/ src / pipeline / steps / embedding_generation_step.py
embedding_generation_step.py
 1  from __future__ import annotations
 2  
 3  """Step that generates embeddings for document chunks."""
 4  
 5  import logging
 6  
 7  from langchain_core.documents import Document
 8  
 9  from ...embeddings.constants import EMBEDDING_METADATA_KEY, EMBEDDING_MODEL_METADATA_KEY
10  from ...embeddings.protocol import Embeddings
11  from ...embeddings.types import EmbeddingModelType
12  from ..contexts.ingestion_context import IngestionContext
13  
14  
15  class EmbeddingGenerationStep:
16      """Step that generates embeddings for document chunks."""
17  
18      def __init__(self, embedding_model: Embeddings, model_name: EmbeddingModelType):
19          """Initialize the embedding generation step.
20  
21          Parameters
22          ----------
23          embedding_model
24              Embedding model instance created by EmbeddingModelFactory.
25          model_name
26              Name of the embedding model (for metadata tracking).
27          """
28          self.embedding_model = embedding_model
29          self.model_name = model_name
30  
31      def run(self, context: IngestionContext) -> None:
32          """Generate embeddings for all chunks.
33  
34          Parameters
35          ----------
36          context
37              Ingestion context with chunks set.
38          """
39          logger = logging.getLogger(__name__)
40          if not context.chunks:
41              context.mark_failed("No chunks available. Chunk step must run first.")
42              return
43  
44          logger.info(f"Generating embeddings for {len(context.chunks)} chunks")
45  
46          texts = [chunk.page_content for chunk in context.chunks]
47          embeddings = self.embedding_model.embed_documents(texts)
48  
49          embedded_chunks = []
50          for chunk, embedding in zip(context.chunks, embeddings, strict=False):
51              embedded_chunk = Document(
52                  page_content=chunk.page_content,
53                  metadata={
54                      **chunk.metadata,
55                      EMBEDDING_METADATA_KEY: embedding,
56                      EMBEDDING_MODEL_METADATA_KEY: self.model_name.value,
57                  }
58              )
59              embedded_chunks.append(embedded_chunk)
60  
61          context.vectors = embeddings
62          context.chunks = embedded_chunks
63  
64          logger.info(f"Generated embeddings for {len(embedded_chunks)} chunks")