/ src / pipeline / steps / answer_generation_step.py
answer_generation_step.py
  1  from __future__ import annotations
  2  
  3  """Step that generates the final answer from retrieved documents."""
  4  
  5  import logging
  6  from pathlib import Path
  7  
  8  from langchain_core.documents import Document
  9  
 10  from ...llms.protocol import LLM
 11  from ..contexts.query_context import QueryContext
 12  from ..step import PipelineStep
 13  from .constants import DEFAULT_GENERATION_PROMPT, DEFAULT_MAX_CONTEXT_CHARS
 14  
 15  
 16  class GenerationStep(PipelineStep):
 17      """Step that generates the final answer from retrieved documents."""
 18  
 19      def __init__(
 20          self,
 21          llm: LLM,
 22          max_context_chars: int = DEFAULT_MAX_CONTEXT_CHARS,
 23          prompt_template: str | None = None,
 24      ):
 25          """Initialize the generation step.
 26  
 27          Parameters
 28          ----------
 29          llm
 30              LLM instance created by LLMFactory (or wired manually).
 31          max_context_chars
 32              Max characters of retrieved context injected into the prompt.
 33          prompt_template
 34              Custom prompt template (use {context} and {query} placeholders).
 35              If None, uses default prompt.
 36          """
 37          self.llm = llm
 38          self.max_context_chars = max_context_chars
 39          self.prompt_template = prompt_template or DEFAULT_GENERATION_PROMPT
 40          self._logger = logging.getLogger(__name__)
 41  
 42      def create_prompt(self, context_text: str, query: str) -> str:
 43          """Create the generation prompt using template.
 44  
 45          Parameters
 46          ----------
 47          context_text
 48              Formatted context from retrieved documents
 49          query
 50              User's query
 51  
 52          Returns
 53          -------
 54          Formatted prompt string for LLM generation
 55          """
 56          return self.prompt_template.format(
 57              context=context_text,
 58              query=query
 59          ).strip()
 60  
 61      def run(self, context: QueryContext) -> None:
 62          """Generate the final answer.
 63  
 64          Parameters
 65          ----------
 66          context
 67              Query context with retrieved_docs set.
 68          """
 69          if not context.retrieved_docs:
 70              self._logger.info("No retrieved docs available. Skipping answer generation.")
 71              self._append_access_denial_notice(context)
 72              return
 73  
 74          context_text = self._build_context_and_citations(context)
 75  
 76          prompt = self.create_prompt(context_text, context.user_query)
 77          context.prompt = prompt
 78  
 79          if not self._generate_response(context, prompt):
 80              return
 81  
 82          self._append_access_denial_notice(context)
 83          self._logger.info("Generated final answer successfully")
 84  
 85      def _build_context_and_citations(self, context: QueryContext) -> str:
 86          """Build context text and citations from retrieved documents."""
 87          blocks: list[str] = []
 88          citations: list[dict] = []
 89  
 90          for i, (doc, score) in enumerate(context.retrieved_docs, start=1):
 91              source, page = self._extract_source_and_page(doc)
 92              citations.append({
 93                  "id": i,
 94                  "source": source,
 95                  "page": page,
 96                  "score": score,
 97                  "content": doc.page_content,
 98              })
 99              blocks.append(
100                  f"[{i}] SOURCE={source} PAGE={page} SCORE={score}\n{doc.page_content}"
101              )
102  
103          context.citations = citations
104  
105          context_text = "\n\n---\n\n".join(blocks)
106          if len(context_text) > self.max_context_chars:
107              context_text = context_text[: self.max_context_chars] + "\n\n[TRUNCATED]\n"
108  
109          return context_text
110  
111      def _extract_source_and_page(self, doc: Document) -> tuple[str | None, int | None]:
112          """Extract source and page from document metadata."""
113          meta = doc.metadata or {}
114          source = meta.get("source") or meta.get("file_path") or meta.get("filename")
115          page = meta.get("page") or meta.get("page_number")
116  
117          # Extract only filename from path
118          if source:
119              source = Path(source).name
120  
121          return source, page
122  
123      def _generate_response(self, context: QueryContext, prompt: str) -> bool:
124          """Generate LLM response. Returns True on success, False on failure."""
125          try:
126              context.llm_response = self.llm.generate(prompt)
127              return True
128          except Exception as e:
129              context.mark_failed(f"LLM generation failed: {e}")
130              return False
131  
132      def _append_access_denial_notice(self, context: QueryContext) -> None:
133          """Append access denial notice if there are restricted documents."""
134          if not context.has_restricted_content or not context.restricted_docs:
135              return
136  
137          denial_message = self._format_access_denial_message(context.restricted_docs)
138          context.llm_response = (context.llm_response or "") + denial_message
139          self._logger.info(
140              f"Appended access denial notice for {len(context.restricted_docs)} restricted documents"
141          )
142  
143      def _format_access_denial_message(self, restricted_docs: list[dict]) -> str:
144          """Format the access denial message for restricted documents."""
145          seen: set[tuple[str, tuple[str, ...]]] = set()
146          lines = []
147  
148          for doc in restricted_docs:
149              source = doc.get("source")
150              source_name = Path(source).name if source else "Unknown source"
151              required_tags = tuple(sorted(doc.get("required_tags", [])))
152              key = (source_name, required_tags)
153  
154              if key in seen:
155                  continue
156              seen.add(key)
157  
158              if required_tags:
159                  tags_str = ", ".join(required_tags)
160                  lines.append(f"- {source_name} (requires: {tags_str})")
161              else:
162                  lines.append(f"- {source_name}")
163  
164          source_list = "\n".join(lines)
165  
166          return (
167              f"\n\n---\n**Note:** {len(lines)} additional document(s) "
168              f"matched your query but you lack permission to access them:\n{source_list}"
169          )