/ integrations / grok.py
grok.py
1 """Grok (xAI) provider implementation of the LLMProvider protocol.""" 2 3 __all__ = ["GrokClient"] 4 5 import logging 6 import threading 7 from collections.abc import Sequence 8 9 import openai 10 from openai import OpenAI 11 from pydantic import BaseModel 12 from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential 13 14 from exceptions import APIError, UsageError 15 from models.llm import Attachment 16 17 logger = logging.getLogger(__name__) 18 19 _GROK_BASE_URL = "https://api.x.ai/v1" 20 21 22 def _is_transient(exc: BaseException) -> bool: 23 if isinstance(exc, openai.APIStatusError): 24 return exc.status_code == 429 or exc.status_code >= 500 25 return isinstance(exc, (openai.APITimeoutError, openai.APIConnectionError)) 26 27 28 class GrokClient: 29 """LLMProvider backed by the OpenAI SDK pointed at the xAI (Grok) base URL.""" 30 31 def __init__(self, api_key: str, model: str) -> None: 32 self._client = OpenAI(api_key=api_key, base_url=_GROK_BASE_URL) 33 self._model = model 34 self._tls = threading.local() 35 36 @retry( 37 retry=retry_if_exception(_is_transient), 38 wait=wait_exponential(multiplier=1, min=2, max=60), 39 stop=stop_after_attempt(5), 40 reraise=True, 41 ) 42 def complete( 43 self, 44 system: str, 45 user: str, 46 *, 47 temperature: float | None = None, 48 seed: int | None = None, 49 response_schema: type[BaseModel] | None = None, 50 attachments: Sequence[Attachment] | None = None, 51 ) -> str: 52 """Send a prompt to Grok (xAI) and return the response. 53 54 Args: 55 system: System prompt text. 56 user: User message text. 57 temperature: Sampling temperature. 58 seed: Passed to the API when provided. 59 response_schema: If provided, requests structured output via the beta parse endpoint. 60 attachments: Not yet implemented — raises ``APIError`` if provided. 61 62 Returns: 63 Response text or JSON string. 64 65 Raises: 66 APIError: On non-retriable HTTP errors (4xx). 67 UsageError: If attachments are provided (not yet implemented for Grok). 68 """ 69 if attachments is not None: 70 raise UsageError( 71 "Grok file attachment support not yet implemented" 72 ) 73 74 logger.debug("Calling grok model=%s", self._model) 75 76 messages = [ 77 {"role": "system", "content": system}, 78 {"role": "user", "content": user}, 79 ] 80 81 try: 82 if response_schema is not None: 83 kwargs: dict[str, object] = { 84 "model": self._model, 85 "messages": messages, 86 "response_format": response_schema, 87 } 88 if temperature is not None: 89 kwargs["temperature"] = temperature 90 if seed is not None: 91 kwargs["seed"] = seed 92 completion = self._client.beta.chat.completions.parse(**kwargs) # type: ignore[arg-type] 93 parsed = completion.choices[0].message.parsed 94 result = parsed.model_dump_json() # type: ignore[attr-defined] 95 else: 96 kwargs = { 97 "model": self._model, 98 "messages": messages, 99 } 100 if temperature is not None: 101 kwargs["temperature"] = temperature 102 if seed is not None: 103 kwargs["seed"] = seed 104 completion = self._client.chat.completions.create(**kwargs) # type: ignore[call-overload] 105 result = completion.choices[0].message.content or "" 106 except openai.APIStatusError as exc: 107 if exc.status_code == 429 or exc.status_code >= 500: 108 raise 109 raise APIError("grok", exc.status_code, str(exc)) from exc 110 except (openai.APITimeoutError, openai.APIConnectionError): 111 raise 112 113 self._tls.input_tokens = completion.usage.prompt_tokens if completion.usage else 0 114 self._tls.output_tokens = ( 115 completion.usage.completion_tokens if completion.usage else 0 116 ) 117 118 logger.info("Response received (%d chars)", len(result)) 119 return result # type: ignore[no-any-return] 120 121 @property 122 def last_usage(self) -> tuple[int, int]: 123 """(input_tokens, output_tokens) from the most recent successful call.""" 124 return ( 125 getattr(self._tls, "input_tokens", 0), 126 getattr(self._tls, "output_tokens", 0), 127 ) 128 129 def ping(self, temperature: float | None = None) -> None: 130 """Send a minimal request to verify the provider and model are reachable.""" 131 self.complete( 132 system="You are a helpful assistant.", user="Say: OK", temperature=temperature 133 )