""" Pydantic models for data validation and serialization. """ from typing import List, Dict, Any, Optional, Literal from datetime import datetime from uuid import UUID from pydantic import BaseModel, Field, ConfigDict, field_validator from enum import Enum # Enums class SearchType(str, Enum): """Search type enum.""" SEMANTIC = "semantic" KEYWORD = "keyword" HYBRID = "hybrid" class MessageRole(str, Enum): """Message role enum.""" USER = "user" ASSISTANT = "assistant" SYSTEM = "system" # Request Models class SearchRequest(BaseModel): """Search request model.""" query: str = Field(..., description="Search query") search_type: SearchType = Field(default=SearchType.SEMANTIC, description="Type of search") limit: int = Field(default=10, ge=1, le=50, description="Maximum results") filters: Dict[str, Any] = Field(default_factory=dict, description="Search filters") model_config = ConfigDict(use_enum_values=True) # Response Models class DocumentMetadata(BaseModel): """Document metadata model.""" id: str title: str source: str metadata: Dict[str, Any] = Field(default_factory=dict) created_at: datetime updated_at: datetime chunk_count: Optional[int] = None class ChunkResult(BaseModel): """Chunk search result model.""" chunk_id: str document_id: str content: str score: float metadata: Dict[str, Any] = Field(default_factory=dict) document_title: str document_source: str @field_validator('score') @classmethod def validate_score(cls, v: float) -> float: """Ensure score is between 0 and 1.""" return max(0.0, min(1.0, v)) class SearchResponse(BaseModel): """Search response model.""" results: List[ChunkResult] = Field(default_factory=list) total_results: int = 0 search_type: SearchType query_time_ms: float class ToolCall(BaseModel): """Tool call information model.""" tool_name: str args: Dict[str, Any] = Field(default_factory=dict) tool_call_id: Optional[str] = None class ChatResponse(BaseModel): """Chat response model.""" message: str session_id: str sources: List[DocumentMetadata] = Field(default_factory=list) tools_used: List[ToolCall] = Field(default_factory=list) metadata: Dict[str, Any] = Field(default_factory=dict) class StreamDelta(BaseModel): """Streaming response delta.""" content: str delta_type: Literal["text", "tool_call", "end"] = "text" metadata: Dict[str, Any] = Field(default_factory=dict) # Database Models class Document(BaseModel): """Document model.""" id: Optional[str] = None title: str source: str content: str metadata: Dict[str, Any] = Field(default_factory=dict) created_at: Optional[datetime] = None updated_at: Optional[datetime] = None class Chunk(BaseModel): """Document chunk model.""" id: Optional[str] = None document_id: str content: str embedding: Optional[List[float]] = None chunk_index: int metadata: Dict[str, Any] = Field(default_factory=dict) token_count: Optional[int] = None created_at: Optional[datetime] = None @field_validator('embedding') @classmethod def validate_embedding(cls, v: Optional[List[float]]) -> Optional[List[float]]: """Validate embedding dimensions.""" if v is not None and len(v) != 1536: # OpenAI text-embedding-3-small raise ValueError(f"Embedding must have 1536 dimensions, got {len(v)}") return v class Session(BaseModel): """Session model.""" id: Optional[str] = None user_id: Optional[str] = None metadata: Dict[str, Any] = Field(default_factory=dict) created_at: Optional[datetime] = None updated_at: Optional[datetime] = None expires_at: Optional[datetime] = None class Message(BaseModel): """Message model.""" id: Optional[str] = None session_id: str role: MessageRole content: str metadata: Dict[str, Any] = Field(default_factory=dict) created_at: Optional[datetime] = None model_config = ConfigDict(use_enum_values=True) # Agent Models class AgentDependencies(BaseModel): """Dependencies for the agent.""" session_id: str database_url: Optional[str] = None openai_api_key: Optional[str] = None model_config = ConfigDict(arbitrary_types_allowed=True) class AgentContext(BaseModel): """Agent execution context.""" session_id: str messages: List[Message] = Field(default_factory=list) tool_calls: List[ToolCall] = Field(default_factory=list) search_results: List[ChunkResult] = Field(default_factory=list) metadata: Dict[str, Any] = Field(default_factory=dict) # Ingestion Models class IngestionConfig(BaseModel): """Configuration for document ingestion.""" chunk_size: int = Field(default=1000, ge=100, le=5000) chunk_overlap: int = Field(default=200, ge=0, le=1000) max_chunk_size: int = Field(default=2000, ge=500, le=10000) use_semantic_chunking: bool = True @field_validator('chunk_overlap') @classmethod def validate_overlap(cls, v: int, info) -> int: """Ensure overlap is less than chunk size.""" chunk_size = info.data.get('chunk_size', 1000) if v >= chunk_size: raise ValueError(f"Chunk overlap ({v}) must be less than chunk size ({chunk_size})") return v class IngestionResult(BaseModel): """Result of document ingestion.""" document_id: str title: str chunks_created: int processing_time_ms: float errors: List[str] = Field(default_factory=list)