/ integrations / openai.py
openai.py
  1  """OpenAI provider implementation of the LLMProvider protocol."""
  2  
  3  __all__ = ["OpenAIClient"]
  4  
  5  import logging
  6  import threading
  7  from base64 import b64encode
  8  from collections.abc import Sequence
  9  
 10  import openai
 11  from openai import OpenAI
 12  from pydantic import BaseModel
 13  from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
 14  
 15  from exceptions import APIError
 16  from models.llm import Attachment
 17  
 18  logger = logging.getLogger(__name__)
 19  
 20  
 21  def _is_transient(exc: BaseException) -> bool:
 22      if isinstance(exc, openai.APIStatusError):
 23          return exc.status_code == 429 or exc.status_code >= 500
 24      return isinstance(exc, (openai.APITimeoutError, openai.APIConnectionError))
 25  
 26  
 27  class OpenAIClient:
 28      """LLMProvider backed by the OpenAI SDK."""
 29  
 30      def __init__(self, api_key: str, model: str) -> None:
 31          self._client = OpenAI(api_key=api_key)
 32          self._model = model
 33          self._tls = threading.local()
 34  
 35      @retry(
 36          retry=retry_if_exception(_is_transient),
 37          wait=wait_exponential(multiplier=1, min=2, max=60),
 38          stop=stop_after_attempt(5),
 39          reraise=True,
 40      )
 41      def complete(
 42          self,
 43          system: str,
 44          user: str,
 45          *,
 46          temperature: float | None = None,
 47          seed: int | None = None,
 48          response_schema: type[BaseModel] | None = None,
 49          attachments: Sequence[Attachment] | None = None,
 50      ) -> str:
 51          """Send a prompt to OpenAI and return the response.
 52  
 53          Args:
 54              system: System prompt text.
 55              user: User message text.
 56              temperature: Sampling temperature, or None to use the model's default.
 57              seed: Passed to the API when provided.
 58              response_schema: If provided, uses beta structured output parsing.
 59              attachments: Binary files sent as inline file content blocks.
 60  
 61          Returns:
 62              Response text or JSON string.
 63  
 64          Raises:
 65              APIError: On non-retriable HTTP errors (4xx).
 66          """
 67          logger.debug("Calling openai model=%s", self._model)
 68  
 69          content: str | list[dict[str, object]]
 70          if attachments:
 71              blocks: list[dict[str, object]] = [
 72                  {
 73                      "type": "file",
 74                      "file": {
 75                          # filename is required by the API for format detection;
 76                      # the MIME type in file_data alone is not sufficient.
 77                      "filename": "document.pdf",
 78                          "file_data": f"data:{a.media_type};base64,{b64encode(a.data).decode()}",
 79                      },
 80                  }
 81                  for a in attachments
 82              ]
 83              blocks.append({"type": "text", "text": user})
 84              content = blocks
 85          else:
 86              content = user
 87  
 88          messages = [
 89              {"role": "system", "content": system},
 90              {"role": "user", "content": content},
 91          ]
 92  
 93          try:
 94              if response_schema is not None:
 95                  kwargs: dict[str, object] = {
 96                      "model": self._model,
 97                      "messages": messages,
 98                      "response_format": response_schema,
 99                  }
100                  if temperature is not None:
101                      kwargs["temperature"] = temperature
102                  if seed is not None:
103                      kwargs["seed"] = seed
104                  completion = self._client.beta.chat.completions.parse(**kwargs)  # type: ignore[arg-type]
105                  parsed = completion.choices[0].message.parsed
106                  result = parsed.model_dump_json()  # type: ignore[attr-defined]
107              else:
108                  kwargs = {
109                      "model": self._model,
110                      "messages": messages,
111                  }
112                  if temperature is not None:
113                      kwargs["temperature"] = temperature
114                  if seed is not None:
115                      kwargs["seed"] = seed
116                  completion = self._client.chat.completions.create(**kwargs)  # type: ignore[call-overload]
117                  result = completion.choices[0].message.content or ""
118          except openai.APIStatusError as exc:
119              if exc.status_code == 429 or exc.status_code >= 500:
120                  raise
121              raise APIError("openai", exc.status_code, str(exc)) from exc
122          except (openai.APITimeoutError, openai.APIConnectionError):
123              raise
124  
125          self._tls.input_tokens = completion.usage.prompt_tokens if completion.usage else 0
126          self._tls.output_tokens = (
127              completion.usage.completion_tokens if completion.usage else 0
128          )
129  
130          logger.info("Response received (%d chars)", len(result))
131          return result  # type: ignore[no-any-return]
132  
133      @property
134      def last_usage(self) -> tuple[int, int]:
135          """(input_tokens, output_tokens) from the most recent successful call."""
136          return (
137              getattr(self._tls, "input_tokens", 0),
138              getattr(self._tls, "output_tokens", 0),
139          )
140  
141      def ping(self, temperature: float | None = None) -> None:
142          """Send a minimal request to verify the provider and model are reachable."""
143          self.complete(
144              system="You are a helpful assistant.", user="Say: OK", temperature=temperature
145          )