mirror of
https://github.com/coleam00/context-engineering-intro.git
synced 2025-12-18 10:15:27 +00:00
196 lines
5.5 KiB
Python
196 lines
5.5 KiB
Python
"""
|
|
Pydantic models for data validation and serialization.
|
|
"""
|
|
|
|
from typing import List, Dict, Any, Optional, Literal
|
|
from datetime import datetime
|
|
from uuid import UUID
|
|
from pydantic import BaseModel, Field, ConfigDict, field_validator
|
|
from enum import Enum
|
|
|
|
# Enums
|
|
class SearchType(str, Enum):
|
|
"""Search type enum."""
|
|
SEMANTIC = "semantic"
|
|
KEYWORD = "keyword"
|
|
HYBRID = "hybrid"
|
|
|
|
class MessageRole(str, Enum):
|
|
"""Message role enum."""
|
|
USER = "user"
|
|
ASSISTANT = "assistant"
|
|
SYSTEM = "system"
|
|
|
|
# Request Models
|
|
class SearchRequest(BaseModel):
|
|
"""Search request model."""
|
|
query: str = Field(..., description="Search query")
|
|
search_type: SearchType = Field(default=SearchType.SEMANTIC, description="Type of search")
|
|
limit: int = Field(default=10, ge=1, le=50, description="Maximum results")
|
|
filters: Dict[str, Any] = Field(default_factory=dict, description="Search filters")
|
|
|
|
model_config = ConfigDict(use_enum_values=True)
|
|
|
|
|
|
# Response Models
|
|
class DocumentMetadata(BaseModel):
|
|
"""Document metadata model."""
|
|
id: str
|
|
title: str
|
|
source: str
|
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
created_at: datetime
|
|
updated_at: datetime
|
|
chunk_count: Optional[int] = None
|
|
|
|
|
|
class ChunkResult(BaseModel):
|
|
"""Chunk search result model."""
|
|
chunk_id: str
|
|
document_id: str
|
|
content: str
|
|
score: float
|
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
document_title: str
|
|
document_source: str
|
|
|
|
@field_validator('score')
|
|
@classmethod
|
|
def validate_score(cls, v: float) -> float:
|
|
"""Ensure score is between 0 and 1."""
|
|
return max(0.0, min(1.0, v))
|
|
|
|
|
|
|
|
|
|
class SearchResponse(BaseModel):
|
|
"""Search response model."""
|
|
results: List[ChunkResult] = Field(default_factory=list)
|
|
total_results: int = 0
|
|
search_type: SearchType
|
|
query_time_ms: float
|
|
|
|
|
|
class ToolCall(BaseModel):
|
|
"""Tool call information model."""
|
|
tool_name: str
|
|
args: Dict[str, Any] = Field(default_factory=dict)
|
|
tool_call_id: Optional[str] = None
|
|
|
|
|
|
class ChatResponse(BaseModel):
|
|
"""Chat response model."""
|
|
message: str
|
|
session_id: str
|
|
sources: List[DocumentMetadata] = Field(default_factory=list)
|
|
tools_used: List[ToolCall] = Field(default_factory=list)
|
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
class StreamDelta(BaseModel):
|
|
"""Streaming response delta."""
|
|
content: str
|
|
delta_type: Literal["text", "tool_call", "end"] = "text"
|
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
# Database Models
|
|
class Document(BaseModel):
|
|
"""Document model."""
|
|
id: Optional[str] = None
|
|
title: str
|
|
source: str
|
|
content: str
|
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
created_at: Optional[datetime] = None
|
|
updated_at: Optional[datetime] = None
|
|
|
|
|
|
class Chunk(BaseModel):
|
|
"""Document chunk model."""
|
|
id: Optional[str] = None
|
|
document_id: str
|
|
content: str
|
|
embedding: Optional[List[float]] = None
|
|
chunk_index: int
|
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
token_count: Optional[int] = None
|
|
created_at: Optional[datetime] = None
|
|
|
|
@field_validator('embedding')
|
|
@classmethod
|
|
def validate_embedding(cls, v: Optional[List[float]]) -> Optional[List[float]]:
|
|
"""Validate embedding dimensions."""
|
|
if v is not None and len(v) != 1536: # OpenAI text-embedding-3-small
|
|
raise ValueError(f"Embedding must have 1536 dimensions, got {len(v)}")
|
|
return v
|
|
|
|
|
|
class Session(BaseModel):
|
|
"""Session model."""
|
|
id: Optional[str] = None
|
|
user_id: Optional[str] = None
|
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
created_at: Optional[datetime] = None
|
|
updated_at: Optional[datetime] = None
|
|
expires_at: Optional[datetime] = None
|
|
|
|
|
|
class Message(BaseModel):
|
|
"""Message model."""
|
|
id: Optional[str] = None
|
|
session_id: str
|
|
role: MessageRole
|
|
content: str
|
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
created_at: Optional[datetime] = None
|
|
|
|
model_config = ConfigDict(use_enum_values=True)
|
|
|
|
|
|
# Agent Models
|
|
class AgentDependencies(BaseModel):
|
|
"""Dependencies for the agent."""
|
|
session_id: str
|
|
database_url: Optional[str] = None
|
|
openai_api_key: Optional[str] = None
|
|
|
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
|
|
|
|
|
|
|
|
class AgentContext(BaseModel):
|
|
"""Agent execution context."""
|
|
session_id: str
|
|
messages: List[Message] = Field(default_factory=list)
|
|
tool_calls: List[ToolCall] = Field(default_factory=list)
|
|
search_results: List[ChunkResult] = Field(default_factory=list)
|
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
# Ingestion Models
|
|
class IngestionConfig(BaseModel):
|
|
"""Configuration for document ingestion."""
|
|
chunk_size: int = Field(default=1000, ge=100, le=5000)
|
|
chunk_overlap: int = Field(default=200, ge=0, le=1000)
|
|
max_chunk_size: int = Field(default=2000, ge=500, le=10000)
|
|
use_semantic_chunking: bool = True
|
|
|
|
@field_validator('chunk_overlap')
|
|
@classmethod
|
|
def validate_overlap(cls, v: int, info) -> int:
|
|
"""Ensure overlap is less than chunk size."""
|
|
chunk_size = info.data.get('chunk_size', 1000)
|
|
if v >= chunk_size:
|
|
raise ValueError(f"Chunk overlap ({v}) must be less than chunk size ({chunk_size})")
|
|
return v
|
|
|
|
|
|
class IngestionResult(BaseModel):
|
|
"""Result of document ingestion."""
|
|
document_id: str
|
|
title: str
|
|
chunks_created: int
|
|
processing_time_ms: float
|
|
errors: List[str] = Field(default_factory=list) |