/ integrations / grok.py
grok.py
  1  """Grok (xAI) provider implementation of the LLMProvider protocol."""
  2  
  3  __all__ = ["GrokClient"]
  4  
  5  import logging
  6  import threading
  7  from collections.abc import Sequence
  8  
  9  import openai
 10  from openai import OpenAI
 11  from pydantic import BaseModel
 12  from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
 13  
 14  from exceptions import APIError, UsageError
 15  from models.llm import Attachment
 16  
 17  logger = logging.getLogger(__name__)
 18  
 19  _GROK_BASE_URL = "https://api.x.ai/v1"
 20  
 21  
 22  def _is_transient(exc: BaseException) -> bool:
 23      if isinstance(exc, openai.APIStatusError):
 24          return exc.status_code == 429 or exc.status_code >= 500
 25      return isinstance(exc, (openai.APITimeoutError, openai.APIConnectionError))
 26  
 27  
 28  class GrokClient:
 29      """LLMProvider backed by the OpenAI SDK pointed at the xAI (Grok) base URL."""
 30  
 31      def __init__(self, api_key: str, model: str) -> None:
 32          self._client = OpenAI(api_key=api_key, base_url=_GROK_BASE_URL)
 33          self._model = model
 34          self._tls = threading.local()
 35  
 36      @retry(
 37          retry=retry_if_exception(_is_transient),
 38          wait=wait_exponential(multiplier=1, min=2, max=60),
 39          stop=stop_after_attempt(5),
 40          reraise=True,
 41      )
 42      def complete(
 43          self,
 44          system: str,
 45          user: str,
 46          *,
 47          temperature: float | None = None,
 48          seed: int | None = None,
 49          response_schema: type[BaseModel] | None = None,
 50          attachments: Sequence[Attachment] | None = None,
 51      ) -> str:
 52          """Send a prompt to Grok (xAI) and return the response.
 53  
 54          Args:
 55              system: System prompt text.
 56              user: User message text.
 57              temperature: Sampling temperature.
 58              seed: Passed to the API when provided.
 59              response_schema: If provided, requests structured output via the beta parse endpoint.
 60              attachments: Not yet implemented — raises ``APIError`` if provided.
 61  
 62          Returns:
 63              Response text or JSON string.
 64  
 65          Raises:
 66              APIError: On non-retriable HTTP errors (4xx).
 67              UsageError: If attachments are provided (not yet implemented for Grok).
 68          """
 69          if attachments is not None:
 70              raise UsageError(
 71                  "Grok file attachment support not yet implemented"
 72              )
 73  
 74          logger.debug("Calling grok model=%s", self._model)
 75  
 76          messages = [
 77              {"role": "system", "content": system},
 78              {"role": "user", "content": user},
 79          ]
 80  
 81          try:
 82              if response_schema is not None:
 83                  kwargs: dict[str, object] = {
 84                      "model": self._model,
 85                      "messages": messages,
 86                      "response_format": response_schema,
 87                  }
 88                  if temperature is not None:
 89                      kwargs["temperature"] = temperature
 90                  if seed is not None:
 91                      kwargs["seed"] = seed
 92                  completion = self._client.beta.chat.completions.parse(**kwargs)  # type: ignore[arg-type]
 93                  parsed = completion.choices[0].message.parsed
 94                  result = parsed.model_dump_json()  # type: ignore[attr-defined]
 95              else:
 96                  kwargs = {
 97                      "model": self._model,
 98                      "messages": messages,
 99                  }
100                  if temperature is not None:
101                      kwargs["temperature"] = temperature
102                  if seed is not None:
103                      kwargs["seed"] = seed
104                  completion = self._client.chat.completions.create(**kwargs)  # type: ignore[call-overload]
105                  result = completion.choices[0].message.content or ""
106          except openai.APIStatusError as exc:
107              if exc.status_code == 429 or exc.status_code >= 500:
108                  raise
109              raise APIError("grok", exc.status_code, str(exc)) from exc
110          except (openai.APITimeoutError, openai.APIConnectionError):
111              raise
112  
113          self._tls.input_tokens = completion.usage.prompt_tokens if completion.usage else 0
114          self._tls.output_tokens = (
115              completion.usage.completion_tokens if completion.usage else 0
116          )
117  
118          logger.info("Response received (%d chars)", len(result))
119          return result  # type: ignore[no-any-return]
120  
121      @property
122      def last_usage(self) -> tuple[int, int]:
123          """(input_tokens, output_tokens) from the most recent successful call."""
124          return (
125              getattr(self._tls, "input_tokens", 0),
126              getattr(self._tls, "output_tokens", 0),
127          )
128  
129      def ping(self, temperature: float | None = None) -> None:
130          """Send a minimal request to verify the provider and model are reachable."""
131          self.complete(
132              system="You are a helpful assistant.", user="Say: OK", temperature=temperature
133          )