/ integrations / mistral.py
mistral.py
  1  """Mistral provider implementation of the LLMProvider protocol."""
  2  
  3  __all__ = ["MistralClient"]
  4  
  5  import logging
  6  import threading
  7  from collections.abc import Sequence
  8  
  9  from mistralai.client import Mistral
 10  from mistralai.client.errors import MistralError, NoResponseError
 11  from mistralai.client.types import UNSET
 12  from pydantic import BaseModel
 13  from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
 14  
 15  from exceptions import APIError, ParseError, UsageError
 16  from models.llm import Attachment
 17  
 18  logger = logging.getLogger(__name__)
 19  
 20  _TIMEOUT_MS = 120_000  # 2 minutes — matches typical LLM response latency ceiling
 21  
 22  
 23  def _is_transient(exc: BaseException) -> bool:
 24      if isinstance(exc, MistralError):
 25          # 429 is retried explicitly: unlike the OpenAI/Anthropic/Gemini SDKs, the native
 26          # mistralai SDK has no built-in retry logic — 429 must be handled at this layer.
 27          return exc.status_code >= 500 or exc.status_code == 429
 28      return isinstance(exc, NoResponseError)
 29  
 30  
 31  class MistralClient:
 32      """LLMProvider backed by the native mistralai SDK."""
 33  
 34      def __init__(self, api_key: str, model: str, max_tokens: int | None = None) -> None:
 35          self._client = Mistral(api_key=api_key, timeout_ms=_TIMEOUT_MS)
 36          self._model = model
 37          self._max_tokens = max_tokens
 38          self._tls = threading.local()
 39  
 40      @retry(
 41          retry=retry_if_exception(_is_transient),
 42          wait=wait_exponential(multiplier=1, min=2, max=60),
 43          stop=stop_after_attempt(5),
 44          reraise=True,
 45      )
 46      def complete(
 47          self,
 48          system: str,
 49          user: str,
 50          *,
 51          temperature: float | None = None,
 52          seed: int | None = None,
 53          response_schema: type[BaseModel] | None = None,
 54          attachments: Sequence[Attachment] | None = None,
 55      ) -> str:
 56          """Send a prompt to Mistral and return the response.
 57  
 58          Args:
 59              system: System prompt text.
 60              user: User message text.
 61              temperature: Sampling temperature, or None to use the model's default.
 62              seed: Passed as random_seed when provided (Mistral's parameter name).
 63              response_schema: If provided, requests structured output via chat.parse.
 64              attachments: Not yet implemented — raises ``APIError`` if provided.
 65  
 66          Returns:
 67              Response text or JSON string.
 68  
 69          Raises:
 70              APIError: On non-retriable HTTP errors (4xx, excluding 429).
 71              ParseError: When structured output is requested but the model returns no parsed content.
 72              UsageError: If attachments are provided (not yet implemented for Mistral).
 73          """
 74          if attachments is not None:
 75              raise UsageError(
 76                  "Mistral file attachment support not yet implemented"
 77              )
 78  
 79          logger.debug("Calling mistral model=%s", self._model)
 80  
 81          messages = [
 82              {"role": "system", "content": system},
 83              {"role": "user", "content": user},
 84          ]
 85  
 86          # Use UNSET (not None) for optional params so the SDK omits them from the
 87          # JSON payload entirely — passing None would serialize as null.
 88          max_tokens = self._max_tokens if self._max_tokens is not None else UNSET
 89          temp = temperature if temperature is not None else UNSET
 90          random_seed = seed if seed is not None else UNSET
 91  
 92          try:
 93              if response_schema is not None:
 94                  response = self._client.chat.parse(
 95                      response_schema,
 96                      model=self._model,
 97                      messages=messages,
 98                      max_tokens=max_tokens,
 99                      temperature=temp,
100                      random_seed=random_seed,
101                  )
102                  parsed = response.choices[0].message.parsed
103                  if parsed is None:
104                      raise ParseError("mistral chat.parse", "model returned no parsed content")
105                  result = str(parsed.model_dump_json())
106              else:
107                  response = self._client.chat.complete(
108                      model=self._model,
109                      messages=messages,  # type: ignore[arg-type]
110                      max_tokens=max_tokens,
111                      temperature=temp,
112                      random_seed=random_seed,
113                  )
114                  result = str(response.choices[0].message.content or "")
115          except MistralError as exc:
116              if exc.status_code < 500 and exc.status_code != 429:
117                  raise APIError("mistral", exc.status_code, str(exc)) from exc
118              raise
119          except NoResponseError:
120              raise
121  
122          self._tls.input_tokens = response.usage.prompt_tokens or 0
123          self._tls.output_tokens = response.usage.completion_tokens or 0
124  
125          logger.info("Response received (%d chars)", len(result))
126          return result
127  
128      @property
129      def last_usage(self) -> tuple[int, int]:
130          """(input_tokens, output_tokens) from the most recent successful call."""
131          return (
132              getattr(self._tls, "input_tokens", 0),
133              getattr(self._tls, "output_tokens", 0),
134          )
135  
136      def ping(self, temperature: float | None = None) -> None:
137          """Send a minimal request to verify the provider and model are reachable."""
138          self.complete(
139              system="You are a helpful assistant.", user="Say: OK", temperature=temperature
140          )