/ integrations / mistral.py
mistral.py
1 """Mistral provider implementation of the LLMProvider protocol.""" 2 3 __all__ = ["MistralClient"] 4 5 import logging 6 import threading 7 from collections.abc import Sequence 8 9 from mistralai.client import Mistral 10 from mistralai.client.errors import MistralError, NoResponseError 11 from mistralai.client.types import UNSET 12 from pydantic import BaseModel 13 from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential 14 15 from exceptions import APIError, ParseError, UsageError 16 from models.llm import Attachment 17 18 logger = logging.getLogger(__name__) 19 20 _TIMEOUT_MS = 120_000 # 2 minutes — matches typical LLM response latency ceiling 21 22 23 def _is_transient(exc: BaseException) -> bool: 24 if isinstance(exc, MistralError): 25 # 429 is retried explicitly: unlike the OpenAI/Anthropic/Gemini SDKs, the native 26 # mistralai SDK has no built-in retry logic — 429 must be handled at this layer. 27 return exc.status_code >= 500 or exc.status_code == 429 28 return isinstance(exc, NoResponseError) 29 30 31 class MistralClient: 32 """LLMProvider backed by the native mistralai SDK.""" 33 34 def __init__(self, api_key: str, model: str, max_tokens: int | None = None) -> None: 35 self._client = Mistral(api_key=api_key, timeout_ms=_TIMEOUT_MS) 36 self._model = model 37 self._max_tokens = max_tokens 38 self._tls = threading.local() 39 40 @retry( 41 retry=retry_if_exception(_is_transient), 42 wait=wait_exponential(multiplier=1, min=2, max=60), 43 stop=stop_after_attempt(5), 44 reraise=True, 45 ) 46 def complete( 47 self, 48 system: str, 49 user: str, 50 *, 51 temperature: float | None = None, 52 seed: int | None = None, 53 response_schema: type[BaseModel] | None = None, 54 attachments: Sequence[Attachment] | None = None, 55 ) -> str: 56 """Send a prompt to Mistral and return the response. 57 58 Args: 59 system: System prompt text. 60 user: User message text. 61 temperature: Sampling temperature, or None to use the model's default. 62 seed: Passed as random_seed when provided (Mistral's parameter name). 63 response_schema: If provided, requests structured output via chat.parse. 64 attachments: Not yet implemented — raises ``APIError`` if provided. 65 66 Returns: 67 Response text or JSON string. 68 69 Raises: 70 APIError: On non-retriable HTTP errors (4xx, excluding 429). 71 ParseError: When structured output is requested but the model returns no parsed content. 72 UsageError: If attachments are provided (not yet implemented for Mistral). 73 """ 74 if attachments is not None: 75 raise UsageError( 76 "Mistral file attachment support not yet implemented" 77 ) 78 79 logger.debug("Calling mistral model=%s", self._model) 80 81 messages = [ 82 {"role": "system", "content": system}, 83 {"role": "user", "content": user}, 84 ] 85 86 # Use UNSET (not None) for optional params so the SDK omits them from the 87 # JSON payload entirely — passing None would serialize as null. 88 max_tokens = self._max_tokens if self._max_tokens is not None else UNSET 89 temp = temperature if temperature is not None else UNSET 90 random_seed = seed if seed is not None else UNSET 91 92 try: 93 if response_schema is not None: 94 response = self._client.chat.parse( 95 response_schema, 96 model=self._model, 97 messages=messages, 98 max_tokens=max_tokens, 99 temperature=temp, 100 random_seed=random_seed, 101 ) 102 parsed = response.choices[0].message.parsed 103 if parsed is None: 104 raise ParseError("mistral chat.parse", "model returned no parsed content") 105 result = str(parsed.model_dump_json()) 106 else: 107 response = self._client.chat.complete( 108 model=self._model, 109 messages=messages, # type: ignore[arg-type] 110 max_tokens=max_tokens, 111 temperature=temp, 112 random_seed=random_seed, 113 ) 114 result = str(response.choices[0].message.content or "") 115 except MistralError as exc: 116 if exc.status_code < 500 and exc.status_code != 429: 117 raise APIError("mistral", exc.status_code, str(exc)) from exc 118 raise 119 except NoResponseError: 120 raise 121 122 self._tls.input_tokens = response.usage.prompt_tokens or 0 123 self._tls.output_tokens = response.usage.completion_tokens or 0 124 125 logger.info("Response received (%d chars)", len(result)) 126 return result 127 128 @property 129 def last_usage(self) -> tuple[int, int]: 130 """(input_tokens, output_tokens) from the most recent successful call.""" 131 return ( 132 getattr(self._tls, "input_tokens", 0), 133 getattr(self._tls, "output_tokens", 0), 134 ) 135 136 def ping(self, temperature: float | None = None) -> None: 137 """Send a minimal request to verify the provider and model are reachable.""" 138 self.complete( 139 system="You are a helpful assistant.", user="Say: OK", temperature=temperature 140 )