2025-08-22 21:01:17 -05:00

196 lines
5.5 KiB
Python

"""
Pydantic models for data validation and serialization.
"""
from typing import List, Dict, Any, Optional, Literal
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel, Field, ConfigDict, field_validator
from enum import Enum
# Enums
class SearchType(str, Enum):
"""Search type enum."""
SEMANTIC = "semantic"
KEYWORD = "keyword"
HYBRID = "hybrid"
class MessageRole(str, Enum):
"""Message role enum."""
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
# Request Models
class SearchRequest(BaseModel):
"""Search request model."""
query: str = Field(..., description="Search query")
search_type: SearchType = Field(default=SearchType.SEMANTIC, description="Type of search")
limit: int = Field(default=10, ge=1, le=50, description="Maximum results")
filters: Dict[str, Any] = Field(default_factory=dict, description="Search filters")
model_config = ConfigDict(use_enum_values=True)
# Response Models
class DocumentMetadata(BaseModel):
"""Document metadata model."""
id: str
title: str
source: str
metadata: Dict[str, Any] = Field(default_factory=dict)
created_at: datetime
updated_at: datetime
chunk_count: Optional[int] = None
class ChunkResult(BaseModel):
"""Chunk search result model."""
chunk_id: str
document_id: str
content: str
score: float
metadata: Dict[str, Any] = Field(default_factory=dict)
document_title: str
document_source: str
@field_validator('score')
@classmethod
def validate_score(cls, v: float) -> float:
"""Ensure score is between 0 and 1."""
return max(0.0, min(1.0, v))
class SearchResponse(BaseModel):
"""Search response model."""
results: List[ChunkResult] = Field(default_factory=list)
total_results: int = 0
search_type: SearchType
query_time_ms: float
class ToolCall(BaseModel):
"""Tool call information model."""
tool_name: str
args: Dict[str, Any] = Field(default_factory=dict)
tool_call_id: Optional[str] = None
class ChatResponse(BaseModel):
"""Chat response model."""
message: str
session_id: str
sources: List[DocumentMetadata] = Field(default_factory=list)
tools_used: List[ToolCall] = Field(default_factory=list)
metadata: Dict[str, Any] = Field(default_factory=dict)
class StreamDelta(BaseModel):
"""Streaming response delta."""
content: str
delta_type: Literal["text", "tool_call", "end"] = "text"
metadata: Dict[str, Any] = Field(default_factory=dict)
# Database Models
class Document(BaseModel):
"""Document model."""
id: Optional[str] = None
title: str
source: str
content: str
metadata: Dict[str, Any] = Field(default_factory=dict)
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
class Chunk(BaseModel):
"""Document chunk model."""
id: Optional[str] = None
document_id: str
content: str
embedding: Optional[List[float]] = None
chunk_index: int
metadata: Dict[str, Any] = Field(default_factory=dict)
token_count: Optional[int] = None
created_at: Optional[datetime] = None
@field_validator('embedding')
@classmethod
def validate_embedding(cls, v: Optional[List[float]]) -> Optional[List[float]]:
"""Validate embedding dimensions."""
if v is not None and len(v) != 1536: # OpenAI text-embedding-3-small
raise ValueError(f"Embedding must have 1536 dimensions, got {len(v)}")
return v
class Session(BaseModel):
"""Session model."""
id: Optional[str] = None
user_id: Optional[str] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
expires_at: Optional[datetime] = None
class Message(BaseModel):
"""Message model."""
id: Optional[str] = None
session_id: str
role: MessageRole
content: str
metadata: Dict[str, Any] = Field(default_factory=dict)
created_at: Optional[datetime] = None
model_config = ConfigDict(use_enum_values=True)
# Agent Models
class AgentDependencies(BaseModel):
"""Dependencies for the agent."""
session_id: str
database_url: Optional[str] = None
openai_api_key: Optional[str] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
class AgentContext(BaseModel):
"""Agent execution context."""
session_id: str
messages: List[Message] = Field(default_factory=list)
tool_calls: List[ToolCall] = Field(default_factory=list)
search_results: List[ChunkResult] = Field(default_factory=list)
metadata: Dict[str, Any] = Field(default_factory=dict)
# Ingestion Models
class IngestionConfig(BaseModel):
"""Configuration for document ingestion."""
chunk_size: int = Field(default=1000, ge=100, le=5000)
chunk_overlap: int = Field(default=200, ge=0, le=1000)
max_chunk_size: int = Field(default=2000, ge=500, le=10000)
use_semantic_chunking: bool = True
@field_validator('chunk_overlap')
@classmethod
def validate_overlap(cls, v: int, info) -> int:
"""Ensure overlap is less than chunk size."""
chunk_size = info.data.get('chunk_size', 1000)
if v >= chunk_size:
raise ValueError(f"Chunk overlap ({v}) must be less than chunk size ({chunk_size})")
return v
class IngestionResult(BaseModel):
"""Result of document ingestion."""
document_id: str
title: str
chunks_created: int
processing_time_ms: float
errors: List[str] = Field(default_factory=list)