answer_generation_step.py
1 from __future__ import annotations 2 3 """Step that generates the final answer from retrieved documents.""" 4 5 import logging 6 from pathlib import Path 7 8 from langchain_core.documents import Document 9 10 from ...llms.protocol import LLM 11 from ..contexts.query_context import QueryContext 12 from ..step import PipelineStep 13 from .constants import DEFAULT_GENERATION_PROMPT, DEFAULT_MAX_CONTEXT_CHARS 14 15 16 class GenerationStep(PipelineStep): 17 """Step that generates the final answer from retrieved documents.""" 18 19 def __init__( 20 self, 21 llm: LLM, 22 max_context_chars: int = DEFAULT_MAX_CONTEXT_CHARS, 23 prompt_template: str | None = None, 24 ): 25 """Initialize the generation step. 26 27 Parameters 28 ---------- 29 llm 30 LLM instance created by LLMFactory (or wired manually). 31 max_context_chars 32 Max characters of retrieved context injected into the prompt. 33 prompt_template 34 Custom prompt template (use {context} and {query} placeholders). 35 If None, uses default prompt. 36 """ 37 self.llm = llm 38 self.max_context_chars = max_context_chars 39 self.prompt_template = prompt_template or DEFAULT_GENERATION_PROMPT 40 self._logger = logging.getLogger(__name__) 41 42 def create_prompt(self, context_text: str, query: str) -> str: 43 """Create the generation prompt using template. 44 45 Parameters 46 ---------- 47 context_text 48 Formatted context from retrieved documents 49 query 50 User's query 51 52 Returns 53 ------- 54 Formatted prompt string for LLM generation 55 """ 56 return self.prompt_template.format( 57 context=context_text, 58 query=query 59 ).strip() 60 61 def run(self, context: QueryContext) -> None: 62 """Generate the final answer. 63 64 Parameters 65 ---------- 66 context 67 Query context with retrieved_docs set. 68 """ 69 if not context.retrieved_docs: 70 self._logger.info("No retrieved docs available. Skipping answer generation.") 71 self._append_access_denial_notice(context) 72 return 73 74 context_text = self._build_context_and_citations(context) 75 76 prompt = self.create_prompt(context_text, context.user_query) 77 context.prompt = prompt 78 79 if not self._generate_response(context, prompt): 80 return 81 82 self._append_access_denial_notice(context) 83 self._logger.info("Generated final answer successfully") 84 85 def _build_context_and_citations(self, context: QueryContext) -> str: 86 """Build context text and citations from retrieved documents.""" 87 blocks: list[str] = [] 88 citations: list[dict] = [] 89 90 for i, (doc, score) in enumerate(context.retrieved_docs, start=1): 91 source, page = self._extract_source_and_page(doc) 92 citations.append({ 93 "id": i, 94 "source": source, 95 "page": page, 96 "score": score, 97 "content": doc.page_content, 98 }) 99 blocks.append( 100 f"[{i}] SOURCE={source} PAGE={page} SCORE={score}\n{doc.page_content}" 101 ) 102 103 context.citations = citations 104 105 context_text = "\n\n---\n\n".join(blocks) 106 if len(context_text) > self.max_context_chars: 107 context_text = context_text[: self.max_context_chars] + "\n\n[TRUNCATED]\n" 108 109 return context_text 110 111 def _extract_source_and_page(self, doc: Document) -> tuple[str | None, int | None]: 112 """Extract source and page from document metadata.""" 113 meta = doc.metadata or {} 114 source = meta.get("source") or meta.get("file_path") or meta.get("filename") 115 page = meta.get("page") or meta.get("page_number") 116 117 # Extract only filename from path 118 if source: 119 source = Path(source).name 120 121 return source, page 122 123 def _generate_response(self, context: QueryContext, prompt: str) -> bool: 124 """Generate LLM response. Returns True on success, False on failure.""" 125 try: 126 context.llm_response = self.llm.generate(prompt) 127 return True 128 except Exception as e: 129 context.mark_failed(f"LLM generation failed: {e}") 130 return False 131 132 def _append_access_denial_notice(self, context: QueryContext) -> None: 133 """Append access denial notice if there are restricted documents.""" 134 if not context.has_restricted_content or not context.restricted_docs: 135 return 136 137 denial_message = self._format_access_denial_message(context.restricted_docs) 138 context.llm_response = (context.llm_response or "") + denial_message 139 self._logger.info( 140 f"Appended access denial notice for {len(context.restricted_docs)} restricted documents" 141 ) 142 143 def _format_access_denial_message(self, restricted_docs: list[dict]) -> str: 144 """Format the access denial message for restricted documents.""" 145 seen: set[tuple[str, tuple[str, ...]]] = set() 146 lines = [] 147 148 for doc in restricted_docs: 149 source = doc.get("source") 150 source_name = Path(source).name if source else "Unknown source" 151 required_tags = tuple(sorted(doc.get("required_tags", []))) 152 key = (source_name, required_tags) 153 154 if key in seen: 155 continue 156 seen.add(key) 157 158 if required_tags: 159 tags_str = ", ".join(required_tags) 160 lines.append(f"- {source_name} (requires: {tags_str})") 161 else: 162 lines.append(f"- {source_name}") 163 164 source_list = "\n".join(lines) 165 166 return ( 167 f"\n\n---\n**Note:** {len(lines)} additional document(s) " 168 f"matched your query but you lack permission to access them:\n{source_list}" 169 )