mirror of
https://github.com/coleam00/context-engineering-intro.git
synced 2025-12-18 18:25:26 +00:00
418 lines
14 KiB
Python
418 lines
14 KiB
Python
"""
|
|
Document embedding generation for vector search.
|
|
"""
|
|
|
|
import os
|
|
import asyncio
|
|
import logging
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
from datetime import datetime
|
|
import json
|
|
|
|
from openai import RateLimitError, APIError
|
|
from dotenv import load_dotenv
|
|
|
|
from .chunker import DocumentChunk
|
|
|
|
# Import flexible providers
|
|
try:
|
|
from ..utils.providers import get_embedding_client, get_embedding_model
|
|
except ImportError:
|
|
# For direct execution or testing
|
|
import sys
|
|
import os
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
from utils.providers import get_embedding_client, get_embedding_model
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Initialize client with flexible provider
|
|
embedding_client = get_embedding_client()
|
|
EMBEDDING_MODEL = get_embedding_model()
|
|
|
|
|
|
class EmbeddingGenerator:
|
|
"""Generates embeddings for document chunks."""
|
|
|
|
def __init__(
|
|
self,
|
|
model: str = EMBEDDING_MODEL,
|
|
batch_size: int = 100,
|
|
max_retries: int = 3,
|
|
retry_delay: float = 1.0
|
|
):
|
|
"""
|
|
Initialize embedding generator.
|
|
|
|
Args:
|
|
model: OpenAI embedding model to use
|
|
batch_size: Number of texts to process in parallel
|
|
max_retries: Maximum number of retry attempts
|
|
retry_delay: Delay between retries in seconds
|
|
"""
|
|
self.model = model
|
|
self.batch_size = batch_size
|
|
self.max_retries = max_retries
|
|
self.retry_delay = retry_delay
|
|
|
|
# Model-specific configurations
|
|
self.model_configs = {
|
|
"text-embedding-3-small": {"dimensions": 1536, "max_tokens": 8191},
|
|
"text-embedding-3-large": {"dimensions": 3072, "max_tokens": 8191},
|
|
"text-embedding-ada-002": {"dimensions": 1536, "max_tokens": 8191}
|
|
}
|
|
|
|
if model not in self.model_configs:
|
|
logger.warning(f"Unknown model {model}, using default config")
|
|
self.config = {"dimensions": 1536, "max_tokens": 8191}
|
|
else:
|
|
self.config = self.model_configs[model]
|
|
|
|
async def generate_embedding(self, text: str) -> List[float]:
|
|
"""
|
|
Generate embedding for a single text.
|
|
|
|
Args:
|
|
text: Text to embed
|
|
|
|
Returns:
|
|
Embedding vector
|
|
"""
|
|
# Truncate text if too long
|
|
if len(text) > self.config["max_tokens"] * 4: # Rough token estimation
|
|
text = text[:self.config["max_tokens"] * 4]
|
|
|
|
for attempt in range(self.max_retries):
|
|
try:
|
|
response = await embedding_client.embeddings.create(
|
|
model=self.model,
|
|
input=text
|
|
)
|
|
|
|
return response.data[0].embedding
|
|
|
|
except RateLimitError as e:
|
|
if attempt == self.max_retries - 1:
|
|
raise
|
|
|
|
# Exponential backoff for rate limits
|
|
delay = self.retry_delay * (2 ** attempt)
|
|
logger.warning(f"Rate limit hit, retrying in {delay}s")
|
|
await asyncio.sleep(delay)
|
|
|
|
except APIError as e:
|
|
logger.error(f"OpenAI API error: {e}")
|
|
if attempt == self.max_retries - 1:
|
|
raise
|
|
await asyncio.sleep(self.retry_delay)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error generating embedding: {e}")
|
|
if attempt == self.max_retries - 1:
|
|
raise
|
|
await asyncio.sleep(self.retry_delay)
|
|
|
|
async def generate_embeddings_batch(
|
|
self,
|
|
texts: List[str]
|
|
) -> List[List[float]]:
|
|
"""
|
|
Generate embeddings for a batch of texts.
|
|
|
|
Args:
|
|
texts: List of texts to embed
|
|
|
|
Returns:
|
|
List of embedding vectors
|
|
"""
|
|
# Filter and truncate texts
|
|
processed_texts = []
|
|
for text in texts:
|
|
if not text or not text.strip():
|
|
processed_texts.append("")
|
|
continue
|
|
|
|
# Truncate if too long
|
|
if len(text) > self.config["max_tokens"] * 4:
|
|
text = text[:self.config["max_tokens"] * 4]
|
|
|
|
processed_texts.append(text)
|
|
|
|
for attempt in range(self.max_retries):
|
|
try:
|
|
response = await embedding_client.embeddings.create(
|
|
model=self.model,
|
|
input=processed_texts
|
|
)
|
|
|
|
return [data.embedding for data in response.data]
|
|
|
|
except RateLimitError as e:
|
|
if attempt == self.max_retries - 1:
|
|
raise
|
|
|
|
delay = self.retry_delay * (2 ** attempt)
|
|
logger.warning(f"Rate limit hit, retrying batch in {delay}s")
|
|
await asyncio.sleep(delay)
|
|
|
|
except APIError as e:
|
|
logger.error(f"OpenAI API error in batch: {e}")
|
|
if attempt == self.max_retries - 1:
|
|
# Fallback to individual processing
|
|
return await self._process_individually(processed_texts)
|
|
await asyncio.sleep(self.retry_delay)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error in batch embedding: {e}")
|
|
if attempt == self.max_retries - 1:
|
|
return await self._process_individually(processed_texts)
|
|
await asyncio.sleep(self.retry_delay)
|
|
|
|
async def _process_individually(
|
|
self,
|
|
texts: List[str]
|
|
) -> List[List[float]]:
|
|
"""
|
|
Process texts individually as fallback.
|
|
|
|
Args:
|
|
texts: List of texts to embed
|
|
|
|
Returns:
|
|
List of embedding vectors
|
|
"""
|
|
embeddings = []
|
|
|
|
for text in texts:
|
|
try:
|
|
if not text or not text.strip():
|
|
embeddings.append([0.0] * self.config["dimensions"])
|
|
continue
|
|
|
|
embedding = await self.generate_embedding(text)
|
|
embeddings.append(embedding)
|
|
|
|
# Small delay to avoid overwhelming the API
|
|
await asyncio.sleep(0.1)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to embed text: {e}")
|
|
# Use zero vector as fallback
|
|
embeddings.append([0.0] * self.config["dimensions"])
|
|
|
|
return embeddings
|
|
|
|
async def embed_chunks(
|
|
self,
|
|
chunks: List[DocumentChunk],
|
|
progress_callback: Optional[callable] = None
|
|
) -> List[DocumentChunk]:
|
|
"""
|
|
Generate embeddings for document chunks.
|
|
|
|
Args:
|
|
chunks: List of document chunks
|
|
progress_callback: Optional callback for progress updates
|
|
|
|
Returns:
|
|
Chunks with embeddings added
|
|
"""
|
|
if not chunks:
|
|
return chunks
|
|
|
|
logger.info(f"Generating embeddings for {len(chunks)} chunks")
|
|
|
|
# Process chunks in batches
|
|
embedded_chunks = []
|
|
total_batches = (len(chunks) + self.batch_size - 1) // self.batch_size
|
|
|
|
for i in range(0, len(chunks), self.batch_size):
|
|
batch_chunks = chunks[i:i + self.batch_size]
|
|
batch_texts = [chunk.content for chunk in batch_chunks]
|
|
|
|
try:
|
|
# Generate embeddings for this batch
|
|
embeddings = await self.generate_embeddings_batch(batch_texts)
|
|
|
|
# Add embeddings to chunks
|
|
for chunk, embedding in zip(batch_chunks, embeddings):
|
|
# Create a new chunk with embedding
|
|
embedded_chunk = DocumentChunk(
|
|
content=chunk.content,
|
|
index=chunk.index,
|
|
start_char=chunk.start_char,
|
|
end_char=chunk.end_char,
|
|
metadata={
|
|
**chunk.metadata,
|
|
"embedding_model": self.model,
|
|
"embedding_generated_at": datetime.now().isoformat()
|
|
},
|
|
token_count=chunk.token_count
|
|
)
|
|
|
|
# Add embedding as a separate attribute
|
|
embedded_chunk.embedding = embedding
|
|
embedded_chunks.append(embedded_chunk)
|
|
|
|
# Progress update
|
|
current_batch = (i // self.batch_size) + 1
|
|
if progress_callback:
|
|
progress_callback(current_batch, total_batches)
|
|
|
|
logger.info(f"Processed batch {current_batch}/{total_batches}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to process batch {i//self.batch_size + 1}: {e}")
|
|
|
|
# Add chunks without embeddings as fallback
|
|
for chunk in batch_chunks:
|
|
chunk.metadata.update({
|
|
"embedding_error": str(e),
|
|
"embedding_generated_at": datetime.now().isoformat()
|
|
})
|
|
chunk.embedding = [0.0] * self.config["dimensions"]
|
|
embedded_chunks.append(chunk)
|
|
|
|
logger.info(f"Generated embeddings for {len(embedded_chunks)} chunks")
|
|
return embedded_chunks
|
|
|
|
async def embed_query(self, query: str) -> List[float]:
|
|
"""
|
|
Generate embedding for a search query.
|
|
|
|
Args:
|
|
query: Search query
|
|
|
|
Returns:
|
|
Query embedding
|
|
"""
|
|
return await self.generate_embedding(query)
|
|
|
|
def get_embedding_dimension(self) -> int:
|
|
"""Get the dimension of embeddings for this model."""
|
|
return self.config["dimensions"]
|
|
|
|
|
|
# Cache for embeddings
|
|
class EmbeddingCache:
|
|
"""Simple in-memory cache for embeddings."""
|
|
|
|
def __init__(self, max_size: int = 1000):
|
|
"""Initialize cache."""
|
|
self.cache: Dict[str, List[float]] = {}
|
|
self.access_times: Dict[str, datetime] = {}
|
|
self.max_size = max_size
|
|
|
|
def get(self, text: str) -> Optional[List[float]]:
|
|
"""Get embedding from cache."""
|
|
text_hash = self._hash_text(text)
|
|
if text_hash in self.cache:
|
|
self.access_times[text_hash] = datetime.now()
|
|
return self.cache[text_hash]
|
|
return None
|
|
|
|
def put(self, text: str, embedding: List[float]):
|
|
"""Store embedding in cache."""
|
|
text_hash = self._hash_text(text)
|
|
|
|
# Evict oldest entries if cache is full
|
|
if len(self.cache) >= self.max_size:
|
|
oldest_key = min(self.access_times.keys(), key=lambda k: self.access_times[k])
|
|
del self.cache[oldest_key]
|
|
del self.access_times[oldest_key]
|
|
|
|
self.cache[text_hash] = embedding
|
|
self.access_times[text_hash] = datetime.now()
|
|
|
|
def _hash_text(self, text: str) -> str:
|
|
"""Generate hash for text."""
|
|
import hashlib
|
|
return hashlib.md5(text.encode()).hexdigest()
|
|
|
|
|
|
# Factory function
|
|
def create_embedder(
|
|
model: str = EMBEDDING_MODEL,
|
|
use_cache: bool = True,
|
|
**kwargs
|
|
) -> EmbeddingGenerator:
|
|
"""
|
|
Create embedding generator with optional caching.
|
|
|
|
Args:
|
|
model: Embedding model to use
|
|
use_cache: Whether to use caching
|
|
**kwargs: Additional arguments for EmbeddingGenerator
|
|
|
|
Returns:
|
|
EmbeddingGenerator instance
|
|
"""
|
|
embedder = EmbeddingGenerator(model=model, **kwargs)
|
|
|
|
if use_cache:
|
|
# Add caching capability
|
|
cache = EmbeddingCache()
|
|
original_generate = embedder.generate_embedding
|
|
|
|
async def cached_generate(text: str) -> List[float]:
|
|
cached = cache.get(text)
|
|
if cached is not None:
|
|
return cached
|
|
|
|
embedding = await original_generate(text)
|
|
cache.put(text, embedding)
|
|
return embedding
|
|
|
|
embedder.generate_embedding = cached_generate
|
|
|
|
return embedder
|
|
|
|
|
|
# Example usage
|
|
async def main():
|
|
"""Example usage of the embedder."""
|
|
from .chunker import ChunkingConfig, create_chunker
|
|
|
|
# Create chunker and embedder
|
|
config = ChunkingConfig(chunk_size=200, use_semantic_splitting=False)
|
|
chunker = create_chunker(config)
|
|
embedder = create_embedder()
|
|
|
|
sample_text = """
|
|
Google's AI initiatives include advanced language models, computer vision,
|
|
and machine learning research. The company has invested heavily in
|
|
transformer architectures and neural network optimization.
|
|
|
|
Microsoft's partnership with OpenAI has led to integration of GPT models
|
|
into various products and services, making AI accessible to enterprise
|
|
customers through Azure cloud services.
|
|
"""
|
|
|
|
# Chunk the document
|
|
chunks = chunker.chunk_document(
|
|
content=sample_text,
|
|
title="AI Initiatives",
|
|
source="example.md"
|
|
)
|
|
|
|
print(f"Created {len(chunks)} chunks")
|
|
|
|
# Generate embeddings
|
|
def progress_callback(current, total):
|
|
print(f"Processing batch {current}/{total}")
|
|
|
|
embedded_chunks = await embedder.embed_chunks(chunks, progress_callback)
|
|
|
|
for i, chunk in enumerate(embedded_chunks):
|
|
print(f"Chunk {i}: {len(chunk.content)} chars, embedding dim: {len(chunk.embedding)}")
|
|
|
|
# Test query embedding
|
|
query_embedding = await embedder.embed_query("Google AI research")
|
|
print(f"Query embedding dimension: {len(query_embedding)}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main()) |