components.py
1 from __future__ import annotations 2 3 """Utility module for initializing RAG pipeline components from configuration.""" 4 5 import logging 6 from typing import NamedTuple 7 8 from . import ( 9 Config, 10 EmbeddingModelFactory, 11 LLMFactory, 12 RerankerFactory, 13 RetrieverFactory, 14 VectorStoreFactory, 15 ) 16 from .embeddings.protocol import Embeddings 17 from .llms.protocol import LLM 18 from .pipeline import PipelineExecutor, QueryContext 19 from .pipeline.steps import ( 20 AccessControlStep, 21 GenerationStep, 22 QueryEmbeddingStep, 23 RerankStep, 24 RetrieveStep, 25 ) 26 from .rerankers.protocol import Reranker 27 from .retrievers.protocol import Retriever 28 from .vector_stores.protocol import VectorStore 29 30 31 class RAGComponents(NamedTuple): 32 """Container for initialized RAG pipeline components.""" 33 34 embedding_model: Embeddings | None 35 vector_store: VectorStore | None 36 retriever: Retriever | None 37 reranker: Reranker | None 38 llm: LLM | None 39 40 41 def initialize_rag_components() -> RAGComponents: 42 """Initialize all RAG pipeline components from configuration. 43 44 Returns 45 ------- 46 RAGComponents 47 Named tuple containing all initialized components 48 """ 49 config = Config.get_config() 50 51 logger = logging.getLogger(__name__) 52 logger.info("Initializing RAG components...") 53 54 pipeline_config = config.pipeline.query 55 56 embedding_model = None 57 vector_store = None 58 retriever = None 59 reranker = None 60 llm = None 61 62 logger.info(f"Retrieve step - enabled: {pipeline_config.retrieve_enabled}") 63 logger.info(f"Rerank step - enabled: {pipeline_config.rerank_enabled}") 64 logger.info(f"Generation step - enabled: {pipeline_config.generation_enabled}") 65 66 # Only create embedding, vector_store and retriever if retrieve step is enabled 67 # (retrieve includes query embedding as they are dependent) 68 if pipeline_config.retrieve_enabled: 69 embedding_model = EmbeddingModelFactory.create( 70 config.embedding.embed_name, 71 **(config.embedding.embed_config or {}), 72 ) 73 74 store_config = { 75 "embedding_function": embedding_model, 76 **(config.vector_store.store_config or {}), 77 } 78 79 vector_store = VectorStoreFactory.create( 80 config.vector_store.store_name, 81 **store_config, 82 ) 83 84 retriever_kwargs = {"vector_store": vector_store, "k": config.retrieval.k} 85 if config.retrieval.searcher_config: 86 retriever_kwargs.update(config.retrieval.searcher_config) 87 88 retriever = RetrieverFactory.create( 89 config.retrieval.searcher_strategy, 90 **retriever_kwargs, 91 ) 92 93 # Only create reranker if rerank step is enabled 94 if pipeline_config.rerank_enabled: 95 reranker = RerankerFactory.create( 96 config.reranking.reranker_name, 97 **(config.reranking.reranker_config or {}), 98 ) 99 100 if pipeline_config.generation_enabled: 101 llm = LLMFactory.create( 102 config.llm.llm_name, 103 **(config.llm.llm_config or {}), 104 ) 105 106 logger.info("RAG components initialized successfully") 107 108 return RAGComponents( 109 embedding_model=embedding_model, 110 vector_store=vector_store, 111 retriever=retriever, 112 reranker=reranker, 113 llm=llm, 114 ) 115 116 117 def execute_query( 118 components: RAGComponents, 119 query: str, 120 user_role: str | None = None, 121 role_mapping: dict[str, list[str]] | None = None, 122 ) -> QueryContext: 123 """Execute a RAG query using the provided components 124 125 Parameters 126 ---------- 127 components : RAGComponents 128 Initialized RAG components 129 query : str 130 User's question or query text 131 user_role : str, optional 132 User's role for access control (expanded to tags via role_mapping) 133 role_mapping : dict[str, list[str]], optional 134 Role-to-tags mapping for access control 135 136 Returns 137 ------- 138 QueryContext 139 Pipeline context containing the query results 140 """ 141 config = Config.get_config() 142 143 logger = logging.getLogger(__name__) 144 logger.info(f"Executing query: {query}") 145 146 context = QueryContext(user_query=query) 147 148 # Set tag-based access control fields if provided 149 if user_role: 150 context.user_role = user_role 151 if role_mapping: 152 context.role_mapping = role_mapping 153 154 # Build steps list dynamically based on pipeline configuration 155 steps = [] 156 157 if components.embedding_model: 158 steps.append(QueryEmbeddingStep(components.embedding_model)) 159 160 if components.retriever: 161 steps.append(RetrieveStep(components.retriever)) 162 163 if config.access_control.notify_on_denied_access: 164 steps.append(AccessControlStep()) 165 166 if components.reranker: 167 steps.append(RerankStep(components.reranker)) 168 169 if components.llm: 170 # Get prompt configuration from pipeline config 171 pipeline_config = config.pipeline.query if config.pipeline else None 172 prompt_template = pipeline_config.generation_prompt if pipeline_config else None 173 174 steps.append( 175 GenerationStep( 176 components.llm, 177 prompt_template=prompt_template, 178 ) 179 ) 180 181 executor = PipelineExecutor(steps) 182 context = executor.execute(context) 183 184 return context