/ integrations / llm.py
llm.py
1 """LLM provider Protocol and factory function.""" 2 3 __all__ = ["LLMProvider", "create_llm_provider"] 4 5 from collections.abc import Sequence 6 from typing import TYPE_CHECKING, Protocol 7 8 from pydantic import BaseModel 9 10 from exceptions import ConfigurationError 11 from models.llm import Attachment 12 13 if TYPE_CHECKING: 14 from config import Settings 15 16 17 class LLMProvider(Protocol): 18 """Common interface for all LLM provider implementations.""" 19 20 def complete( 21 self, 22 system: str, 23 user: str, 24 *, 25 temperature: float | None = None, 26 seed: int | None = None, 27 response_schema: type[BaseModel] | None = None, 28 attachments: Sequence[Attachment] | None = None, 29 ) -> str: 30 """Send a prompt and return a response string. 31 32 Args: 33 system: System prompt text. 34 user: User message text. 35 temperature: Sampling temperature, or None to use the model's default. 36 seed: Optional reproducibility seed (ignored by providers that don't support it). 37 response_schema: If provided, the provider uses its native structured output 38 mechanism and returns a JSON string conforming to this schema. 39 attachments: Binary files to include in the prompt. Each provider 40 translates attachments into its native content block format. 41 Providers that do not support attachments raise ``UsageError``. 42 43 Returns: 44 Response text, or a JSON string when response_schema is provided. 45 """ 46 ... 47 48 @property 49 def last_usage(self) -> tuple[int, int]: 50 """(input_tokens, output_tokens) from the most recent successful complete() call 51 on the current thread. 52 53 Returns (0, 0) before the first successful call on this thread. 54 Unchanged after a failed call — retains the previous successful call's counts. 55 Thread-safe: each thread reads only its own counts. 56 """ 57 ... 58 59 def ping(self, temperature: float | None = None) -> None: 60 """Send a minimal request to verify the provider and model are reachable. 61 62 Args: 63 temperature: Sampling temperature to use, or None to use the model's default. 64 65 Raises: 66 APIError: If the provider returns a 4xx error (bad model, bad key). 67 """ 68 ... 69 70 71 def create_llm_provider( 72 provider: str, model: str, settings: "Settings", *, max_tokens: int = 1024 73 ) -> LLMProvider: 74 """Instantiate the correct LLM provider implementation from settings. 75 76 Args: 77 provider: One of "anthropic", "openai", "grok", "gemini", "deepseek", "mistral". 78 model: Model ID to use for this provider. 79 settings: Application settings instance (used for API keys). 80 max_tokens: Maximum tokens in the response. Required by Anthropic; passed 81 as an optional cap by Mistral. Ignored by other providers. 82 83 Returns: 84 An object implementing the LLMProvider protocol. 85 86 Raises: 87 ConfigurationError: If the required API key for the selected provider 88 is absent from settings. 89 ConfigurationError: If the provider name is not recognised. 90 """ 91 _PROVIDERS: dict[str, tuple[str, str]] = { 92 "anthropic": ("anthropic_api_key", "ANTHROPIC_API_KEY"), 93 "openai": ("openai_api_key", "OPENAI_API_KEY"), 94 "grok": ("grok_api_key", "XAI_API_KEY"), 95 "gemini": ("gemini_api_key", "GEMINI_API_KEY"), 96 "deepseek": ("deepseek_api_key", "DEEPSEEK_API_KEY"), 97 "mistral": ("mistral_api_key", "MISTRAL_API_KEY"), 98 } 99 100 if provider not in _PROVIDERS: 101 raise ConfigurationError( 102 f"llm_provider={provider!r} — valid values: {sorted(_PROVIDERS)}" 103 ) 104 105 key_attr, key_env = _PROVIDERS[provider] 106 api_key: str | None = getattr(settings, key_attr, None) 107 if not api_key: 108 raise ConfigurationError(key_env) 109 110 if provider == "anthropic": 111 from integrations.anthropic import AnthropicClient 112 return AnthropicClient(api_key=api_key, model=model, max_tokens=max_tokens) 113 elif provider == "openai": 114 from integrations.openai import OpenAIClient 115 return OpenAIClient(api_key=api_key, model=model) 116 elif provider == "grok": 117 from integrations.grok import GrokClient 118 return GrokClient(api_key=api_key, model=model) 119 elif provider == "gemini": 120 from integrations.gemini import GeminiClient 121 return GeminiClient(api_key=api_key, model=model) 122 elif provider == "deepseek": 123 from integrations.deepseek import DeepSeekClient 124 return DeepSeekClient(api_key=api_key, model=model) 125 else: # mistral 126 from integrations.mistral import MistralClient 127 return MistralClient(api_key=api_key, model=model, max_tokens=max_tokens)