2025-08-22 21:01:17 -05:00

217 lines
5.7 KiB
Python

"""
Database utilities for PostgreSQL connection and operations.
"""
import os
import json
import asyncio
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime, timedelta, timezone
from contextlib import asynccontextmanager
from uuid import UUID
import logging
import asyncpg
from asyncpg.pool import Pool
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
logger = logging.getLogger(__name__)
class DatabasePool:
"""Manages PostgreSQL connection pool."""
def __init__(self, database_url: Optional[str] = None):
"""
Initialize database pool.
Args:
database_url: PostgreSQL connection URL
"""
self.database_url = database_url or os.getenv("DATABASE_URL")
if not self.database_url:
raise ValueError("DATABASE_URL environment variable not set")
self.pool: Optional[Pool] = None
async def initialize(self):
"""Create connection pool."""
if not self.pool:
self.pool = await asyncpg.create_pool(
self.database_url,
min_size=5,
max_size=20,
max_inactive_connection_lifetime=300,
command_timeout=60
)
logger.info("Database connection pool initialized")
async def close(self):
"""Close connection pool."""
if self.pool:
await self.pool.close()
self.pool = None
logger.info("Database connection pool closed")
@asynccontextmanager
async def acquire(self):
"""Acquire a connection from the pool."""
if not self.pool:
await self.initialize()
async with self.pool.acquire() as connection:
yield connection
# Global database pool instance
db_pool = DatabasePool()
async def initialize_database():
"""Initialize database connection pool."""
await db_pool.initialize()
async def close_database():
"""Close database connection pool."""
await db_pool.close()
# Document Management Functions
async def get_document(document_id: str) -> Optional[Dict[str, Any]]:
"""
Get document by ID.
Args:
document_id: Document UUID
Returns:
Document data or None if not found
"""
async with db_pool.acquire() as conn:
result = await conn.fetchrow(
"""
SELECT
id::text,
title,
source,
content,
metadata,
created_at,
updated_at
FROM documents
WHERE id = $1::uuid
""",
document_id
)
if result:
return {
"id": result["id"],
"title": result["title"],
"source": result["source"],
"content": result["content"],
"metadata": json.loads(result["metadata"]),
"created_at": result["created_at"].isoformat(),
"updated_at": result["updated_at"].isoformat()
}
return None
async def list_documents(
limit: int = 100,
offset: int = 0,
metadata_filter: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""
List documents with optional filtering.
Args:
limit: Maximum number of documents to return
offset: Number of documents to skip
metadata_filter: Optional metadata filter
Returns:
List of documents
"""
async with db_pool.acquire() as conn:
query = """
SELECT
d.id::text,
d.title,
d.source,
d.metadata,
d.created_at,
d.updated_at,
COUNT(c.id) AS chunk_count
FROM documents d
LEFT JOIN chunks c ON d.id = c.document_id
"""
params = []
conditions = []
if metadata_filter:
conditions.append(f"d.metadata @> ${len(params) + 1}::jsonb")
params.append(json.dumps(metadata_filter))
if conditions:
query += " WHERE " + " AND ".join(conditions)
query += """
GROUP BY d.id, d.title, d.source, d.metadata, d.created_at, d.updated_at
ORDER BY d.created_at DESC
LIMIT $%d OFFSET $%d
""" % (len(params) + 1, len(params) + 2)
params.extend([limit, offset])
results = await conn.fetch(query, *params)
return [
{
"id": row["id"],
"title": row["title"],
"source": row["source"],
"metadata": json.loads(row["metadata"]),
"created_at": row["created_at"].isoformat(),
"updated_at": row["updated_at"].isoformat(),
"chunk_count": row["chunk_count"]
}
for row in results
]
# Utility Functions
async def execute_query(query: str, *params) -> List[Dict[str, Any]]:
"""
Execute a custom query.
Args:
query: SQL query
*params: Query parameters
Returns:
Query results
"""
async with db_pool.acquire() as conn:
results = await conn.fetch(query, *params)
return [dict(row) for row in results]
async def test_connection() -> bool:
"""
Test database connection.
Returns:
True if connection successful
"""
try:
async with db_pool.acquire() as conn:
await conn.fetchval("SELECT 1")
return True
except Exception as e:
logger.error(f"Database connection test failed: {e}")
return False