/ haystack / components / embedders / openai_document_embedder.py
openai_document_embedder.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import os
  6  from dataclasses import replace
  7  from typing import Any
  8  
  9  from more_itertools import batched
 10  from openai import APIError, AsyncOpenAI, OpenAI
 11  from tqdm import tqdm
 12  from tqdm.asyncio import tqdm as async_tqdm
 13  
 14  from haystack import Document, component, default_from_dict, default_to_dict, logging
 15  from haystack.utils import Secret
 16  from haystack.utils.http_client import init_http_client
 17  
 18  logger = logging.getLogger(__name__)
 19  
 20  
 21  @component
 22  class OpenAIDocumentEmbedder:
 23      """
 24      Computes document embeddings using OpenAI models.
 25  
 26      ### Usage example
 27      <!-- test-ignore -->
 28      ```python
 29      from haystack import Document
 30      from haystack.components.embedders import OpenAIDocumentEmbedder
 31  
 32      doc = Document(content="I love pizza!")
 33      document_embedder = OpenAIDocumentEmbedder()
 34      result = document_embedder.run([doc])
 35  
 36      print(result['documents'][0].embedding)
 37  
 38      # [0.017020374536514282, -0.023255806416273117, ...]
 39      ```
 40      """
 41  
 42      def __init__(  # noqa: PLR0913 (too-many-arguments)
 43          self,
 44          api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
 45          model: str = "text-embedding-ada-002",
 46          dimensions: int | None = None,
 47          api_base_url: str | None = None,
 48          organization: str | None = None,
 49          prefix: str = "",
 50          suffix: str = "",
 51          batch_size: int = 32,
 52          progress_bar: bool = True,
 53          meta_fields_to_embed: list[str] | None = None,
 54          embedding_separator: str = "\n",
 55          timeout: float | None = None,
 56          max_retries: int | None = None,
 57          http_client_kwargs: dict[str, Any] | None = None,
 58          *,
 59          raise_on_failure: bool = False,
 60      ) -> None:
 61          """
 62          Creates an OpenAIDocumentEmbedder component.
 63  
 64          Before initializing the component, you can set the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES'
 65          environment variables to override the `timeout` and `max_retries` parameters respectively
 66          in the OpenAI client.
 67  
 68          :param api_key:
 69              The OpenAI API key.
 70              You can set it with an environment variable `OPENAI_API_KEY`, or pass with this parameter
 71              during initialization.
 72          :param model:
 73              The name of the model to use for calculating embeddings.
 74              The default model is `text-embedding-ada-002`.
 75          :param dimensions:
 76              The number of dimensions of the resulting embeddings. Only `text-embedding-3` and
 77              later models support this parameter.
 78          :param api_base_url:
 79              Overrides the default base URL for all HTTP requests.
 80          :param organization:
 81              Your OpenAI organization ID. See OpenAI's
 82              [Setting Up Your Organization](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization)
 83              for more information.
 84          :param prefix:
 85              A string to add at the beginning of each text.
 86          :param suffix:
 87              A string to add at the end of each text.
 88          :param batch_size:
 89              Number of documents to embed at once.
 90          :param progress_bar:
 91              If `True`, shows a progress bar when running.
 92          :param meta_fields_to_embed:
 93              List of metadata fields to embed along with the document text.
 94          :param embedding_separator:
 95              Separator used to concatenate the metadata fields to the document text.
 96          :param timeout:
 97              Timeout for OpenAI client calls. If not set, it defaults to either the
 98              `OPENAI_TIMEOUT` environment variable, or 30 seconds.
 99          :param max_retries:
100              Maximum number of retries to contact OpenAI after an internal error.
101              If not set, it defaults to either the `OPENAI_MAX_RETRIES` environment variable, or 5 retries.
102          :param http_client_kwargs:
103              A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
104              For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
105          :param raise_on_failure:
106              Whether to raise an exception if the embedding request fails. If `False`, the component will log the error
107              and continue processing the remaining documents. If `True`, it will raise an exception on failure.
108          """
109          self.api_key = api_key
110          self.model = model
111          self.dimensions = dimensions
112          self.api_base_url = api_base_url
113          self.organization = organization
114          self.prefix = prefix
115          self.suffix = suffix
116          self.batch_size = batch_size
117          self.progress_bar = progress_bar
118          self.meta_fields_to_embed = meta_fields_to_embed or []
119          self.embedding_separator = embedding_separator
120          self.timeout = timeout
121          self.max_retries = max_retries
122          self.http_client_kwargs = http_client_kwargs
123          self.raise_on_failure = raise_on_failure
124  
125          if timeout is None:
126              timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
127          if max_retries is None:
128              max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
129  
130          client_kwargs: dict[str, Any] = {
131              "api_key": api_key.resolve_value(),
132              "organization": organization,
133              "base_url": api_base_url,
134              "timeout": timeout,
135              "max_retries": max_retries,
136          }
137  
138          self.client = OpenAI(http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_kwargs)
139          self.async_client = AsyncOpenAI(
140              http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_kwargs
141          )
142  
143      def _get_telemetry_data(self) -> dict[str, Any]:
144          """
145          Data that is sent to Posthog for usage analytics.
146          """
147          return {"model": self.model}
148  
149      def to_dict(self) -> dict[str, Any]:
150          """
151          Serializes the component to a dictionary.
152  
153          :returns:
154              Dictionary with serialized data.
155          """
156          return default_to_dict(
157              self,
158              api_key=self.api_key,
159              model=self.model,
160              dimensions=self.dimensions,
161              api_base_url=self.api_base_url,
162              organization=self.organization,
163              prefix=self.prefix,
164              suffix=self.suffix,
165              batch_size=self.batch_size,
166              progress_bar=self.progress_bar,
167              meta_fields_to_embed=self.meta_fields_to_embed,
168              embedding_separator=self.embedding_separator,
169              timeout=self.timeout,
170              max_retries=self.max_retries,
171              http_client_kwargs=self.http_client_kwargs,
172              raise_on_failure=self.raise_on_failure,
173          )
174  
175      @classmethod
176      def from_dict(cls, data: dict[str, Any]) -> "OpenAIDocumentEmbedder":
177          """
178          Deserializes the component from a dictionary.
179  
180          :param data:
181              Dictionary to deserialize from.
182          :returns:
183              Deserialized component.
184          """
185          return default_from_dict(cls, data)
186  
187      def _prepare_texts_to_embed(self, documents: list[Document]) -> dict[str, str]:
188          """
189          Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
190          """
191          texts_to_embed = {}
192          for doc in documents:
193              meta_values_to_embed = [
194                  str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
195              ]
196  
197              texts_to_embed[doc.id] = (
198                  self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix
199              )
200  
201          return texts_to_embed
202  
203      def _embed_batch(
204          self, texts_to_embed: dict[str, str], batch_size: int
205      ) -> tuple[dict[str, list[float]], dict[str, Any]]:
206          """
207          Embed a list of texts in batches.
208          """
209  
210          doc_ids_to_embeddings: dict[str, list[float]] = {}
211          meta: dict[str, Any] = {}
212          for batch in tqdm(
213              batched(texts_to_embed.items(), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
214          ):
215              args: dict[str, Any] = {"model": self.model, "input": [b[1] for b in batch], "encoding_format": "float"}
216  
217              if self.dimensions is not None:
218                  args["dimensions"] = self.dimensions
219  
220              try:
221                  response = self.client.embeddings.create(**args)
222              except APIError as exc:
223                  ids = ", ".join(b[0] for b in batch)
224                  msg = "Failed embedding of documents {ids} caused by {exc}"
225                  logger.exception(msg, ids=ids, exc=exc)
226                  if self.raise_on_failure:
227                      raise exc
228                  continue
229  
230              embeddings = [el.embedding for el in response.data]
231              doc_ids_to_embeddings.update(dict(zip((b[0] for b in batch), embeddings, strict=True)))
232  
233              if "model" not in meta:
234                  meta["model"] = response.model
235              if "usage" not in meta:
236                  meta["usage"] = dict(response.usage)
237              else:
238                  meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
239                  meta["usage"]["total_tokens"] += response.usage.total_tokens
240  
241          return doc_ids_to_embeddings, meta
242  
243      async def _embed_batch_async(
244          self, texts_to_embed: dict[str, str], batch_size: int
245      ) -> tuple[dict[str, list[float]], dict[str, Any]]:
246          """
247          Embed a list of texts in batches asynchronously.
248          """
249  
250          doc_ids_to_embeddings: dict[str, list[float]] = {}
251          meta: dict[str, Any] = {}
252  
253          batches = list(batched(texts_to_embed.items(), batch_size))
254          if self.progress_bar:
255              batches = async_tqdm(batches, desc="Calculating embeddings")
256  
257          for batch in batches:
258              args: dict[str, Any] = {"model": self.model, "input": [b[1] for b in batch]}
259  
260              if self.dimensions is not None:
261                  args["dimensions"] = self.dimensions
262  
263              try:
264                  response = await self.async_client.embeddings.create(**args)
265              except APIError as exc:
266                  ids = ", ".join(b[0] for b in batch)
267                  msg = "Failed embedding of documents {ids} caused by {exc}"
268                  logger.exception(msg, ids=ids, exc=exc)
269                  if self.raise_on_failure:
270                      raise exc
271                  continue
272  
273              embeddings = [el.embedding for el in response.data]
274              doc_ids_to_embeddings.update(dict(zip((b[0] for b in batch), embeddings, strict=True)))
275  
276              if "model" not in meta:
277                  meta["model"] = response.model
278              if "usage" not in meta:
279                  meta["usage"] = dict(response.usage)
280              else:
281                  meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
282                  meta["usage"]["total_tokens"] += response.usage.total_tokens
283  
284          return doc_ids_to_embeddings, meta
285  
286      @component.output_types(documents=list[Document], meta=dict[str, Any])
287      def run(self, documents: list[Document]) -> dict[str, Any]:
288          """
289          Embeds a list of documents.
290  
291          :param documents:
292              A list of documents to embed.
293  
294          :returns:
295              A dictionary with the following keys:
296              - `documents`: A list of documents with embeddings.
297              - `meta`: Information about the usage of the model.
298          """
299          if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
300              raise TypeError(
301                  "OpenAIDocumentEmbedder expects a list of Documents as input."
302                  "In case you want to embed a string, please use the OpenAITextEmbedder."
303              )
304  
305          texts_to_embed = self._prepare_texts_to_embed(documents=documents)
306  
307          doc_ids_to_embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
308  
309          new_documents = []
310          for doc in documents:
311              if doc.id in doc_ids_to_embeddings:
312                  new_documents.append(replace(doc, embedding=doc_ids_to_embeddings[doc.id]))
313              else:
314                  new_documents.append(replace(doc))
315  
316          return {"documents": new_documents, "meta": meta}
317  
318      @component.output_types(documents=list[Document], meta=dict[str, Any])
319      async def run_async(self, documents: list[Document]) -> dict[str, Any]:
320          """
321          Embeds a list of documents asynchronously.
322  
323          :param documents:
324              A list of documents to embed.
325  
326          :returns:
327              A dictionary with the following keys:
328              - `documents`: A list of documents with embeddings.
329              - `meta`: Information about the usage of the model.
330          """
331          if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
332              raise TypeError(
333                  "OpenAIDocumentEmbedder expects a list of Documents as input. "
334                  "In case you want to embed a string, please use the OpenAITextEmbedder."
335              )
336  
337          texts_to_embed = self._prepare_texts_to_embed(documents=documents)
338  
339          doc_ids_to_embeddings, meta = await self._embed_batch_async(
340              texts_to_embed=texts_to_embed, batch_size=self.batch_size
341          )
342  
343          new_documents = []
344          for doc in documents:
345              if doc.id in doc_ids_to_embeddings:
346                  new_documents.append(replace(doc, embedding=doc_ids_to_embeddings[doc.id]))
347              else:
348                  new_documents.append(replace(doc))
349  
350          return {"documents": new_documents, "meta": meta}