""" Document embedding generation for vector search. """ import os import asyncio import logging from typing import List, Dict, Any, Optional, Tuple from datetime import datetime import json from openai import RateLimitError, APIError from dotenv import load_dotenv from .chunker import DocumentChunk # Import flexible providers try: from ..utils.providers import get_embedding_client, get_embedding_model except ImportError: # For direct execution or testing import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from utils.providers import get_embedding_client, get_embedding_model # Load environment variables load_dotenv() logger = logging.getLogger(__name__) # Initialize client with flexible provider embedding_client = get_embedding_client() EMBEDDING_MODEL = get_embedding_model() class EmbeddingGenerator: """Generates embeddings for document chunks.""" def __init__( self, model: str = EMBEDDING_MODEL, batch_size: int = 100, max_retries: int = 3, retry_delay: float = 1.0 ): """ Initialize embedding generator. Args: model: OpenAI embedding model to use batch_size: Number of texts to process in parallel max_retries: Maximum number of retry attempts retry_delay: Delay between retries in seconds """ self.model = model self.batch_size = batch_size self.max_retries = max_retries self.retry_delay = retry_delay # Model-specific configurations self.model_configs = { "text-embedding-3-small": {"dimensions": 1536, "max_tokens": 8191}, "text-embedding-3-large": {"dimensions": 3072, "max_tokens": 8191}, "text-embedding-ada-002": {"dimensions": 1536, "max_tokens": 8191} } if model not in self.model_configs: logger.warning(f"Unknown model {model}, using default config") self.config = {"dimensions": 1536, "max_tokens": 8191} else: self.config = self.model_configs[model] async def generate_embedding(self, text: str) -> List[float]: """ Generate embedding for a single text. Args: text: Text to embed Returns: Embedding vector """ # Truncate text if too long if len(text) > self.config["max_tokens"] * 4: # Rough token estimation text = text[:self.config["max_tokens"] * 4] for attempt in range(self.max_retries): try: response = await embedding_client.embeddings.create( model=self.model, input=text ) return response.data[0].embedding except RateLimitError as e: if attempt == self.max_retries - 1: raise # Exponential backoff for rate limits delay = self.retry_delay * (2 ** attempt) logger.warning(f"Rate limit hit, retrying in {delay}s") await asyncio.sleep(delay) except APIError as e: logger.error(f"OpenAI API error: {e}") if attempt == self.max_retries - 1: raise await asyncio.sleep(self.retry_delay) except Exception as e: logger.error(f"Unexpected error generating embedding: {e}") if attempt == self.max_retries - 1: raise await asyncio.sleep(self.retry_delay) async def generate_embeddings_batch( self, texts: List[str] ) -> List[List[float]]: """ Generate embeddings for a batch of texts. Args: texts: List of texts to embed Returns: List of embedding vectors """ # Filter and truncate texts processed_texts = [] for text in texts: if not text or not text.strip(): processed_texts.append("") continue # Truncate if too long if len(text) > self.config["max_tokens"] * 4: text = text[:self.config["max_tokens"] * 4] processed_texts.append(text) for attempt in range(self.max_retries): try: response = await embedding_client.embeddings.create( model=self.model, input=processed_texts ) return [data.embedding for data in response.data] except RateLimitError as e: if attempt == self.max_retries - 1: raise delay = self.retry_delay * (2 ** attempt) logger.warning(f"Rate limit hit, retrying batch in {delay}s") await asyncio.sleep(delay) except APIError as e: logger.error(f"OpenAI API error in batch: {e}") if attempt == self.max_retries - 1: # Fallback to individual processing return await self._process_individually(processed_texts) await asyncio.sleep(self.retry_delay) except Exception as e: logger.error(f"Unexpected error in batch embedding: {e}") if attempt == self.max_retries - 1: return await self._process_individually(processed_texts) await asyncio.sleep(self.retry_delay) async def _process_individually( self, texts: List[str] ) -> List[List[float]]: """ Process texts individually as fallback. Args: texts: List of texts to embed Returns: List of embedding vectors """ embeddings = [] for text in texts: try: if not text or not text.strip(): embeddings.append([0.0] * self.config["dimensions"]) continue embedding = await self.generate_embedding(text) embeddings.append(embedding) # Small delay to avoid overwhelming the API await asyncio.sleep(0.1) except Exception as e: logger.error(f"Failed to embed text: {e}") # Use zero vector as fallback embeddings.append([0.0] * self.config["dimensions"]) return embeddings async def embed_chunks( self, chunks: List[DocumentChunk], progress_callback: Optional[callable] = None ) -> List[DocumentChunk]: """ Generate embeddings for document chunks. Args: chunks: List of document chunks progress_callback: Optional callback for progress updates Returns: Chunks with embeddings added """ if not chunks: return chunks logger.info(f"Generating embeddings for {len(chunks)} chunks") # Process chunks in batches embedded_chunks = [] total_batches = (len(chunks) + self.batch_size - 1) // self.batch_size for i in range(0, len(chunks), self.batch_size): batch_chunks = chunks[i:i + self.batch_size] batch_texts = [chunk.content for chunk in batch_chunks] try: # Generate embeddings for this batch embeddings = await self.generate_embeddings_batch(batch_texts) # Add embeddings to chunks for chunk, embedding in zip(batch_chunks, embeddings): # Create a new chunk with embedding embedded_chunk = DocumentChunk( content=chunk.content, index=chunk.index, start_char=chunk.start_char, end_char=chunk.end_char, metadata={ **chunk.metadata, "embedding_model": self.model, "embedding_generated_at": datetime.now().isoformat() }, token_count=chunk.token_count ) # Add embedding as a separate attribute embedded_chunk.embedding = embedding embedded_chunks.append(embedded_chunk) # Progress update current_batch = (i // self.batch_size) + 1 if progress_callback: progress_callback(current_batch, total_batches) logger.info(f"Processed batch {current_batch}/{total_batches}") except Exception as e: logger.error(f"Failed to process batch {i//self.batch_size + 1}: {e}") # Add chunks without embeddings as fallback for chunk in batch_chunks: chunk.metadata.update({ "embedding_error": str(e), "embedding_generated_at": datetime.now().isoformat() }) chunk.embedding = [0.0] * self.config["dimensions"] embedded_chunks.append(chunk) logger.info(f"Generated embeddings for {len(embedded_chunks)} chunks") return embedded_chunks async def embed_query(self, query: str) -> List[float]: """ Generate embedding for a search query. Args: query: Search query Returns: Query embedding """ return await self.generate_embedding(query) def get_embedding_dimension(self) -> int: """Get the dimension of embeddings for this model.""" return self.config["dimensions"] # Cache for embeddings class EmbeddingCache: """Simple in-memory cache for embeddings.""" def __init__(self, max_size: int = 1000): """Initialize cache.""" self.cache: Dict[str, List[float]] = {} self.access_times: Dict[str, datetime] = {} self.max_size = max_size def get(self, text: str) -> Optional[List[float]]: """Get embedding from cache.""" text_hash = self._hash_text(text) if text_hash in self.cache: self.access_times[text_hash] = datetime.now() return self.cache[text_hash] return None def put(self, text: str, embedding: List[float]): """Store embedding in cache.""" text_hash = self._hash_text(text) # Evict oldest entries if cache is full if len(self.cache) >= self.max_size: oldest_key = min(self.access_times.keys(), key=lambda k: self.access_times[k]) del self.cache[oldest_key] del self.access_times[oldest_key] self.cache[text_hash] = embedding self.access_times[text_hash] = datetime.now() def _hash_text(self, text: str) -> str: """Generate hash for text.""" import hashlib return hashlib.md5(text.encode()).hexdigest() # Factory function def create_embedder( model: str = EMBEDDING_MODEL, use_cache: bool = True, **kwargs ) -> EmbeddingGenerator: """ Create embedding generator with optional caching. Args: model: Embedding model to use use_cache: Whether to use caching **kwargs: Additional arguments for EmbeddingGenerator Returns: EmbeddingGenerator instance """ embedder = EmbeddingGenerator(model=model, **kwargs) if use_cache: # Add caching capability cache = EmbeddingCache() original_generate = embedder.generate_embedding async def cached_generate(text: str) -> List[float]: cached = cache.get(text) if cached is not None: return cached embedding = await original_generate(text) cache.put(text, embedding) return embedding embedder.generate_embedding = cached_generate return embedder # Example usage async def main(): """Example usage of the embedder.""" from .chunker import ChunkingConfig, create_chunker # Create chunker and embedder config = ChunkingConfig(chunk_size=200, use_semantic_splitting=False) chunker = create_chunker(config) embedder = create_embedder() sample_text = """ Google's AI initiatives include advanced language models, computer vision, and machine learning research. The company has invested heavily in transformer architectures and neural network optimization. Microsoft's partnership with OpenAI has led to integration of GPT models into various products and services, making AI accessible to enterprise customers through Azure cloud services. """ # Chunk the document chunks = chunker.chunk_document( content=sample_text, title="AI Initiatives", source="example.md" ) print(f"Created {len(chunks)} chunks") # Generate embeddings def progress_callback(current, total): print(f"Processing batch {current}/{total}") embedded_chunks = await embedder.embed_chunks(chunks, progress_callback) for i, chunk in enumerate(embedded_chunks): print(f"Chunk {i}: {len(chunk.content)} chars, embedding dim: {len(chunk.embedding)}") # Test query embedding query_embedding = await embedder.embed_query("Google AI research") print(f"Query embedding dimension: {len(query_embedding)}") if __name__ == "__main__": asyncio.run(main())