/ integrations / openai.py
openai.py
1 """OpenAI provider implementation of the LLMProvider protocol.""" 2 3 __all__ = ["OpenAIClient"] 4 5 import logging 6 import threading 7 from base64 import b64encode 8 from collections.abc import Sequence 9 10 import openai 11 from openai import OpenAI 12 from pydantic import BaseModel 13 from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential 14 15 from exceptions import APIError 16 from models.llm import Attachment 17 18 logger = logging.getLogger(__name__) 19 20 21 def _is_transient(exc: BaseException) -> bool: 22 if isinstance(exc, openai.APIStatusError): 23 return exc.status_code == 429 or exc.status_code >= 500 24 return isinstance(exc, (openai.APITimeoutError, openai.APIConnectionError)) 25 26 27 class OpenAIClient: 28 """LLMProvider backed by the OpenAI SDK.""" 29 30 def __init__(self, api_key: str, model: str) -> None: 31 self._client = OpenAI(api_key=api_key) 32 self._model = model 33 self._tls = threading.local() 34 35 @retry( 36 retry=retry_if_exception(_is_transient), 37 wait=wait_exponential(multiplier=1, min=2, max=60), 38 stop=stop_after_attempt(5), 39 reraise=True, 40 ) 41 def complete( 42 self, 43 system: str, 44 user: str, 45 *, 46 temperature: float | None = None, 47 seed: int | None = None, 48 response_schema: type[BaseModel] | None = None, 49 attachments: Sequence[Attachment] | None = None, 50 ) -> str: 51 """Send a prompt to OpenAI and return the response. 52 53 Args: 54 system: System prompt text. 55 user: User message text. 56 temperature: Sampling temperature, or None to use the model's default. 57 seed: Passed to the API when provided. 58 response_schema: If provided, uses beta structured output parsing. 59 attachments: Binary files sent as inline file content blocks. 60 61 Returns: 62 Response text or JSON string. 63 64 Raises: 65 APIError: On non-retriable HTTP errors (4xx). 66 """ 67 logger.debug("Calling openai model=%s", self._model) 68 69 content: str | list[dict[str, object]] 70 if attachments: 71 blocks: list[dict[str, object]] = [ 72 { 73 "type": "file", 74 "file": { 75 # filename is required by the API for format detection; 76 # the MIME type in file_data alone is not sufficient. 77 "filename": "document.pdf", 78 "file_data": f"data:{a.media_type};base64,{b64encode(a.data).decode()}", 79 }, 80 } 81 for a in attachments 82 ] 83 blocks.append({"type": "text", "text": user}) 84 content = blocks 85 else: 86 content = user 87 88 messages = [ 89 {"role": "system", "content": system}, 90 {"role": "user", "content": content}, 91 ] 92 93 try: 94 if response_schema is not None: 95 kwargs: dict[str, object] = { 96 "model": self._model, 97 "messages": messages, 98 "response_format": response_schema, 99 } 100 if temperature is not None: 101 kwargs["temperature"] = temperature 102 if seed is not None: 103 kwargs["seed"] = seed 104 completion = self._client.beta.chat.completions.parse(**kwargs) # type: ignore[arg-type] 105 parsed = completion.choices[0].message.parsed 106 result = parsed.model_dump_json() # type: ignore[attr-defined] 107 else: 108 kwargs = { 109 "model": self._model, 110 "messages": messages, 111 } 112 if temperature is not None: 113 kwargs["temperature"] = temperature 114 if seed is not None: 115 kwargs["seed"] = seed 116 completion = self._client.chat.completions.create(**kwargs) # type: ignore[call-overload] 117 result = completion.choices[0].message.content or "" 118 except openai.APIStatusError as exc: 119 if exc.status_code == 429 or exc.status_code >= 500: 120 raise 121 raise APIError("openai", exc.status_code, str(exc)) from exc 122 except (openai.APITimeoutError, openai.APIConnectionError): 123 raise 124 125 self._tls.input_tokens = completion.usage.prompt_tokens if completion.usage else 0 126 self._tls.output_tokens = ( 127 completion.usage.completion_tokens if completion.usage else 0 128 ) 129 130 logger.info("Response received (%d chars)", len(result)) 131 return result # type: ignore[no-any-return] 132 133 @property 134 def last_usage(self) -> tuple[int, int]: 135 """(input_tokens, output_tokens) from the most recent successful call.""" 136 return ( 137 getattr(self._tls, "input_tokens", 0), 138 getattr(self._tls, "output_tokens", 0), 139 ) 140 141 def ping(self, temperature: float | None = None) -> None: 142 """Send a minimal request to verify the provider and model are reachable.""" 143 self.complete( 144 system="You are a helpful assistant.", user="Say: OK", temperature=temperature 145 )