/ integrations / gemini.py
gemini.py
  1  """Gemini provider implementation of the LLMProvider protocol."""
  2  
  3  __all__ = ["GeminiClient"]
  4  
  5  import logging
  6  import threading
  7  from collections.abc import Sequence
  8  
  9  from google import genai
 10  from google.genai import types
 11  from google.genai.errors import ClientError as _GenAIClientError
 12  from google.genai.errors import ServerError as _GenAIServerError
 13  from pydantic import BaseModel
 14  from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
 15  
 16  from exceptions import APIError
 17  from models.llm import Attachment
 18  
 19  logger = logging.getLogger(__name__)
 20  
 21  
 22  def _is_transient(exc: BaseException) -> bool:
 23      if isinstance(exc, _GenAIClientError) and exc.code == 429:
 24          return True
 25      return isinstance(exc, _GenAIServerError)
 26  
 27  
 28  class GeminiClient:
 29      """LLMProvider backed by the google-genai SDK."""
 30  
 31      def __init__(self, api_key: str, model: str) -> None:
 32          self._client = genai.Client(api_key=api_key)
 33          self._model = model
 34          self._tls = threading.local()
 35  
 36      @retry(
 37          retry=retry_if_exception(_is_transient),
 38          wait=wait_exponential(multiplier=1, min=2, max=60),
 39          stop=stop_after_attempt(5),
 40          reraise=True,
 41      )
 42      def complete(
 43          self,
 44          system: str,
 45          user: str,
 46          *,
 47          temperature: float | None = None,
 48          seed: int | None = None,
 49          response_schema: type[BaseModel] | None = None,
 50          attachments: Sequence[Attachment] | None = None,
 51      ) -> str:
 52          """Send a prompt to Gemini and return the response.
 53  
 54          Args:
 55              system: System prompt text.
 56              user: User message text.
 57              temperature: Sampling temperature.
 58              seed: Passed in GenerateContentConfig when provided.
 59              response_schema: If provided, requests JSON output conforming to the schema.
 60              attachments: Binary files sent as ``Part.from_bytes()`` content parts.
 61  
 62          Returns:
 63              Response text or JSON string.
 64  
 65          Raises:
 66              APIError: On non-retriable errors.
 67          """
 68          logger.debug("Calling gemini model=%s", self._model)
 69  
 70          contents: str | list[object]
 71          if attachments:
 72              parts: list[object] = [
 73                  types.Part.from_bytes(data=a.data, mime_type=a.media_type)
 74                  for a in attachments
 75              ]
 76              parts.append(user)
 77              contents = parts
 78          else:
 79              contents = user
 80  
 81          config_kwargs: dict[str, object] = {
 82              "system_instruction": system,
 83          }
 84          if temperature is not None:
 85              config_kwargs["temperature"] = temperature
 86          if seed is not None:
 87              config_kwargs["seed"] = seed
 88          if response_schema is not None:
 89              config_kwargs["response_mime_type"] = "application/json"
 90              config_kwargs["response_schema"] = response_schema
 91  
 92          try:
 93              response = self._client.models.generate_content(
 94                  model=self._model,
 95                  contents=contents,
 96                  config=types.GenerateContentConfig(**config_kwargs),
 97              )
 98          except _GenAIClientError as exc:
 99              if exc.code == 429:
100                  raise  # rate limit — transient, let tenacity retry
101              raise APIError("gemini", exc.code, str(exc)) from exc
102          except _GenAIServerError:
103              # 5xx: transient — re-raise so tenacity can retry
104              raise
105  
106          self._tls.input_tokens = (
107              (response.usage_metadata.prompt_token_count or 0)
108              if response.usage_metadata
109              else 0
110          )
111          self._tls.output_tokens = (
112              (response.usage_metadata.candidates_token_count or 0)
113              if response.usage_metadata
114              else 0
115          )
116  
117          result = response.text or ""
118          logger.info("Response received (%d chars)", len(result))
119          return result
120  
121      @property
122      def last_usage(self) -> tuple[int, int]:
123          """(input_tokens, output_tokens) from the most recent successful call."""
124          return (
125              getattr(self._tls, "input_tokens", 0),
126              getattr(self._tls, "output_tokens", 0),
127          )
128  
129      def ping(self, temperature: float | None = None) -> None:
130          """Send a minimal request to verify the provider and model are reachable."""
131          self.complete(
132              system="You are a helpful assistant.", user="Say: OK", temperature=temperature
133          )