openai.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import os
  6  from typing import Any
  7  
  8  from openai import OpenAI, Stream
  9  from openai.types.chat import ChatCompletion, ChatCompletionChunk
 10  
 11  from haystack import component, default_from_dict, default_to_dict, logging
 12  from haystack.components.generators.chat.openai import (
 13      _check_finish_reason,
 14      _convert_chat_completion_chunk_to_streaming_chunk,
 15      _convert_chat_completion_to_chat_message,
 16  )
 17  from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message
 18  from haystack.dataclasses import (
 19      ChatMessage,
 20      ComponentInfo,
 21      StreamingCallbackT,
 22      StreamingChunk,
 23      select_streaming_callback,
 24  )
 25  from haystack.utils import Secret, deserialize_callable, serialize_callable
 26  from haystack.utils.http_client import init_http_client
 27  
 28  logger = logging.getLogger(__name__)
 29  
 30  
 31  @component
 32  class OpenAIGenerator:
 33      """
 34      Generates text using OpenAI's large language models (LLMs).
 35  
 36      It works with the gpt-4 and gpt-5 series models and supports streaming responses
 37      from OpenAI API. It uses strings as input and output.
 38  
 39      You can customize how the text is generated by passing parameters to the
 40      OpenAI API. Use the `**generation_kwargs` argument when you initialize
 41      the component or when you run it. Any parameter that works with
 42      `openai.ChatCompletion.create` will work here too.
 43  
 44  
 45      For details on OpenAI API parameters, see
 46      [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat).
 47  
 48      ### Usage example
 49  
 50      ```python
 51      from haystack.components.generators import OpenAIGenerator
 52      client = OpenAIGenerator()
 53      response = client.run("What's Natural Language Processing? Be brief.")
 54      print(response)
 55  
 56      # >> {'replies': ['Natural Language Processing (NLP) is a branch of artificial intelligence that focuses on
 57      # >> the interaction between computers and human language. It involves enabling computers to understand, interpret,
 58      # >> and respond to natural human language in a way that is both meaningful and useful.'], 'meta': [{'model':
 59      # >> 'gpt-5-mini', 'index': 0, 'finish_reason': 'stop', 'usage': {'prompt_tokens': 16,
 60      # >> 'completion_tokens': 49, 'total_tokens': 65}}]}
 61      ```
 62      """
 63  
 64      def __init__(
 65          self,
 66          api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
 67          model: str = "gpt-5-mini",
 68          streaming_callback: StreamingCallbackT | None = None,
 69          api_base_url: str | None = None,
 70          organization: str | None = None,
 71          system_prompt: str | None = None,
 72          generation_kwargs: dict[str, Any] | None = None,
 73          timeout: float | None = None,
 74          max_retries: int | None = None,
 75          http_client_kwargs: dict[str, Any] | None = None,
 76      ) -> None:
 77          """
 78          Creates an instance of OpenAIGenerator. Unless specified otherwise in `model`, uses OpenAI's gpt-5-mini
 79  
 80          By setting the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES' you can change the timeout and max_retries parameters
 81          in the OpenAI client.
 82  
 83          :param api_key: The OpenAI API key to connect to OpenAI.
 84          :param model: The name of the model to use.
 85          :param streaming_callback: A callback function that is called when a new token is received from the stream.
 86              The callback function accepts StreamingChunk as an argument.
 87          :param api_base_url: An optional base URL.
 88          :param organization: The Organization ID, defaults to `None`.
 89          :param system_prompt: The system prompt to use for text generation. If not provided, the system prompt is
 90          omitted, and the default system prompt of the model is used.
 91          :param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to
 92              the OpenAI endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for
 93              more details.
 94              Some of the supported parameters:
 95              - `max_completion_tokens`: An upper bound for the number of tokens that can be generated for a completion,
 96                  including visible output tokens and reasoning tokens.
 97              - `temperature`: What sampling temperature to use. Higher values mean the model will take more risks.
 98                  Try 0.9 for more creative applications and 0 (argmax sampling) for ones with a well-defined answer.
 99              - `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model
100                  considers the results of the tokens with top_p probability mass. So, 0.1 means only the tokens
101                  comprising the top 10% probability mass are considered.
102              - `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2,
103                  it will generate two completions for each of the three prompts, ending up with 6 completions in total.
104              - `stop`: One or more sequences after which the LLM should stop generating tokens.
105              - `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean
106                  the model will be less likely to repeat the same token in the text.
107              - `frequency_penalty`: What penalty to apply if a token has already been generated in the text.
108                  Bigger values mean the model will be less likely to repeat the same token in the text.
109              - `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
110                  values are the bias to add to that token.
111          :param timeout:
112              Timeout for OpenAI Client calls, if not set it is inferred from the `OPENAI_TIMEOUT` environment variable
113              or set to 30.
114          :param max_retries:
115              Maximum retries to establish contact with OpenAI if it returns an internal error, if not set it is inferred
116              from the `OPENAI_MAX_RETRIES` environment variable or set to 5.
117          :param http_client_kwargs:
118              A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
119              For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
120          """
121          self.api_key = api_key
122          self.model = model
123          self.generation_kwargs = generation_kwargs or {}
124          self.system_prompt = system_prompt
125          self.streaming_callback = streaming_callback
126  
127          self.api_base_url = api_base_url
128          self.organization = organization
129          self.http_client_kwargs = http_client_kwargs
130  
131          if timeout is None:
132              timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
133          if max_retries is None:
134              max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
135  
136          self.client = OpenAI(
137              api_key=api_key.resolve_value(),
138              organization=organization,
139              base_url=api_base_url,
140              timeout=timeout,
141              max_retries=max_retries,
142              http_client=init_http_client(self.http_client_kwargs, async_client=False),
143          )
144  
145      def _get_telemetry_data(self) -> dict[str, Any]:
146          """
147          Data that is sent to Posthog for usage analytics.
148          """
149          return {"model": self.model}
150  
151      def to_dict(self) -> dict[str, Any]:
152          """
153          Serialize this component to a dictionary.
154  
155          :returns:
156              The serialized component as a dictionary.
157          """
158          callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
159          return default_to_dict(
160              self,
161              model=self.model,
162              streaming_callback=callback_name,
163              api_base_url=self.api_base_url,
164              organization=self.organization,
165              generation_kwargs=self.generation_kwargs,
166              system_prompt=self.system_prompt,
167              api_key=self.api_key,
168              http_client_kwargs=self.http_client_kwargs,
169          )
170  
171      @classmethod
172      def from_dict(cls, data: dict[str, Any]) -> "OpenAIGenerator":
173          """
174          Deserialize this component from a dictionary.
175  
176          :param data:
177              The dictionary representation of this component.
178          :returns:
179              The deserialized component instance.
180          """
181          init_params = data.get("init_parameters", {})
182          serialized_callback_handler = init_params.get("streaming_callback")
183          if serialized_callback_handler:
184              data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
185          return default_from_dict(cls, data)
186  
187      @component.output_types(replies=list[str], meta=list[dict[str, Any]])
188      def run(
189          self,
190          prompt: str,
191          system_prompt: str | None = None,
192          streaming_callback: StreamingCallbackT | None = None,
193          generation_kwargs: dict[str, Any] | None = None,
194      ) -> dict[str, list[str] | list[dict[str, Any]]]:
195          """
196          Invoke the text generation inference based on the provided messages and generation parameters.
197  
198          :param prompt:
199              The string prompt to use for text generation.
200          :param system_prompt:
201              The system prompt to use for text generation. If this run time system prompt is omitted, the system
202              prompt, if defined at initialisation time, is used.
203          :param streaming_callback:
204              A callback function that is called when a new token is received from the stream.
205          :param generation_kwargs:
206              Additional keyword arguments for text generation. These parameters will potentially override the parameters
207              passed in the `__init__` method. For more details on the parameters supported by the OpenAI API, refer to
208              the OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat/create).
209          :returns:
210              A list of strings containing the generated responses and a list of dictionaries containing the metadata
211          for each response.
212          """
213          message = ChatMessage.from_user(prompt)
214          if system_prompt is not None:
215              messages = [ChatMessage.from_system(system_prompt), message]
216          elif self.system_prompt:
217              messages = [ChatMessage.from_system(self.system_prompt), message]
218          else:
219              messages = [message]
220  
221          # update generation kwargs by merging with the generation kwargs passed to the run method
222          generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
223  
224          # check if streaming_callback is passed
225          streaming_callback = select_streaming_callback(
226              init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
227          )
228  
229          # adapt ChatMessage(s) to the format expected by the OpenAI API
230          openai_formatted_messages = [message.to_openai_dict_format() for message in messages]
231  
232          completion: Stream[ChatCompletionChunk] | ChatCompletion = self.client.chat.completions.create(
233              model=self.model,
234              messages=openai_formatted_messages,  # type: ignore
235              stream=streaming_callback is not None,
236              **generation_kwargs,
237          )
238  
239          completions: list[ChatMessage] = []
240          if streaming_callback is not None:
241              num_responses = generation_kwargs.pop("n", 1)
242              if num_responses > 1:
243                  raise ValueError("Cannot stream multiple responses, please set n=1.")
244  
245              component_info = ComponentInfo.from_component(self)
246              chunks: list[StreamingChunk] = []
247              for chunk in completion:
248                  chunk_delta: StreamingChunk = _convert_chat_completion_chunk_to_streaming_chunk(
249                      chunk=chunk,  # type: ignore
250                      previous_chunks=chunks,
251                      component_info=component_info,
252                  )
253                  chunks.append(chunk_delta)
254                  streaming_callback(chunk_delta)
255  
256              completions = [_convert_streaming_chunks_to_chat_message(chunks=chunks)]
257          elif isinstance(completion, ChatCompletion):
258              completions = [
259                  _convert_chat_completion_to_chat_message(completion=completion, choice=choice)
260                  for choice in completion.choices
261              ]
262  
263          # before returning, do post-processing of the completions
264          for response in completions:
265              _check_finish_reason(response.meta)
266  
267          return {
268              "replies": [message.text or "" for message in completions],
269              "meta": [message.meta for message in completions],
270          }