gemini.py
1 from __future__ import annotations 2 3 import logging 4 import re 5 from typing import Any 6 7 from google import genai 8 from google.genai import types 9 from langchain_core.documents import Document 10 11 from .constants import ( 12 DEFAULT_MAX_RERANK_CHARS, 13 DEFAULT_RERANK_MODEL, 14 DEFAULT_SCORING_PROMPT, 15 ) 16 from .protocol import Reranker 17 18 19 class GeminiReranker(Reranker): 20 """Reranker backed by Gemini API for relevance scoring.""" 21 22 def __init__( 23 self, 24 model: str = DEFAULT_RERANK_MODEL, 25 api_key: str | None = None, 26 max_chars: int = DEFAULT_MAX_RERANK_CHARS, 27 scoring_prompt: str | None = None, 28 ): 29 """Initialize Gemini reranker. 30 31 Parameters 32 ---------- 33 model 34 Gemini model name for reranking 35 api_key 36 Google API key (if None, uses environment variable) 37 max_chars 38 Maximum characters of document content to send for scoring 39 scoring_prompt 40 Custom prompt template for scoring (use {query} and {document} placeholders) 41 """ 42 self.client = genai.Client(api_key=api_key) if api_key else genai.Client() 43 self.model = model 44 self.max_chars = max_chars 45 self.scoring_prompt = scoring_prompt or DEFAULT_SCORING_PROMPT 46 self.logger = logging.getLogger(__name__) 47 48 def create_scoring_prompt(self, query: str, document_content: str) -> str: 49 """Generate the scoring prompt for document relevance evaluation. 50 51 Parameters 52 ---------- 53 query 54 User's search query 55 document_content 56 Document text to evaluate (may include metadata) 57 58 Returns 59 ------- 60 Formatted prompt string for LLM scoring 61 """ 62 return self.scoring_prompt.format(query=query, document=document_content) 63 64 def rerank( 65 self, query: str, 66 documents: list[tuple[Document, float]] 67 ) -> list[tuple[Document, float]]: 68 """Rerank documents using Gemini to score relevance. 69 70 Parameters 71 ---------- 72 query 73 User's search query 74 documents 75 List of (Document, score) tuples from initial retrieval 76 77 Returns 78 ------- 79 List of (Document, score) tuples reordered by relevance with new scores 80 """ 81 if not documents: 82 return documents 83 84 self.logger.info(f"🔄 Reranking {len(documents)} documents with Gemini") 85 86 reranked = [] 87 successful_scores = 0 88 zero_scores = 0 89 90 for i, (doc, _original_score) in enumerate(documents, 1): 91 self.logger.info(f"--- Document {i}/{len(documents)} ---") 92 93 content = doc.page_content[: self.max_chars] 94 if len(doc.page_content) > self.max_chars: 95 content += "...[truncated]" 96 97 # Add metadata context if available 98 metadata = doc.metadata or {} 99 100 source = metadata.get("source") 101 102 if source: 103 content = f"[METADATA]\n{source}\n[CONTENT]\n{content}" 104 105 # Score the document relevance 106 score = self._score_document(query, content) 107 reranked.append((doc, score)) 108 109 if score > 0.0: 110 successful_scores += 1 111 else: 112 zero_scores += 1 113 114 # Sort by new scores (higher is better) 115 reranked.sort(key=lambda x: x[1], reverse=True) 116 117 self.logger.info( 118 f"✅ Reranking completed: {successful_scores} docs scored > 0, " 119 f"{zero_scores} docs scored 0" 120 ) 121 return reranked 122 123 def _score_document(self, query: str, document_content: str) -> float: 124 """Score a single document's relevance to the query. 125 126 Parameters 127 ---------- 128 query 129 User's search query 130 document_content 131 Document text content (possibly truncated) 132 133 Returns 134 ------- 135 Relevance score from 0.0 to 10.0 (higher is more relevant) 136 """ 137 doc_preview = document_content[:150].replace("\n", " ") 138 self.logger.info(f"📝 Evaluating: {doc_preview}...") 139 140 prompt = self.create_scoring_prompt(query, document_content) 141 142 try: 143 144 config = types.GenerateContentConfig( 145 temperature=0.0, 146 max_output_tokens=5, 147 ) 148 149 resp = self.client.models.generate_content( 150 model=self.model, 151 contents=prompt, 152 config=config, 153 ) 154 155 # Check if response is None or empty 156 if not resp or not resp.text: 157 self.logger.info( 158 f"LLM returned empty response for model {self.model}. " 159 "This might be a model availability issue. Returning 0.0" 160 ) 161 return 0.0 162 163 score_text = resp.text.strip() 164 165 self.logger.debug(f"LLM prompt: {prompt}") 166 self.logger.debug(f"LLM score: {score_text}") 167 168 # Try to extract a number from the response 169 # First try to parse the whole thing as a float 170 try: 171 score = float(score_text) 172 score = max(0.0, min(10.0, score)) 173 self.logger.info(f"Score: {score:.1f}") 174 return score 175 except ValueError: 176 # If that fails, try to find a number with regex 177 match = re.search(r"\d+\.?\d*", score_text) 178 if match: 179 score = float(match.group()) 180 score = max(0.0, min(10.0, score)) 181 self.logger.info(f"Score: {score:.1f} (extracted from text)") 182 return score 183 else: 184 self.logger.info( 185 f"Could not parse score from LLM response: '{score_text[:100]}'. " 186 f"Returning 0.0" 187 ) 188 return 0.0 189 190 except Exception as e: 191 self.logger.info(f"Error scoring document: {e}") 192 return 0.0 193 194 195 def create_gemini_reranker(config: dict[str, Any]) -> Reranker: 196 """Factory function to create a Gemini reranker from config dict. 197 198 Parameters 199 ---------- 200 config 201 Configuration dictionary with optional keys: model, api_key, max_chars, scoring_prompt 202 203 Returns 204 ------- 205 Configured GeminiReranker instance 206 """ 207 return GeminiReranker( 208 model=config.get("model", DEFAULT_RERANK_MODEL), 209 api_key=config.get("api_key"), 210 max_chars=config.get("max_chars", DEFAULT_MAX_RERANK_CHARS), 211 scoring_prompt=config.get("scoring_prompt"), 212 )