/ integrations / llm.py
llm.py
  1  """LLM provider Protocol and factory function."""
  2  
  3  __all__ = ["LLMProvider", "create_llm_provider"]
  4  
  5  from collections.abc import Sequence
  6  from typing import TYPE_CHECKING, Protocol
  7  
  8  from pydantic import BaseModel
  9  
 10  from exceptions import ConfigurationError
 11  from models.llm import Attachment
 12  
 13  if TYPE_CHECKING:
 14      from config import Settings
 15  
 16  
 17  class LLMProvider(Protocol):
 18      """Common interface for all LLM provider implementations."""
 19  
 20      def complete(
 21          self,
 22          system: str,
 23          user: str,
 24          *,
 25          temperature: float | None = None,
 26          seed: int | None = None,
 27          response_schema: type[BaseModel] | None = None,
 28          attachments: Sequence[Attachment] | None = None,
 29      ) -> str:
 30          """Send a prompt and return a response string.
 31  
 32          Args:
 33              system: System prompt text.
 34              user: User message text.
 35              temperature: Sampling temperature, or None to use the model's default.
 36              seed: Optional reproducibility seed (ignored by providers that don't support it).
 37              response_schema: If provided, the provider uses its native structured output
 38                  mechanism and returns a JSON string conforming to this schema.
 39              attachments: Binary files to include in the prompt. Each provider
 40                  translates attachments into its native content block format.
 41                  Providers that do not support attachments raise ``UsageError``.
 42  
 43          Returns:
 44              Response text, or a JSON string when response_schema is provided.
 45          """
 46          ...
 47  
 48      @property
 49      def last_usage(self) -> tuple[int, int]:
 50          """(input_tokens, output_tokens) from the most recent successful complete() call
 51          on the current thread.
 52  
 53          Returns (0, 0) before the first successful call on this thread.
 54          Unchanged after a failed call — retains the previous successful call's counts.
 55          Thread-safe: each thread reads only its own counts.
 56          """
 57          ...
 58  
 59      def ping(self, temperature: float | None = None) -> None:
 60          """Send a minimal request to verify the provider and model are reachable.
 61  
 62          Args:
 63              temperature: Sampling temperature to use, or None to use the model's default.
 64  
 65          Raises:
 66              APIError: If the provider returns a 4xx error (bad model, bad key).
 67          """
 68          ...
 69  
 70  
 71  def create_llm_provider(
 72      provider: str, model: str, settings: "Settings", *, max_tokens: int = 1024
 73  ) -> LLMProvider:
 74      """Instantiate the correct LLM provider implementation from settings.
 75  
 76      Args:
 77          provider: One of "anthropic", "openai", "grok", "gemini", "deepseek", "mistral".
 78          model: Model ID to use for this provider.
 79          settings: Application settings instance (used for API keys).
 80          max_tokens: Maximum tokens in the response. Required by Anthropic; passed
 81              as an optional cap by Mistral. Ignored by other providers.
 82  
 83      Returns:
 84          An object implementing the LLMProvider protocol.
 85  
 86      Raises:
 87          ConfigurationError: If the required API key for the selected provider
 88              is absent from settings.
 89          ConfigurationError: If the provider name is not recognised.
 90      """
 91      _PROVIDERS: dict[str, tuple[str, str]] = {
 92          "anthropic": ("anthropic_api_key", "ANTHROPIC_API_KEY"),
 93          "openai": ("openai_api_key", "OPENAI_API_KEY"),
 94          "grok": ("grok_api_key", "XAI_API_KEY"),
 95          "gemini": ("gemini_api_key", "GEMINI_API_KEY"),
 96          "deepseek": ("deepseek_api_key", "DEEPSEEK_API_KEY"),
 97          "mistral": ("mistral_api_key", "MISTRAL_API_KEY"),
 98      }
 99  
100      if provider not in _PROVIDERS:
101          raise ConfigurationError(
102              f"llm_provider={provider!r} — valid values: {sorted(_PROVIDERS)}"
103          )
104  
105      key_attr, key_env = _PROVIDERS[provider]
106      api_key: str | None = getattr(settings, key_attr, None)
107      if not api_key:
108          raise ConfigurationError(key_env)
109  
110      if provider == "anthropic":
111          from integrations.anthropic import AnthropicClient
112          return AnthropicClient(api_key=api_key, model=model, max_tokens=max_tokens)
113      elif provider == "openai":
114          from integrations.openai import OpenAIClient
115          return OpenAIClient(api_key=api_key, model=model)
116      elif provider == "grok":
117          from integrations.grok import GrokClient
118          return GrokClient(api_key=api_key, model=model)
119      elif provider == "gemini":
120          from integrations.gemini import GeminiClient
121          return GeminiClient(api_key=api_key, model=model)
122      elif provider == "deepseek":
123          from integrations.deepseek import DeepSeekClient
124          return DeepSeekClient(api_key=api_key, model=model)
125      else:  # mistral
126          from integrations.mistral import MistralClient
127          return MistralClient(api_key=api_key, model=model, max_tokens=max_tokens)