/ src / components.py
components.py
  1  from __future__ import annotations
  2  
  3  """Utility module for initializing RAG pipeline components from configuration."""
  4  
  5  import logging
  6  from typing import NamedTuple
  7  
  8  from . import (
  9      Config,
 10      EmbeddingModelFactory,
 11      LLMFactory,
 12      RerankerFactory,
 13      RetrieverFactory,
 14      VectorStoreFactory,
 15  )
 16  from .embeddings.protocol import Embeddings
 17  from .llms.protocol import LLM
 18  from .pipeline import PipelineExecutor, QueryContext
 19  from .pipeline.steps import (
 20      AccessControlStep,
 21      GenerationStep,
 22      QueryEmbeddingStep,
 23      RerankStep,
 24      RetrieveStep,
 25  )
 26  from .rerankers.protocol import Reranker
 27  from .retrievers.protocol import Retriever
 28  from .vector_stores.protocol import VectorStore
 29  
 30  
 31  class RAGComponents(NamedTuple):
 32      """Container for initialized RAG pipeline components."""
 33  
 34      embedding_model: Embeddings | None
 35      vector_store: VectorStore | None
 36      retriever: Retriever | None
 37      reranker: Reranker | None
 38      llm: LLM | None
 39  
 40  
 41  def initialize_rag_components() -> RAGComponents:
 42      """Initialize all RAG pipeline components from configuration.
 43  
 44      Returns
 45      -------
 46      RAGComponents
 47          Named tuple containing all initialized components
 48      """
 49      config = Config.get_config()
 50  
 51      logger = logging.getLogger(__name__)
 52      logger.info("Initializing RAG components...")
 53  
 54      pipeline_config = config.pipeline.query
 55  
 56      embedding_model = None
 57      vector_store = None
 58      retriever = None
 59      reranker = None
 60      llm = None
 61  
 62      logger.info(f"Retrieve step - enabled: {pipeline_config.retrieve_enabled}")
 63      logger.info(f"Rerank step - enabled: {pipeline_config.rerank_enabled}")
 64      logger.info(f"Generation step - enabled: {pipeline_config.generation_enabled}")
 65  
 66      # Only create embedding, vector_store and retriever if retrieve step is enabled
 67      # (retrieve includes query embedding as they are dependent)
 68      if pipeline_config.retrieve_enabled:
 69          embedding_model = EmbeddingModelFactory.create(
 70              config.embedding.embed_name,
 71              **(config.embedding.embed_config or {}),
 72          )
 73  
 74          store_config = {
 75                  "embedding_function": embedding_model,
 76                  **(config.vector_store.store_config or {}),
 77          }
 78  
 79          vector_store = VectorStoreFactory.create(
 80              config.vector_store.store_name,
 81              **store_config,
 82          )
 83  
 84          retriever_kwargs = {"vector_store": vector_store, "k": config.retrieval.k}
 85          if config.retrieval.searcher_config:
 86              retriever_kwargs.update(config.retrieval.searcher_config)
 87  
 88          retriever = RetrieverFactory.create(
 89              config.retrieval.searcher_strategy,
 90              **retriever_kwargs,
 91          )
 92  
 93      # Only create reranker if rerank step is enabled
 94      if pipeline_config.rerank_enabled:
 95          reranker = RerankerFactory.create(
 96              config.reranking.reranker_name,
 97              **(config.reranking.reranker_config or {}),
 98          )
 99  
100      if pipeline_config.generation_enabled:
101          llm = LLMFactory.create(
102              config.llm.llm_name,
103              **(config.llm.llm_config or {}),
104          )
105  
106      logger.info("RAG components initialized successfully")
107  
108      return RAGComponents(
109          embedding_model=embedding_model,
110          vector_store=vector_store,
111          retriever=retriever,
112          reranker=reranker,
113          llm=llm,
114      )
115  
116  
117  def execute_query(
118      components: RAGComponents,
119      query: str,
120      user_role: str | None = None,
121      role_mapping: dict[str, list[str]] | None = None,
122  ) -> QueryContext:
123      """Execute a RAG query using the provided components
124  
125      Parameters
126      ----------
127      components : RAGComponents
128          Initialized RAG components
129      query : str
130          User's question or query text
131      user_role : str, optional
132          User's role for access control (expanded to tags via role_mapping)
133      role_mapping : dict[str, list[str]], optional
134          Role-to-tags mapping for access control
135  
136      Returns
137      -------
138      QueryContext
139          Pipeline context containing the query results
140      """
141      config = Config.get_config()
142  
143      logger = logging.getLogger(__name__)
144      logger.info(f"Executing query: {query}")
145  
146      context = QueryContext(user_query=query)
147  
148      # Set tag-based access control fields if provided
149      if user_role:
150          context.user_role = user_role
151      if role_mapping:
152          context.role_mapping = role_mapping
153  
154      # Build steps list dynamically based on pipeline configuration
155      steps = []
156  
157      if components.embedding_model:
158          steps.append(QueryEmbeddingStep(components.embedding_model))
159  
160      if components.retriever:
161          steps.append(RetrieveStep(components.retriever))
162  
163          if config.access_control.notify_on_denied_access:
164              steps.append(AccessControlStep())
165  
166      if components.reranker:
167          steps.append(RerankStep(components.reranker))
168  
169      if components.llm:
170          # Get prompt configuration from pipeline config
171          pipeline_config = config.pipeline.query if config.pipeline else None
172          prompt_template = pipeline_config.generation_prompt if pipeline_config else None
173  
174          steps.append(
175              GenerationStep(
176                  components.llm,
177                  prompt_template=prompt_template,
178              )
179          )
180  
181      executor = PipelineExecutor(steps)
182      context = executor.execute(context)
183  
184      return context