""" Database utilities for PostgreSQL connection and operations. """ import os import json import asyncio from typing import List, Dict, Any, Optional, Tuple from datetime import datetime, timedelta, timezone from contextlib import asynccontextmanager from uuid import UUID import logging import asyncpg from asyncpg.pool import Pool from dotenv import load_dotenv # Load environment variables load_dotenv() logger = logging.getLogger(__name__) class DatabasePool: """Manages PostgreSQL connection pool.""" def __init__(self, database_url: Optional[str] = None): """ Initialize database pool. Args: database_url: PostgreSQL connection URL """ self.database_url = database_url or os.getenv("DATABASE_URL") if not self.database_url: raise ValueError("DATABASE_URL environment variable not set") self.pool: Optional[Pool] = None async def initialize(self): """Create connection pool.""" if not self.pool: self.pool = await asyncpg.create_pool( self.database_url, min_size=5, max_size=20, max_inactive_connection_lifetime=300, command_timeout=60 ) logger.info("Database connection pool initialized") async def close(self): """Close connection pool.""" if self.pool: await self.pool.close() self.pool = None logger.info("Database connection pool closed") @asynccontextmanager async def acquire(self): """Acquire a connection from the pool.""" if not self.pool: await self.initialize() async with self.pool.acquire() as connection: yield connection # Global database pool instance db_pool = DatabasePool() async def initialize_database(): """Initialize database connection pool.""" await db_pool.initialize() async def close_database(): """Close database connection pool.""" await db_pool.close() # Document Management Functions async def get_document(document_id: str) -> Optional[Dict[str, Any]]: """ Get document by ID. Args: document_id: Document UUID Returns: Document data or None if not found """ async with db_pool.acquire() as conn: result = await conn.fetchrow( """ SELECT id::text, title, source, content, metadata, created_at, updated_at FROM documents WHERE id = $1::uuid """, document_id ) if result: return { "id": result["id"], "title": result["title"], "source": result["source"], "content": result["content"], "metadata": json.loads(result["metadata"]), "created_at": result["created_at"].isoformat(), "updated_at": result["updated_at"].isoformat() } return None async def list_documents( limit: int = 100, offset: int = 0, metadata_filter: Optional[Dict[str, Any]] = None ) -> List[Dict[str, Any]]: """ List documents with optional filtering. Args: limit: Maximum number of documents to return offset: Number of documents to skip metadata_filter: Optional metadata filter Returns: List of documents """ async with db_pool.acquire() as conn: query = """ SELECT d.id::text, d.title, d.source, d.metadata, d.created_at, d.updated_at, COUNT(c.id) AS chunk_count FROM documents d LEFT JOIN chunks c ON d.id = c.document_id """ params = [] conditions = [] if metadata_filter: conditions.append(f"d.metadata @> ${len(params) + 1}::jsonb") params.append(json.dumps(metadata_filter)) if conditions: query += " WHERE " + " AND ".join(conditions) query += """ GROUP BY d.id, d.title, d.source, d.metadata, d.created_at, d.updated_at ORDER BY d.created_at DESC LIMIT $%d OFFSET $%d """ % (len(params) + 1, len(params) + 2) params.extend([limit, offset]) results = await conn.fetch(query, *params) return [ { "id": row["id"], "title": row["title"], "source": row["source"], "metadata": json.loads(row["metadata"]), "created_at": row["created_at"].isoformat(), "updated_at": row["updated_at"].isoformat(), "chunk_count": row["chunk_count"] } for row in results ] # Utility Functions async def execute_query(query: str, *params) -> List[Dict[str, Any]]: """ Execute a custom query. Args: query: SQL query *params: Query parameters Returns: Query results """ async with db_pool.acquire() as conn: results = await conn.fetch(query, *params) return [dict(row) for row in results] async def test_connection() -> bool: """ Test database connection. Returns: True if connection successful """ try: async with db_pool.acquire() as conn: await conn.fetchval("SELECT 1") return True except Exception as e: logger.error(f"Database connection test failed: {e}") return False