/ src / pipeline / steps / rerank_step.py
rerank_step.py
 1  from __future__ import annotations
 2  
 3  """Step that reranks retrieved documents using a cross-encoder."""
 4  
 5  import logging
 6  
 7  from ...rerankers.protocol import Reranker
 8  from ..contexts.query_context import QueryContext
 9  from ..step import PipelineStep
10  
11  
12  class RerankStep(PipelineStep):
13      """Step that reranks retrieved documents using a cross-encoder.
14  
15      This step takes documents from initial retrieval (with vector similarity scores)
16      and reorders them using a more sophisticated cross-encoder model that can
17      better assess relevance between the query and document content.
18      """
19  
20      def __init__(self, reranker: Reranker):
21          """Initialize the rerank step.
22  
23          Parameters
24          ----------
25          reranker
26              Reranker instance created by RerankerFactory.
27          """
28          self.reranker = reranker
29  
30      def run(self, context: QueryContext) -> None:
31          """Rerank retrieved documents for improved relevance.
32  
33          Parameters
34          ----------
35          context
36              Query context with retrieved_docs set from RetrieveStep.
37          """
38          logger = logging.getLogger(__name__)
39  
40          if not context.retrieved_docs:
41              logger.info("No documents to rerank. Skipping rerank step.")
42              return
43  
44          initial_count = len(context.retrieved_docs)
45          logger.info(f"Reranking {initial_count} retrieved documents using cross-encoder")
46  
47          logger.info("Initial retrieval scores (before reranking):")
48          for i, (doc, score) in enumerate(context.retrieved_docs, 1):
49              source = doc.metadata.get("source", "unknown")
50              logger.info(f"  [{i}] distance_score={score:.4f} source={source}")
51  
52          reranked_docs = self.reranker.rerank(
53              context.user_query,
54              context.retrieved_docs
55          )
56  
57          context.retrieved_docs = reranked_docs
58  
59          # Log top reranked results with scores
60          logger.info("Top reranked results (higher score = more relevant):")
61          for i, (doc, score) in enumerate(reranked_docs[:5], 1):
62              source = doc.metadata.get("source", "unknown")
63              content_preview = doc.page_content[:100].replace("\n", " ")
64              logger.info(f"  [{i}] Score: {score:.1f} | {content_preview}...")