/ src / rerankers / gemini.py
gemini.py
  1  from __future__ import annotations
  2  
  3  import logging
  4  import re
  5  from typing import Any
  6  
  7  from google import genai
  8  from google.genai import types
  9  from langchain_core.documents import Document
 10  
 11  from .constants import (
 12      DEFAULT_MAX_RERANK_CHARS,
 13      DEFAULT_RERANK_MODEL,
 14      DEFAULT_SCORING_PROMPT,
 15  )
 16  from .protocol import Reranker
 17  
 18  
 19  class GeminiReranker(Reranker):
 20      """Reranker backed by Gemini API for relevance scoring."""
 21  
 22      def __init__(
 23          self,
 24          model: str = DEFAULT_RERANK_MODEL,
 25          api_key: str | None = None,
 26          max_chars: int = DEFAULT_MAX_RERANK_CHARS,
 27          scoring_prompt: str | None = None,
 28      ):
 29          """Initialize Gemini reranker.
 30  
 31          Parameters
 32          ----------
 33          model
 34              Gemini model name for reranking
 35          api_key
 36              Google API key (if None, uses environment variable)
 37          max_chars
 38              Maximum characters of document content to send for scoring
 39          scoring_prompt
 40              Custom prompt template for scoring (use {query} and {document} placeholders)
 41          """
 42          self.client = genai.Client(api_key=api_key) if api_key else genai.Client()
 43          self.model = model
 44          self.max_chars = max_chars
 45          self.scoring_prompt = scoring_prompt or DEFAULT_SCORING_PROMPT
 46          self.logger = logging.getLogger(__name__)
 47  
 48      def create_scoring_prompt(self, query: str, document_content: str) -> str:
 49          """Generate the scoring prompt for document relevance evaluation.
 50  
 51          Parameters
 52          ----------
 53          query
 54              User's search query
 55          document_content
 56              Document text to evaluate (may include metadata)
 57  
 58          Returns
 59          -------
 60          Formatted prompt string for LLM scoring
 61          """
 62          return self.scoring_prompt.format(query=query, document=document_content)
 63  
 64      def rerank(
 65          self, query: str,
 66          documents: list[tuple[Document, float]]
 67      ) -> list[tuple[Document, float]]:
 68          """Rerank documents using Gemini to score relevance.
 69  
 70          Parameters
 71          ----------
 72          query
 73              User's search query
 74          documents
 75              List of (Document, score) tuples from initial retrieval
 76  
 77          Returns
 78          -------
 79          List of (Document, score) tuples reordered by relevance with new scores
 80          """
 81          if not documents:
 82              return documents
 83  
 84          self.logger.info(f"🔄 Reranking {len(documents)} documents with Gemini")
 85  
 86          reranked = []
 87          successful_scores = 0
 88          zero_scores = 0
 89  
 90          for i, (doc, _original_score) in enumerate(documents, 1):
 91              self.logger.info(f"--- Document {i}/{len(documents)} ---")
 92  
 93              content = doc.page_content[: self.max_chars]
 94              if len(doc.page_content) > self.max_chars:
 95                  content += "...[truncated]"
 96  
 97              # Add metadata context if available
 98              metadata = doc.metadata or {}
 99  
100              source = metadata.get("source")
101  
102              if source:
103                  content = f"[METADATA]\n{source}\n[CONTENT]\n{content}"
104  
105              # Score the document relevance
106              score = self._score_document(query, content)
107              reranked.append((doc, score))
108  
109              if score > 0.0:
110                  successful_scores += 1
111              else:
112                  zero_scores += 1
113  
114          # Sort by new scores (higher is better)
115          reranked.sort(key=lambda x: x[1], reverse=True)
116  
117          self.logger.info(
118              f"✅ Reranking completed: {successful_scores} docs scored > 0, "
119              f"{zero_scores} docs scored 0"
120          )
121          return reranked
122  
123      def _score_document(self, query: str, document_content: str) -> float:
124          """Score a single document's relevance to the query.
125  
126          Parameters
127          ----------
128          query
129              User's search query
130          document_content
131              Document text content (possibly truncated)
132  
133          Returns
134          -------
135          Relevance score from 0.0 to 10.0 (higher is more relevant)
136          """
137          doc_preview = document_content[:150].replace("\n", " ")
138          self.logger.info(f"📝 Evaluating: {doc_preview}...")
139  
140          prompt = self.create_scoring_prompt(query, document_content)
141  
142          try:
143  
144              config = types.GenerateContentConfig(
145                  temperature=0.0,
146                  max_output_tokens=5,
147              )
148  
149              resp = self.client.models.generate_content(
150                  model=self.model,
151                  contents=prompt,
152                  config=config,
153              )
154  
155              # Check if response is None or empty
156              if not resp or not resp.text:
157                  self.logger.info(
158                      f"LLM returned empty response for model {self.model}. "
159                      "This might be a model availability issue. Returning 0.0"
160                  )
161                  return 0.0
162  
163              score_text = resp.text.strip()
164  
165              self.logger.debug(f"LLM prompt: {prompt}")
166              self.logger.debug(f"LLM score: {score_text}")
167  
168              # Try to extract a number from the response
169              # First try to parse the whole thing as a float
170              try:
171                  score = float(score_text)
172                  score = max(0.0, min(10.0, score))
173                  self.logger.info(f"Score: {score:.1f}")
174                  return score
175              except ValueError:
176                  # If that fails, try to find a number with regex
177                  match = re.search(r"\d+\.?\d*", score_text)
178                  if match:
179                      score = float(match.group())
180                      score = max(0.0, min(10.0, score))
181                      self.logger.info(f"Score: {score:.1f} (extracted from text)")
182                      return score
183                  else:
184                      self.logger.info(
185                          f"Could not parse score from LLM response: '{score_text[:100]}'. "
186                          f"Returning 0.0"
187                      )
188                      return 0.0
189  
190          except Exception as e:
191              self.logger.info(f"Error scoring document: {e}")
192              return 0.0
193  
194  
195  def create_gemini_reranker(config: dict[str, Any]) -> Reranker:
196      """Factory function to create a Gemini reranker from config dict.
197  
198      Parameters
199      ----------
200      config
201          Configuration dictionary with optional keys: model, api_key, max_chars, scoring_prompt
202  
203      Returns
204      -------
205      Configured GeminiReranker instance
206      """
207      return GeminiReranker(
208          model=config.get("model", DEFAULT_RERANK_MODEL),
209          api_key=config.get("api_key"),
210          max_chars=config.get("max_chars", DEFAULT_MAX_RERANK_CHARS),
211          scoring_prompt=config.get("scoring_prompt"),
212      )