openai_document_embedder.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import os 6 from dataclasses import replace 7 from typing import Any 8 9 from more_itertools import batched 10 from openai import APIError, AsyncOpenAI, OpenAI 11 from tqdm import tqdm 12 from tqdm.asyncio import tqdm as async_tqdm 13 14 from haystack import Document, component, default_from_dict, default_to_dict, logging 15 from haystack.utils import Secret 16 from haystack.utils.http_client import init_http_client 17 18 logger = logging.getLogger(__name__) 19 20 21 @component 22 class OpenAIDocumentEmbedder: 23 """ 24 Computes document embeddings using OpenAI models. 25 26 ### Usage example 27 <!-- test-ignore --> 28 ```python 29 from haystack import Document 30 from haystack.components.embedders import OpenAIDocumentEmbedder 31 32 doc = Document(content="I love pizza!") 33 document_embedder = OpenAIDocumentEmbedder() 34 result = document_embedder.run([doc]) 35 36 print(result['documents'][0].embedding) 37 38 # [0.017020374536514282, -0.023255806416273117, ...] 39 ``` 40 """ 41 42 def __init__( # noqa: PLR0913 (too-many-arguments) 43 self, 44 api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"), 45 model: str = "text-embedding-ada-002", 46 dimensions: int | None = None, 47 api_base_url: str | None = None, 48 organization: str | None = None, 49 prefix: str = "", 50 suffix: str = "", 51 batch_size: int = 32, 52 progress_bar: bool = True, 53 meta_fields_to_embed: list[str] | None = None, 54 embedding_separator: str = "\n", 55 timeout: float | None = None, 56 max_retries: int | None = None, 57 http_client_kwargs: dict[str, Any] | None = None, 58 *, 59 raise_on_failure: bool = False, 60 ) -> None: 61 """ 62 Creates an OpenAIDocumentEmbedder component. 63 64 Before initializing the component, you can set the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES' 65 environment variables to override the `timeout` and `max_retries` parameters respectively 66 in the OpenAI client. 67 68 :param api_key: 69 The OpenAI API key. 70 You can set it with an environment variable `OPENAI_API_KEY`, or pass with this parameter 71 during initialization. 72 :param model: 73 The name of the model to use for calculating embeddings. 74 The default model is `text-embedding-ada-002`. 75 :param dimensions: 76 The number of dimensions of the resulting embeddings. Only `text-embedding-3` and 77 later models support this parameter. 78 :param api_base_url: 79 Overrides the default base URL for all HTTP requests. 80 :param organization: 81 Your OpenAI organization ID. See OpenAI's 82 [Setting Up Your Organization](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization) 83 for more information. 84 :param prefix: 85 A string to add at the beginning of each text. 86 :param suffix: 87 A string to add at the end of each text. 88 :param batch_size: 89 Number of documents to embed at once. 90 :param progress_bar: 91 If `True`, shows a progress bar when running. 92 :param meta_fields_to_embed: 93 List of metadata fields to embed along with the document text. 94 :param embedding_separator: 95 Separator used to concatenate the metadata fields to the document text. 96 :param timeout: 97 Timeout for OpenAI client calls. If not set, it defaults to either the 98 `OPENAI_TIMEOUT` environment variable, or 30 seconds. 99 :param max_retries: 100 Maximum number of retries to contact OpenAI after an internal error. 101 If not set, it defaults to either the `OPENAI_MAX_RETRIES` environment variable, or 5 retries. 102 :param http_client_kwargs: 103 A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`. 104 For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client). 105 :param raise_on_failure: 106 Whether to raise an exception if the embedding request fails. If `False`, the component will log the error 107 and continue processing the remaining documents. If `True`, it will raise an exception on failure. 108 """ 109 self.api_key = api_key 110 self.model = model 111 self.dimensions = dimensions 112 self.api_base_url = api_base_url 113 self.organization = organization 114 self.prefix = prefix 115 self.suffix = suffix 116 self.batch_size = batch_size 117 self.progress_bar = progress_bar 118 self.meta_fields_to_embed = meta_fields_to_embed or [] 119 self.embedding_separator = embedding_separator 120 self.timeout = timeout 121 self.max_retries = max_retries 122 self.http_client_kwargs = http_client_kwargs 123 self.raise_on_failure = raise_on_failure 124 125 if timeout is None: 126 timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0")) 127 if max_retries is None: 128 max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5")) 129 130 client_kwargs: dict[str, Any] = { 131 "api_key": api_key.resolve_value(), 132 "organization": organization, 133 "base_url": api_base_url, 134 "timeout": timeout, 135 "max_retries": max_retries, 136 } 137 138 self.client = OpenAI(http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_kwargs) 139 self.async_client = AsyncOpenAI( 140 http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_kwargs 141 ) 142 143 def _get_telemetry_data(self) -> dict[str, Any]: 144 """ 145 Data that is sent to Posthog for usage analytics. 146 """ 147 return {"model": self.model} 148 149 def to_dict(self) -> dict[str, Any]: 150 """ 151 Serializes the component to a dictionary. 152 153 :returns: 154 Dictionary with serialized data. 155 """ 156 return default_to_dict( 157 self, 158 api_key=self.api_key, 159 model=self.model, 160 dimensions=self.dimensions, 161 api_base_url=self.api_base_url, 162 organization=self.organization, 163 prefix=self.prefix, 164 suffix=self.suffix, 165 batch_size=self.batch_size, 166 progress_bar=self.progress_bar, 167 meta_fields_to_embed=self.meta_fields_to_embed, 168 embedding_separator=self.embedding_separator, 169 timeout=self.timeout, 170 max_retries=self.max_retries, 171 http_client_kwargs=self.http_client_kwargs, 172 raise_on_failure=self.raise_on_failure, 173 ) 174 175 @classmethod 176 def from_dict(cls, data: dict[str, Any]) -> "OpenAIDocumentEmbedder": 177 """ 178 Deserializes the component from a dictionary. 179 180 :param data: 181 Dictionary to deserialize from. 182 :returns: 183 Deserialized component. 184 """ 185 return default_from_dict(cls, data) 186 187 def _prepare_texts_to_embed(self, documents: list[Document]) -> dict[str, str]: 188 """ 189 Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. 190 """ 191 texts_to_embed = {} 192 for doc in documents: 193 meta_values_to_embed = [ 194 str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None 195 ] 196 197 texts_to_embed[doc.id] = ( 198 self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix 199 ) 200 201 return texts_to_embed 202 203 def _embed_batch( 204 self, texts_to_embed: dict[str, str], batch_size: int 205 ) -> tuple[dict[str, list[float]], dict[str, Any]]: 206 """ 207 Embed a list of texts in batches. 208 """ 209 210 doc_ids_to_embeddings: dict[str, list[float]] = {} 211 meta: dict[str, Any] = {} 212 for batch in tqdm( 213 batched(texts_to_embed.items(), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" 214 ): 215 args: dict[str, Any] = {"model": self.model, "input": [b[1] for b in batch], "encoding_format": "float"} 216 217 if self.dimensions is not None: 218 args["dimensions"] = self.dimensions 219 220 try: 221 response = self.client.embeddings.create(**args) 222 except APIError as exc: 223 ids = ", ".join(b[0] for b in batch) 224 msg = "Failed embedding of documents {ids} caused by {exc}" 225 logger.exception(msg, ids=ids, exc=exc) 226 if self.raise_on_failure: 227 raise exc 228 continue 229 230 embeddings = [el.embedding for el in response.data] 231 doc_ids_to_embeddings.update(dict(zip((b[0] for b in batch), embeddings, strict=True))) 232 233 if "model" not in meta: 234 meta["model"] = response.model 235 if "usage" not in meta: 236 meta["usage"] = dict(response.usage) 237 else: 238 meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens 239 meta["usage"]["total_tokens"] += response.usage.total_tokens 240 241 return doc_ids_to_embeddings, meta 242 243 async def _embed_batch_async( 244 self, texts_to_embed: dict[str, str], batch_size: int 245 ) -> tuple[dict[str, list[float]], dict[str, Any]]: 246 """ 247 Embed a list of texts in batches asynchronously. 248 """ 249 250 doc_ids_to_embeddings: dict[str, list[float]] = {} 251 meta: dict[str, Any] = {} 252 253 batches = list(batched(texts_to_embed.items(), batch_size)) 254 if self.progress_bar: 255 batches = async_tqdm(batches, desc="Calculating embeddings") 256 257 for batch in batches: 258 args: dict[str, Any] = {"model": self.model, "input": [b[1] for b in batch]} 259 260 if self.dimensions is not None: 261 args["dimensions"] = self.dimensions 262 263 try: 264 response = await self.async_client.embeddings.create(**args) 265 except APIError as exc: 266 ids = ", ".join(b[0] for b in batch) 267 msg = "Failed embedding of documents {ids} caused by {exc}" 268 logger.exception(msg, ids=ids, exc=exc) 269 if self.raise_on_failure: 270 raise exc 271 continue 272 273 embeddings = [el.embedding for el in response.data] 274 doc_ids_to_embeddings.update(dict(zip((b[0] for b in batch), embeddings, strict=True))) 275 276 if "model" not in meta: 277 meta["model"] = response.model 278 if "usage" not in meta: 279 meta["usage"] = dict(response.usage) 280 else: 281 meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens 282 meta["usage"]["total_tokens"] += response.usage.total_tokens 283 284 return doc_ids_to_embeddings, meta 285 286 @component.output_types(documents=list[Document], meta=dict[str, Any]) 287 def run(self, documents: list[Document]) -> dict[str, Any]: 288 """ 289 Embeds a list of documents. 290 291 :param documents: 292 A list of documents to embed. 293 294 :returns: 295 A dictionary with the following keys: 296 - `documents`: A list of documents with embeddings. 297 - `meta`: Information about the usage of the model. 298 """ 299 if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): 300 raise TypeError( 301 "OpenAIDocumentEmbedder expects a list of Documents as input." 302 "In case you want to embed a string, please use the OpenAITextEmbedder." 303 ) 304 305 texts_to_embed = self._prepare_texts_to_embed(documents=documents) 306 307 doc_ids_to_embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size) 308 309 new_documents = [] 310 for doc in documents: 311 if doc.id in doc_ids_to_embeddings: 312 new_documents.append(replace(doc, embedding=doc_ids_to_embeddings[doc.id])) 313 else: 314 new_documents.append(replace(doc)) 315 316 return {"documents": new_documents, "meta": meta} 317 318 @component.output_types(documents=list[Document], meta=dict[str, Any]) 319 async def run_async(self, documents: list[Document]) -> dict[str, Any]: 320 """ 321 Embeds a list of documents asynchronously. 322 323 :param documents: 324 A list of documents to embed. 325 326 :returns: 327 A dictionary with the following keys: 328 - `documents`: A list of documents with embeddings. 329 - `meta`: Information about the usage of the model. 330 """ 331 if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): 332 raise TypeError( 333 "OpenAIDocumentEmbedder expects a list of Documents as input. " 334 "In case you want to embed a string, please use the OpenAITextEmbedder." 335 ) 336 337 texts_to_embed = self._prepare_texts_to_embed(documents=documents) 338 339 doc_ids_to_embeddings, meta = await self._embed_batch_async( 340 texts_to_embed=texts_to_embed, batch_size=self.batch_size 341 ) 342 343 new_documents = [] 344 for doc in documents: 345 if doc.id in doc_ids_to_embeddings: 346 new_documents.append(replace(doc, embedding=doc_ids_to_embeddings[doc.id])) 347 else: 348 new_documents.append(replace(doc)) 349 350 return {"documents": new_documents, "meta": meta}