mirror of
https://github.com/coleam00/context-engineering-intro.git
synced 2025-12-29 16:14:56 +00:00
AI Agent Factory with Claude Code Subagents
This commit is contained in:
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Database utilities for PostgreSQL connection and operations.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from contextlib import asynccontextmanager
|
||||
from uuid import UUID
|
||||
import logging
|
||||
|
||||
import asyncpg
|
||||
from asyncpg.pool import Pool
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatabasePool:
|
||||
"""Manages PostgreSQL connection pool."""
|
||||
|
||||
def __init__(self, database_url: Optional[str] = None):
|
||||
"""
|
||||
Initialize database pool.
|
||||
|
||||
Args:
|
||||
database_url: PostgreSQL connection URL
|
||||
"""
|
||||
self.database_url = database_url or os.getenv("DATABASE_URL")
|
||||
if not self.database_url:
|
||||
raise ValueError("DATABASE_URL environment variable not set")
|
||||
|
||||
self.pool: Optional[Pool] = None
|
||||
|
||||
async def initialize(self):
|
||||
"""Create connection pool."""
|
||||
if not self.pool:
|
||||
self.pool = await asyncpg.create_pool(
|
||||
self.database_url,
|
||||
min_size=5,
|
||||
max_size=20,
|
||||
max_inactive_connection_lifetime=300,
|
||||
command_timeout=60
|
||||
)
|
||||
logger.info("Database connection pool initialized")
|
||||
|
||||
async def close(self):
|
||||
"""Close connection pool."""
|
||||
if self.pool:
|
||||
await self.pool.close()
|
||||
self.pool = None
|
||||
logger.info("Database connection pool closed")
|
||||
|
||||
@asynccontextmanager
|
||||
async def acquire(self):
|
||||
"""Acquire a connection from the pool."""
|
||||
if not self.pool:
|
||||
await self.initialize()
|
||||
|
||||
async with self.pool.acquire() as connection:
|
||||
yield connection
|
||||
|
||||
|
||||
# Global database pool instance
|
||||
db_pool = DatabasePool()
|
||||
|
||||
|
||||
async def initialize_database():
|
||||
"""Initialize database connection pool."""
|
||||
await db_pool.initialize()
|
||||
|
||||
|
||||
async def close_database():
|
||||
"""Close database connection pool."""
|
||||
await db_pool.close()
|
||||
|
||||
# Document Management Functions
|
||||
async def get_document(document_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get document by ID.
|
||||
|
||||
Args:
|
||||
document_id: Document UUID
|
||||
|
||||
Returns:
|
||||
Document data or None if not found
|
||||
"""
|
||||
async with db_pool.acquire() as conn:
|
||||
result = await conn.fetchrow(
|
||||
"""
|
||||
SELECT
|
||||
id::text,
|
||||
title,
|
||||
source,
|
||||
content,
|
||||
metadata,
|
||||
created_at,
|
||||
updated_at
|
||||
FROM documents
|
||||
WHERE id = $1::uuid
|
||||
""",
|
||||
document_id
|
||||
)
|
||||
|
||||
if result:
|
||||
return {
|
||||
"id": result["id"],
|
||||
"title": result["title"],
|
||||
"source": result["source"],
|
||||
"content": result["content"],
|
||||
"metadata": json.loads(result["metadata"]),
|
||||
"created_at": result["created_at"].isoformat(),
|
||||
"updated_at": result["updated_at"].isoformat()
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def list_documents(
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
metadata_filter: Optional[Dict[str, Any]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
List documents with optional filtering.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of documents to return
|
||||
offset: Number of documents to skip
|
||||
metadata_filter: Optional metadata filter
|
||||
|
||||
Returns:
|
||||
List of documents
|
||||
"""
|
||||
async with db_pool.acquire() as conn:
|
||||
query = """
|
||||
SELECT
|
||||
d.id::text,
|
||||
d.title,
|
||||
d.source,
|
||||
d.metadata,
|
||||
d.created_at,
|
||||
d.updated_at,
|
||||
COUNT(c.id) AS chunk_count
|
||||
FROM documents d
|
||||
LEFT JOIN chunks c ON d.id = c.document_id
|
||||
"""
|
||||
|
||||
params = []
|
||||
conditions = []
|
||||
|
||||
if metadata_filter:
|
||||
conditions.append(f"d.metadata @> ${len(params) + 1}::jsonb")
|
||||
params.append(json.dumps(metadata_filter))
|
||||
|
||||
if conditions:
|
||||
query += " WHERE " + " AND ".join(conditions)
|
||||
|
||||
query += """
|
||||
GROUP BY d.id, d.title, d.source, d.metadata, d.created_at, d.updated_at
|
||||
ORDER BY d.created_at DESC
|
||||
LIMIT $%d OFFSET $%d
|
||||
""" % (len(params) + 1, len(params) + 2)
|
||||
|
||||
params.extend([limit, offset])
|
||||
|
||||
results = await conn.fetch(query, *params)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": row["id"],
|
||||
"title": row["title"],
|
||||
"source": row["source"],
|
||||
"metadata": json.loads(row["metadata"]),
|
||||
"created_at": row["created_at"].isoformat(),
|
||||
"updated_at": row["updated_at"].isoformat(),
|
||||
"chunk_count": row["chunk_count"]
|
||||
}
|
||||
for row in results
|
||||
]
|
||||
|
||||
# Utility Functions
|
||||
async def execute_query(query: str, *params) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Execute a custom query.
|
||||
|
||||
Args:
|
||||
query: SQL query
|
||||
*params: Query parameters
|
||||
|
||||
Returns:
|
||||
Query results
|
||||
"""
|
||||
async with db_pool.acquire() as conn:
|
||||
results = await conn.fetch(query, *params)
|
||||
return [dict(row) for row in results]
|
||||
|
||||
|
||||
async def test_connection() -> bool:
|
||||
"""
|
||||
Test database connection.
|
||||
|
||||
Returns:
|
||||
True if connection successful
|
||||
"""
|
||||
try:
|
||||
async with db_pool.acquire() as conn:
|
||||
await conn.fetchval("SELECT 1")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Database connection test failed: {e}")
|
||||
return False
|
||||
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
Pydantic models for data validation and serialization.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional, Literal
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_validator
|
||||
from enum import Enum
|
||||
|
||||
# Enums
|
||||
class SearchType(str, Enum):
|
||||
"""Search type enum."""
|
||||
SEMANTIC = "semantic"
|
||||
KEYWORD = "keyword"
|
||||
HYBRID = "hybrid"
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
"""Message role enum."""
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
SYSTEM = "system"
|
||||
|
||||
# Request Models
|
||||
class SearchRequest(BaseModel):
|
||||
"""Search request model."""
|
||||
query: str = Field(..., description="Search query")
|
||||
search_type: SearchType = Field(default=SearchType.SEMANTIC, description="Type of search")
|
||||
limit: int = Field(default=10, ge=1, le=50, description="Maximum results")
|
||||
filters: Dict[str, Any] = Field(default_factory=dict, description="Search filters")
|
||||
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
# Response Models
|
||||
class DocumentMetadata(BaseModel):
|
||||
"""Document metadata model."""
|
||||
id: str
|
||||
title: str
|
||||
source: str
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
chunk_count: Optional[int] = None
|
||||
|
||||
|
||||
class ChunkResult(BaseModel):
|
||||
"""Chunk search result model."""
|
||||
chunk_id: str
|
||||
document_id: str
|
||||
content: str
|
||||
score: float
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
document_title: str
|
||||
document_source: str
|
||||
|
||||
@field_validator('score')
|
||||
@classmethod
|
||||
def validate_score(cls, v: float) -> float:
|
||||
"""Ensure score is between 0 and 1."""
|
||||
return max(0.0, min(1.0, v))
|
||||
|
||||
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
"""Search response model."""
|
||||
results: List[ChunkResult] = Field(default_factory=list)
|
||||
total_results: int = 0
|
||||
search_type: SearchType
|
||||
query_time_ms: float
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""Tool call information model."""
|
||||
tool_name: str
|
||||
args: Dict[str, Any] = Field(default_factory=dict)
|
||||
tool_call_id: Optional[str] = None
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
"""Chat response model."""
|
||||
message: str
|
||||
session_id: str
|
||||
sources: List[DocumentMetadata] = Field(default_factory=list)
|
||||
tools_used: List[ToolCall] = Field(default_factory=list)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class StreamDelta(BaseModel):
|
||||
"""Streaming response delta."""
|
||||
content: str
|
||||
delta_type: Literal["text", "tool_call", "end"] = "text"
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
# Database Models
|
||||
class Document(BaseModel):
|
||||
"""Document model."""
|
||||
id: Optional[str] = None
|
||||
title: str
|
||||
source: str
|
||||
content: str
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class Chunk(BaseModel):
|
||||
"""Document chunk model."""
|
||||
id: Optional[str] = None
|
||||
document_id: str
|
||||
content: str
|
||||
embedding: Optional[List[float]] = None
|
||||
chunk_index: int
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
token_count: Optional[int] = None
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
@field_validator('embedding')
|
||||
@classmethod
|
||||
def validate_embedding(cls, v: Optional[List[float]]) -> Optional[List[float]]:
|
||||
"""Validate embedding dimensions."""
|
||||
if v is not None and len(v) != 1536: # OpenAI text-embedding-3-small
|
||||
raise ValueError(f"Embedding must have 1536 dimensions, got {len(v)}")
|
||||
return v
|
||||
|
||||
|
||||
class Session(BaseModel):
|
||||
"""Session model."""
|
||||
id: Optional[str] = None
|
||||
user_id: Optional[str] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""Message model."""
|
||||
id: Optional[str] = None
|
||||
session_id: str
|
||||
role: MessageRole
|
||||
content: str
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
# Agent Models
|
||||
class AgentDependencies(BaseModel):
|
||||
"""Dependencies for the agent."""
|
||||
session_id: str
|
||||
database_url: Optional[str] = None
|
||||
openai_api_key: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
|
||||
|
||||
class AgentContext(BaseModel):
|
||||
"""Agent execution context."""
|
||||
session_id: str
|
||||
messages: List[Message] = Field(default_factory=list)
|
||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||
search_results: List[ChunkResult] = Field(default_factory=list)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
# Ingestion Models
|
||||
class IngestionConfig(BaseModel):
|
||||
"""Configuration for document ingestion."""
|
||||
chunk_size: int = Field(default=1000, ge=100, le=5000)
|
||||
chunk_overlap: int = Field(default=200, ge=0, le=1000)
|
||||
max_chunk_size: int = Field(default=2000, ge=500, le=10000)
|
||||
use_semantic_chunking: bool = True
|
||||
|
||||
@field_validator('chunk_overlap')
|
||||
@classmethod
|
||||
def validate_overlap(cls, v: int, info) -> int:
|
||||
"""Ensure overlap is less than chunk size."""
|
||||
chunk_size = info.data.get('chunk_size', 1000)
|
||||
if v >= chunk_size:
|
||||
raise ValueError(f"Chunk overlap ({v}) must be less than chunk size ({chunk_size})")
|
||||
return v
|
||||
|
||||
|
||||
class IngestionResult(BaseModel):
|
||||
"""Result of document ingestion."""
|
||||
document_id: str
|
||||
title: str
|
||||
chunks_created: int
|
||||
processing_time_ms: float
|
||||
errors: List[str] = Field(default_factory=list)
|
||||
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
Simplified provider configuration for OpenAI models only.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
from pydantic_ai.models.openai import OpenAIModel
|
||||
from pydantic_ai.providers.openai import OpenAIProvider
|
||||
import openai
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def get_llm_model() -> OpenAIModel:
|
||||
"""
|
||||
Get LLM model configuration for OpenAI.
|
||||
|
||||
Returns:
|
||||
Configured OpenAI model
|
||||
"""
|
||||
llm_choice = os.getenv('LLM_CHOICE', 'gpt-4.1-mini')
|
||||
api_key = os.getenv('OPENAI_API_KEY')
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("OPENAI_API_KEY environment variable is required")
|
||||
|
||||
return OpenAIModel(llm_choice, provider=OpenAIProvider(api_key=api_key))
|
||||
|
||||
|
||||
def get_embedding_client() -> openai.AsyncOpenAI:
|
||||
"""
|
||||
Get OpenAI client for embeddings.
|
||||
|
||||
Returns:
|
||||
Configured OpenAI client for embeddings
|
||||
"""
|
||||
api_key = os.getenv('OPENAI_API_KEY')
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("OPENAI_API_KEY environment variable is required")
|
||||
|
||||
return openai.AsyncOpenAI(api_key=api_key)
|
||||
|
||||
|
||||
def get_embedding_model() -> str:
|
||||
"""
|
||||
Get embedding model name.
|
||||
|
||||
Returns:
|
||||
Embedding model name
|
||||
"""
|
||||
return os.getenv('EMBEDDING_MODEL', 'text-embedding-3-small')
|
||||
|
||||
|
||||
def get_ingestion_model() -> OpenAIModel:
|
||||
"""
|
||||
Get model for ingestion tasks (uses same model as main LLM).
|
||||
|
||||
Returns:
|
||||
Configured model for ingestion tasks
|
||||
"""
|
||||
return get_llm_model()
|
||||
|
||||
|
||||
def validate_configuration() -> bool:
|
||||
"""
|
||||
Validate that required environment variables are set.
|
||||
|
||||
Returns:
|
||||
True if configuration is valid
|
||||
"""
|
||||
required_vars = [
|
||||
'OPENAI_API_KEY',
|
||||
'DATABASE_URL'
|
||||
]
|
||||
|
||||
missing_vars = []
|
||||
for var in required_vars:
|
||||
if not os.getenv(var):
|
||||
missing_vars.append(var)
|
||||
|
||||
if missing_vars:
|
||||
print(f"Missing required environment variables: {', '.join(missing_vars)}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_model_info() -> dict:
|
||||
"""
|
||||
Get information about current model configuration.
|
||||
|
||||
Returns:
|
||||
Dictionary with model configuration info
|
||||
"""
|
||||
return {
|
||||
"llm_provider": "openai",
|
||||
"llm_model": os.getenv('LLM_CHOICE', 'gpt-4.1-mini'),
|
||||
"embedding_provider": "openai",
|
||||
"embedding_model": get_embedding_model(),
|
||||
}
|
||||
Reference in New Issue
Block a user