embedding_generation_step.py
1 from __future__ import annotations 2 3 """Step that generates embeddings for document chunks.""" 4 5 import logging 6 7 from langchain_core.documents import Document 8 9 from ...embeddings.constants import EMBEDDING_METADATA_KEY, EMBEDDING_MODEL_METADATA_KEY 10 from ...embeddings.protocol import Embeddings 11 from ...embeddings.types import EmbeddingModelType 12 from ..contexts.ingestion_context import IngestionContext 13 14 15 class EmbeddingGenerationStep: 16 """Step that generates embeddings for document chunks.""" 17 18 def __init__(self, embedding_model: Embeddings, model_name: EmbeddingModelType): 19 """Initialize the embedding generation step. 20 21 Parameters 22 ---------- 23 embedding_model 24 Embedding model instance created by EmbeddingModelFactory. 25 model_name 26 Name of the embedding model (for metadata tracking). 27 """ 28 self.embedding_model = embedding_model 29 self.model_name = model_name 30 31 def run(self, context: IngestionContext) -> None: 32 """Generate embeddings for all chunks. 33 34 Parameters 35 ---------- 36 context 37 Ingestion context with chunks set. 38 """ 39 logger = logging.getLogger(__name__) 40 if not context.chunks: 41 context.mark_failed("No chunks available. Chunk step must run first.") 42 return 43 44 logger.info(f"Generating embeddings for {len(context.chunks)} chunks") 45 46 texts = [chunk.page_content for chunk in context.chunks] 47 embeddings = self.embedding_model.embed_documents(texts) 48 49 embedded_chunks = [] 50 for chunk, embedding in zip(context.chunks, embeddings, strict=False): 51 embedded_chunk = Document( 52 page_content=chunk.page_content, 53 metadata={ 54 **chunk.metadata, 55 EMBEDDING_METADATA_KEY: embedding, 56 EMBEDDING_MODEL_METADATA_KEY: self.model_name.value, 57 } 58 ) 59 embedded_chunks.append(embedded_chunk) 60 61 context.vectors = embeddings 62 context.chunks = embedded_chunks 63 64 logger.info(f"Generated embeddings for {len(embedded_chunks)} chunks")