rerank_step.py
1 from __future__ import annotations 2 3 """Step that reranks retrieved documents using a cross-encoder.""" 4 5 import logging 6 7 from ...rerankers.protocol import Reranker 8 from ..contexts.query_context import QueryContext 9 from ..step import PipelineStep 10 11 12 class RerankStep(PipelineStep): 13 """Step that reranks retrieved documents using a cross-encoder. 14 15 This step takes documents from initial retrieval (with vector similarity scores) 16 and reorders them using a more sophisticated cross-encoder model that can 17 better assess relevance between the query and document content. 18 """ 19 20 def __init__(self, reranker: Reranker): 21 """Initialize the rerank step. 22 23 Parameters 24 ---------- 25 reranker 26 Reranker instance created by RerankerFactory. 27 """ 28 self.reranker = reranker 29 30 def run(self, context: QueryContext) -> None: 31 """Rerank retrieved documents for improved relevance. 32 33 Parameters 34 ---------- 35 context 36 Query context with retrieved_docs set from RetrieveStep. 37 """ 38 logger = logging.getLogger(__name__) 39 40 if not context.retrieved_docs: 41 logger.info("No documents to rerank. Skipping rerank step.") 42 return 43 44 initial_count = len(context.retrieved_docs) 45 logger.info(f"Reranking {initial_count} retrieved documents using cross-encoder") 46 47 logger.info("Initial retrieval scores (before reranking):") 48 for i, (doc, score) in enumerate(context.retrieved_docs, 1): 49 source = doc.metadata.get("source", "unknown") 50 logger.info(f" [{i}] distance_score={score:.4f} source={source}") 51 52 reranked_docs = self.reranker.rerank( 53 context.user_query, 54 context.retrieved_docs 55 ) 56 57 context.retrieved_docs = reranked_docs 58 59 # Log top reranked results with scores 60 logger.info("Top reranked results (higher score = more relevant):") 61 for i, (doc, score) in enumerate(reranked_docs[:5], 1): 62 source = doc.metadata.get("source", "unknown") 63 content_preview = doc.page_content[:100].replace("\n", " ") 64 logger.info(f" [{i}] Score: {score:.1f} | {content_preview}...")