/ integrations / gemini.py
gemini.py
1 """Gemini provider implementation of the LLMProvider protocol.""" 2 3 __all__ = ["GeminiClient"] 4 5 import logging 6 import threading 7 from collections.abc import Sequence 8 9 from google import genai 10 from google.genai import types 11 from google.genai.errors import ClientError as _GenAIClientError 12 from google.genai.errors import ServerError as _GenAIServerError 13 from pydantic import BaseModel 14 from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential 15 16 from exceptions import APIError 17 from models.llm import Attachment 18 19 logger = logging.getLogger(__name__) 20 21 22 def _is_transient(exc: BaseException) -> bool: 23 if isinstance(exc, _GenAIClientError) and exc.code == 429: 24 return True 25 return isinstance(exc, _GenAIServerError) 26 27 28 class GeminiClient: 29 """LLMProvider backed by the google-genai SDK.""" 30 31 def __init__(self, api_key: str, model: str) -> None: 32 self._client = genai.Client(api_key=api_key) 33 self._model = model 34 self._tls = threading.local() 35 36 @retry( 37 retry=retry_if_exception(_is_transient), 38 wait=wait_exponential(multiplier=1, min=2, max=60), 39 stop=stop_after_attempt(5), 40 reraise=True, 41 ) 42 def complete( 43 self, 44 system: str, 45 user: str, 46 *, 47 temperature: float | None = None, 48 seed: int | None = None, 49 response_schema: type[BaseModel] | None = None, 50 attachments: Sequence[Attachment] | None = None, 51 ) -> str: 52 """Send a prompt to Gemini and return the response. 53 54 Args: 55 system: System prompt text. 56 user: User message text. 57 temperature: Sampling temperature. 58 seed: Passed in GenerateContentConfig when provided. 59 response_schema: If provided, requests JSON output conforming to the schema. 60 attachments: Binary files sent as ``Part.from_bytes()`` content parts. 61 62 Returns: 63 Response text or JSON string. 64 65 Raises: 66 APIError: On non-retriable errors. 67 """ 68 logger.debug("Calling gemini model=%s", self._model) 69 70 contents: str | list[object] 71 if attachments: 72 parts: list[object] = [ 73 types.Part.from_bytes(data=a.data, mime_type=a.media_type) 74 for a in attachments 75 ] 76 parts.append(user) 77 contents = parts 78 else: 79 contents = user 80 81 config_kwargs: dict[str, object] = { 82 "system_instruction": system, 83 } 84 if temperature is not None: 85 config_kwargs["temperature"] = temperature 86 if seed is not None: 87 config_kwargs["seed"] = seed 88 if response_schema is not None: 89 config_kwargs["response_mime_type"] = "application/json" 90 config_kwargs["response_schema"] = response_schema 91 92 try: 93 response = self._client.models.generate_content( 94 model=self._model, 95 contents=contents, 96 config=types.GenerateContentConfig(**config_kwargs), 97 ) 98 except _GenAIClientError as exc: 99 if exc.code == 429: 100 raise # rate limit — transient, let tenacity retry 101 raise APIError("gemini", exc.code, str(exc)) from exc 102 except _GenAIServerError: 103 # 5xx: transient — re-raise so tenacity can retry 104 raise 105 106 self._tls.input_tokens = ( 107 (response.usage_metadata.prompt_token_count or 0) 108 if response.usage_metadata 109 else 0 110 ) 111 self._tls.output_tokens = ( 112 (response.usage_metadata.candidates_token_count or 0) 113 if response.usage_metadata 114 else 0 115 ) 116 117 result = response.text or "" 118 logger.info("Response received (%d chars)", len(result)) 119 return result 120 121 @property 122 def last_usage(self) -> tuple[int, int]: 123 """(input_tokens, output_tokens) from the most recent successful call.""" 124 return ( 125 getattr(self._tls, "input_tokens", 0), 126 getattr(self._tls, "output_tokens", 0), 127 ) 128 129 def ping(self, temperature: float | None = None) -> None: 130 """Send a minimal request to verify the provider and model are reachable.""" 131 self.complete( 132 system="You are a helpful assistant.", user="Say: OK", temperature=temperature 133 )