openai.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import os 6 from typing import Any 7 8 from openai import OpenAI, Stream 9 from openai.types.chat import ChatCompletion, ChatCompletionChunk 10 11 from haystack import component, default_from_dict, default_to_dict, logging 12 from haystack.components.generators.chat.openai import ( 13 _check_finish_reason, 14 _convert_chat_completion_chunk_to_streaming_chunk, 15 _convert_chat_completion_to_chat_message, 16 ) 17 from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message 18 from haystack.dataclasses import ( 19 ChatMessage, 20 ComponentInfo, 21 StreamingCallbackT, 22 StreamingChunk, 23 select_streaming_callback, 24 ) 25 from haystack.utils import Secret, deserialize_callable, serialize_callable 26 from haystack.utils.http_client import init_http_client 27 28 logger = logging.getLogger(__name__) 29 30 31 @component 32 class OpenAIGenerator: 33 """ 34 Generates text using OpenAI's large language models (LLMs). 35 36 It works with the gpt-4 and gpt-5 series models and supports streaming responses 37 from OpenAI API. It uses strings as input and output. 38 39 You can customize how the text is generated by passing parameters to the 40 OpenAI API. Use the `**generation_kwargs` argument when you initialize 41 the component or when you run it. Any parameter that works with 42 `openai.ChatCompletion.create` will work here too. 43 44 45 For details on OpenAI API parameters, see 46 [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat). 47 48 ### Usage example 49 50 ```python 51 from haystack.components.generators import OpenAIGenerator 52 client = OpenAIGenerator() 53 response = client.run("What's Natural Language Processing? Be brief.") 54 print(response) 55 56 # >> {'replies': ['Natural Language Processing (NLP) is a branch of artificial intelligence that focuses on 57 # >> the interaction between computers and human language. It involves enabling computers to understand, interpret, 58 # >> and respond to natural human language in a way that is both meaningful and useful.'], 'meta': [{'model': 59 # >> 'gpt-5-mini', 'index': 0, 'finish_reason': 'stop', 'usage': {'prompt_tokens': 16, 60 # >> 'completion_tokens': 49, 'total_tokens': 65}}]} 61 ``` 62 """ 63 64 def __init__( 65 self, 66 api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"), 67 model: str = "gpt-5-mini", 68 streaming_callback: StreamingCallbackT | None = None, 69 api_base_url: str | None = None, 70 organization: str | None = None, 71 system_prompt: str | None = None, 72 generation_kwargs: dict[str, Any] | None = None, 73 timeout: float | None = None, 74 max_retries: int | None = None, 75 http_client_kwargs: dict[str, Any] | None = None, 76 ) -> None: 77 """ 78 Creates an instance of OpenAIGenerator. Unless specified otherwise in `model`, uses OpenAI's gpt-5-mini 79 80 By setting the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES' you can change the timeout and max_retries parameters 81 in the OpenAI client. 82 83 :param api_key: The OpenAI API key to connect to OpenAI. 84 :param model: The name of the model to use. 85 :param streaming_callback: A callback function that is called when a new token is received from the stream. 86 The callback function accepts StreamingChunk as an argument. 87 :param api_base_url: An optional base URL. 88 :param organization: The Organization ID, defaults to `None`. 89 :param system_prompt: The system prompt to use for text generation. If not provided, the system prompt is 90 omitted, and the default system prompt of the model is used. 91 :param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to 92 the OpenAI endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for 93 more details. 94 Some of the supported parameters: 95 - `max_completion_tokens`: An upper bound for the number of tokens that can be generated for a completion, 96 including visible output tokens and reasoning tokens. 97 - `temperature`: What sampling temperature to use. Higher values mean the model will take more risks. 98 Try 0.9 for more creative applications and 0 (argmax sampling) for ones with a well-defined answer. 99 - `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model 100 considers the results of the tokens with top_p probability mass. So, 0.1 means only the tokens 101 comprising the top 10% probability mass are considered. 102 - `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2, 103 it will generate two completions for each of the three prompts, ending up with 6 completions in total. 104 - `stop`: One or more sequences after which the LLM should stop generating tokens. 105 - `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean 106 the model will be less likely to repeat the same token in the text. 107 - `frequency_penalty`: What penalty to apply if a token has already been generated in the text. 108 Bigger values mean the model will be less likely to repeat the same token in the text. 109 - `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the 110 values are the bias to add to that token. 111 :param timeout: 112 Timeout for OpenAI Client calls, if not set it is inferred from the `OPENAI_TIMEOUT` environment variable 113 or set to 30. 114 :param max_retries: 115 Maximum retries to establish contact with OpenAI if it returns an internal error, if not set it is inferred 116 from the `OPENAI_MAX_RETRIES` environment variable or set to 5. 117 :param http_client_kwargs: 118 A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`. 119 For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client). 120 """ 121 self.api_key = api_key 122 self.model = model 123 self.generation_kwargs = generation_kwargs or {} 124 self.system_prompt = system_prompt 125 self.streaming_callback = streaming_callback 126 127 self.api_base_url = api_base_url 128 self.organization = organization 129 self.http_client_kwargs = http_client_kwargs 130 131 if timeout is None: 132 timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0")) 133 if max_retries is None: 134 max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5")) 135 136 self.client = OpenAI( 137 api_key=api_key.resolve_value(), 138 organization=organization, 139 base_url=api_base_url, 140 timeout=timeout, 141 max_retries=max_retries, 142 http_client=init_http_client(self.http_client_kwargs, async_client=False), 143 ) 144 145 def _get_telemetry_data(self) -> dict[str, Any]: 146 """ 147 Data that is sent to Posthog for usage analytics. 148 """ 149 return {"model": self.model} 150 151 def to_dict(self) -> dict[str, Any]: 152 """ 153 Serialize this component to a dictionary. 154 155 :returns: 156 The serialized component as a dictionary. 157 """ 158 callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None 159 return default_to_dict( 160 self, 161 model=self.model, 162 streaming_callback=callback_name, 163 api_base_url=self.api_base_url, 164 organization=self.organization, 165 generation_kwargs=self.generation_kwargs, 166 system_prompt=self.system_prompt, 167 api_key=self.api_key, 168 http_client_kwargs=self.http_client_kwargs, 169 ) 170 171 @classmethod 172 def from_dict(cls, data: dict[str, Any]) -> "OpenAIGenerator": 173 """ 174 Deserialize this component from a dictionary. 175 176 :param data: 177 The dictionary representation of this component. 178 :returns: 179 The deserialized component instance. 180 """ 181 init_params = data.get("init_parameters", {}) 182 serialized_callback_handler = init_params.get("streaming_callback") 183 if serialized_callback_handler: 184 data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) 185 return default_from_dict(cls, data) 186 187 @component.output_types(replies=list[str], meta=list[dict[str, Any]]) 188 def run( 189 self, 190 prompt: str, 191 system_prompt: str | None = None, 192 streaming_callback: StreamingCallbackT | None = None, 193 generation_kwargs: dict[str, Any] | None = None, 194 ) -> dict[str, list[str] | list[dict[str, Any]]]: 195 """ 196 Invoke the text generation inference based on the provided messages and generation parameters. 197 198 :param prompt: 199 The string prompt to use for text generation. 200 :param system_prompt: 201 The system prompt to use for text generation. If this run time system prompt is omitted, the system 202 prompt, if defined at initialisation time, is used. 203 :param streaming_callback: 204 A callback function that is called when a new token is received from the stream. 205 :param generation_kwargs: 206 Additional keyword arguments for text generation. These parameters will potentially override the parameters 207 passed in the `__init__` method. For more details on the parameters supported by the OpenAI API, refer to 208 the OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat/create). 209 :returns: 210 A list of strings containing the generated responses and a list of dictionaries containing the metadata 211 for each response. 212 """ 213 message = ChatMessage.from_user(prompt) 214 if system_prompt is not None: 215 messages = [ChatMessage.from_system(system_prompt), message] 216 elif self.system_prompt: 217 messages = [ChatMessage.from_system(self.system_prompt), message] 218 else: 219 messages = [message] 220 221 # update generation kwargs by merging with the generation kwargs passed to the run method 222 generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} 223 224 # check if streaming_callback is passed 225 streaming_callback = select_streaming_callback( 226 init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False 227 ) 228 229 # adapt ChatMessage(s) to the format expected by the OpenAI API 230 openai_formatted_messages = [message.to_openai_dict_format() for message in messages] 231 232 completion: Stream[ChatCompletionChunk] | ChatCompletion = self.client.chat.completions.create( 233 model=self.model, 234 messages=openai_formatted_messages, # type: ignore 235 stream=streaming_callback is not None, 236 **generation_kwargs, 237 ) 238 239 completions: list[ChatMessage] = [] 240 if streaming_callback is not None: 241 num_responses = generation_kwargs.pop("n", 1) 242 if num_responses > 1: 243 raise ValueError("Cannot stream multiple responses, please set n=1.") 244 245 component_info = ComponentInfo.from_component(self) 246 chunks: list[StreamingChunk] = [] 247 for chunk in completion: 248 chunk_delta: StreamingChunk = _convert_chat_completion_chunk_to_streaming_chunk( 249 chunk=chunk, # type: ignore 250 previous_chunks=chunks, 251 component_info=component_info, 252 ) 253 chunks.append(chunk_delta) 254 streaming_callback(chunk_delta) 255 256 completions = [_convert_streaming_chunks_to_chat_message(chunks=chunks)] 257 elif isinstance(completion, ChatCompletion): 258 completions = [ 259 _convert_chat_completion_to_chat_message(completion=completion, choice=choice) 260 for choice in completion.choices 261 ] 262 263 # before returning, do post-processing of the completions 264 for response in completions: 265 _check_finish_reason(response.meta) 266 267 return { 268 "replies": [message.text or "" for message in completions], 269 "meta": [message.meta for message in completions], 270 }