mirror of
https://github.com/coleam00/context-engineering-intro.git
synced 2025-12-18 18:25:26 +00:00
217 lines
5.7 KiB
Python
217 lines
5.7 KiB
Python
"""
|
|
Database utilities for PostgreSQL connection and operations.
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import asyncio
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
from datetime import datetime, timedelta, timezone
|
|
from contextlib import asynccontextmanager
|
|
from uuid import UUID
|
|
import logging
|
|
|
|
import asyncpg
|
|
from asyncpg.pool import Pool
|
|
from dotenv import load_dotenv
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DatabasePool:
|
|
"""Manages PostgreSQL connection pool."""
|
|
|
|
def __init__(self, database_url: Optional[str] = None):
|
|
"""
|
|
Initialize database pool.
|
|
|
|
Args:
|
|
database_url: PostgreSQL connection URL
|
|
"""
|
|
self.database_url = database_url or os.getenv("DATABASE_URL")
|
|
if not self.database_url:
|
|
raise ValueError("DATABASE_URL environment variable not set")
|
|
|
|
self.pool: Optional[Pool] = None
|
|
|
|
async def initialize(self):
|
|
"""Create connection pool."""
|
|
if not self.pool:
|
|
self.pool = await asyncpg.create_pool(
|
|
self.database_url,
|
|
min_size=5,
|
|
max_size=20,
|
|
max_inactive_connection_lifetime=300,
|
|
command_timeout=60
|
|
)
|
|
logger.info("Database connection pool initialized")
|
|
|
|
async def close(self):
|
|
"""Close connection pool."""
|
|
if self.pool:
|
|
await self.pool.close()
|
|
self.pool = None
|
|
logger.info("Database connection pool closed")
|
|
|
|
@asynccontextmanager
|
|
async def acquire(self):
|
|
"""Acquire a connection from the pool."""
|
|
if not self.pool:
|
|
await self.initialize()
|
|
|
|
async with self.pool.acquire() as connection:
|
|
yield connection
|
|
|
|
|
|
# Global database pool instance
|
|
db_pool = DatabasePool()
|
|
|
|
|
|
async def initialize_database():
|
|
"""Initialize database connection pool."""
|
|
await db_pool.initialize()
|
|
|
|
|
|
async def close_database():
|
|
"""Close database connection pool."""
|
|
await db_pool.close()
|
|
|
|
# Document Management Functions
|
|
async def get_document(document_id: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get document by ID.
|
|
|
|
Args:
|
|
document_id: Document UUID
|
|
|
|
Returns:
|
|
Document data or None if not found
|
|
"""
|
|
async with db_pool.acquire() as conn:
|
|
result = await conn.fetchrow(
|
|
"""
|
|
SELECT
|
|
id::text,
|
|
title,
|
|
source,
|
|
content,
|
|
metadata,
|
|
created_at,
|
|
updated_at
|
|
FROM documents
|
|
WHERE id = $1::uuid
|
|
""",
|
|
document_id
|
|
)
|
|
|
|
if result:
|
|
return {
|
|
"id": result["id"],
|
|
"title": result["title"],
|
|
"source": result["source"],
|
|
"content": result["content"],
|
|
"metadata": json.loads(result["metadata"]),
|
|
"created_at": result["created_at"].isoformat(),
|
|
"updated_at": result["updated_at"].isoformat()
|
|
}
|
|
|
|
return None
|
|
|
|
|
|
async def list_documents(
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
metadata_filter: Optional[Dict[str, Any]] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
List documents with optional filtering.
|
|
|
|
Args:
|
|
limit: Maximum number of documents to return
|
|
offset: Number of documents to skip
|
|
metadata_filter: Optional metadata filter
|
|
|
|
Returns:
|
|
List of documents
|
|
"""
|
|
async with db_pool.acquire() as conn:
|
|
query = """
|
|
SELECT
|
|
d.id::text,
|
|
d.title,
|
|
d.source,
|
|
d.metadata,
|
|
d.created_at,
|
|
d.updated_at,
|
|
COUNT(c.id) AS chunk_count
|
|
FROM documents d
|
|
LEFT JOIN chunks c ON d.id = c.document_id
|
|
"""
|
|
|
|
params = []
|
|
conditions = []
|
|
|
|
if metadata_filter:
|
|
conditions.append(f"d.metadata @> ${len(params) + 1}::jsonb")
|
|
params.append(json.dumps(metadata_filter))
|
|
|
|
if conditions:
|
|
query += " WHERE " + " AND ".join(conditions)
|
|
|
|
query += """
|
|
GROUP BY d.id, d.title, d.source, d.metadata, d.created_at, d.updated_at
|
|
ORDER BY d.created_at DESC
|
|
LIMIT $%d OFFSET $%d
|
|
""" % (len(params) + 1, len(params) + 2)
|
|
|
|
params.extend([limit, offset])
|
|
|
|
results = await conn.fetch(query, *params)
|
|
|
|
return [
|
|
{
|
|
"id": row["id"],
|
|
"title": row["title"],
|
|
"source": row["source"],
|
|
"metadata": json.loads(row["metadata"]),
|
|
"created_at": row["created_at"].isoformat(),
|
|
"updated_at": row["updated_at"].isoformat(),
|
|
"chunk_count": row["chunk_count"]
|
|
}
|
|
for row in results
|
|
]
|
|
|
|
# Utility Functions
|
|
async def execute_query(query: str, *params) -> List[Dict[str, Any]]:
|
|
"""
|
|
Execute a custom query.
|
|
|
|
Args:
|
|
query: SQL query
|
|
*params: Query parameters
|
|
|
|
Returns:
|
|
Query results
|
|
"""
|
|
async with db_pool.acquire() as conn:
|
|
results = await conn.fetch(query, *params)
|
|
return [dict(row) for row in results]
|
|
|
|
|
|
async def test_connection() -> bool:
|
|
"""
|
|
Test database connection.
|
|
|
|
Returns:
|
|
True if connection successful
|
|
"""
|
|
try:
|
|
async with db_pool.acquire() as conn:
|
|
await conn.fetchval("SELECT 1")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Database connection test failed: {e}")
|
|
return False |