/ haystack / components / embedders / hugging_face_api_document_embedder.py
hugging_face_api_document_embedder.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  from asyncio import Semaphore, gather
  6  from dataclasses import replace
  7  from itertools import chain
  8  from typing import Any
  9  
 10  from tqdm import tqdm
 11  
 12  from haystack import component, default_from_dict, default_to_dict, logging
 13  from haystack.dataclasses import Document
 14  from haystack.lazy_imports import LazyImport
 15  from haystack.utils import Secret
 16  from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model
 17  from haystack.utils.url_validation import is_valid_http_url
 18  
 19  with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import:
 20      from huggingface_hub import AsyncInferenceClient, InferenceClient
 21  
 22  logger = logging.getLogger(__name__)
 23  
 24  
 25  @component
 26  class HuggingFaceAPIDocumentEmbedder:
 27      """
 28      Embeds documents using Hugging Face APIs.
 29  
 30      Use it with the following Hugging Face APIs:
 31      - [Free Serverless Inference API](https://huggingface.co/inference-api)
 32      - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
 33      - [Self-hosted Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference)
 34  
 35  
 36      ### Usage examples
 37  
 38      #### With free serverless inference API
 39      <!-- test-ignore -->
 40      ```python
 41      from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder
 42      from haystack.utils import Secret
 43      from haystack.dataclasses import Document
 44  
 45      doc = Document(content="I love pizza!")
 46  
 47      doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="serverless_inference_api",
 48                                                    api_params={"model": "BAAI/bge-small-en-v1.5"},
 49                                                    token=Secret.from_token("<your-api-key>"))
 50  
 51      result = document_embedder.run([doc])
 52      print(result["documents"][0].embedding)
 53  
 54      # [0.017020374536514282, -0.023255806416273117, ...]
 55      ```
 56  
 57      #### With paid inference endpoints
 58      <!-- test-ignore -->
 59      ```python
 60      from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder
 61      from haystack.utils import Secret
 62      from haystack.dataclasses import Document
 63  
 64      doc = Document(content="I love pizza!")
 65  
 66      doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="inference_endpoints",
 67                                                    api_params={"url": "<your-inference-endpoint-url>"},
 68                                                    token=Secret.from_token("<your-api-key>"))
 69  
 70      result = document_embedder.run([doc])
 71      print(result["documents"][0].embedding)
 72  
 73      # [0.017020374536514282, -0.023255806416273117, ...]
 74      ```
 75  
 76      #### With self-hosted text embeddings inference
 77      <!-- test-ignore -->
 78      ```python
 79      from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder
 80      from haystack.dataclasses import Document
 81  
 82      doc = Document(content="I love pizza!")
 83  
 84      doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="text_embeddings_inference",
 85                                                    api_params={"url": "http://localhost:8080"})
 86  
 87      result = document_embedder.run([doc])
 88      print(result["documents"][0].embedding)
 89  
 90      # [0.017020374536514282, -0.023255806416273117, ...]
 91      ```
 92      """
 93  
 94      def __init__(
 95          self,
 96          api_type: HFEmbeddingAPIType | str,
 97          api_params: dict[str, str],
 98          token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
 99          prefix: str = "",
100          suffix: str = "",
101          truncate: bool | None = True,
102          normalize: bool | None = False,
103          batch_size: int = 32,
104          progress_bar: bool = True,
105          meta_fields_to_embed: list[str] | None = None,
106          embedding_separator: str = "\n",
107          concurrency_limit: int = 4,
108      ) -> None:
109          """
110          Creates a HuggingFaceAPIDocumentEmbedder component.
111  
112          :param api_type:
113              The type of Hugging Face API to use.
114          :param api_params:
115              A dictionary with the following keys:
116              - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
117              - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
118              `TEXT_EMBEDDINGS_INFERENCE`.
119          :param token: The Hugging Face token to use as HTTP bearer authorization.
120              Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
121          :param prefix:
122              A string to add at the beginning of each text.
123          :param suffix:
124              A string to add at the end of each text.
125          :param truncate:
126              Truncates the input text to the maximum length supported by the model.
127              Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
128              if the backend uses Text Embeddings Inference.
129              If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
130          :param normalize:
131              Normalizes the embeddings to unit length.
132              Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
133              if the backend uses Text Embeddings Inference.
134              If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
135          :param batch_size:
136              Number of documents to process at once.
137          :param progress_bar:
138              If `True`, shows a progress bar when running.
139          :param meta_fields_to_embed:
140              List of metadata fields to embed along with the document text.
141          :param embedding_separator:
142              Separator used to concatenate the metadata fields to the document text.
143          :param concurrency_limit:
144              The maximum number of requests that should be allowed to run concurrently.
145              This parameter is only used in the `run_async` method.
146          """
147          huggingface_hub_import.check()
148  
149          if isinstance(api_type, str):
150              api_type = HFEmbeddingAPIType.from_str(api_type)
151  
152          api_params = api_params or {}
153  
154          if api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
155              model = api_params.get("model")
156              if model is None:
157                  raise ValueError(
158                      "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
159                  )
160              check_valid_model(model, HFModelType.EMBEDDING, token)
161              model_or_url = model
162          elif api_type in [HFEmbeddingAPIType.INFERENCE_ENDPOINTS, HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE]:
163              url = api_params.get("url")
164              if url is None:
165                  msg = (
166                      "To use Text Embeddings Inference or Inference Endpoints, you need to specify the `url` "
167                      "parameter in `api_params`."
168                  )
169                  raise ValueError(msg)
170              if not is_valid_http_url(url):
171                  raise ValueError(f"Invalid URL: {url}")
172              model_or_url = url
173          else:
174              msg = f"Unknown api_type {api_type}"
175              raise ValueError(msg)
176  
177          client_args: dict[str, Any] = {"model": model_or_url, "token": token.resolve_value() if token else None}
178  
179          self.api_type = api_type
180          self.api_params = api_params
181          self.token = token
182          self.prefix = prefix
183          self.suffix = suffix
184          self.truncate = truncate
185          self.normalize = normalize
186          self.batch_size = batch_size
187          self.progress_bar = progress_bar
188          self.meta_fields_to_embed = meta_fields_to_embed or []
189          self.embedding_separator = embedding_separator
190          self.concurrency_limit = concurrency_limit
191          self._client = InferenceClient(**client_args)
192          self._async_client = AsyncInferenceClient(**client_args)
193  
194      def to_dict(self) -> dict[str, Any]:
195          """
196          Serializes the component to a dictionary.
197  
198          :returns:
199              Dictionary with serialized data.
200          """
201          return default_to_dict(
202              self,
203              api_type=str(self.api_type),
204              api_params=self.api_params,
205              prefix=self.prefix,
206              suffix=self.suffix,
207              token=self.token,
208              truncate=self.truncate,
209              normalize=self.normalize,
210              batch_size=self.batch_size,
211              progress_bar=self.progress_bar,
212              meta_fields_to_embed=self.meta_fields_to_embed,
213              embedding_separator=self.embedding_separator,
214          )
215  
216      @classmethod
217      def from_dict(cls, data: dict[str, Any]) -> "HuggingFaceAPIDocumentEmbedder":
218          """
219          Deserializes the component from a dictionary.
220  
221          :param data:
222              Dictionary to deserialize from.
223          :returns:
224              Deserialized component.
225          """
226          return default_from_dict(cls, data)
227  
228      def _prepare_texts_to_embed(self, documents: list[Document]) -> list[str]:
229          """
230          Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
231          """
232          texts_to_embed = []
233          for doc in documents:
234              meta_values_to_embed = [
235                  str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
236              ]
237  
238              text_to_embed = (
239                  self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix
240              )
241  
242              texts_to_embed.append(text_to_embed)
243          return texts_to_embed
244  
245      @staticmethod
246      def _adjust_api_parameters(
247          truncate: bool | None, normalize: bool | None, api_type: HFEmbeddingAPIType
248      ) -> tuple[bool | None, bool | None]:
249          """
250          Adjust the truncate and normalize parameters based on the API type.
251          """
252          if api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
253              if truncate is not None:
254                  msg = "`truncate` parameter is not supported for Serverless Inference API. It will be ignored."
255                  logger.warning(msg)
256                  truncate = None
257              if normalize is not None:
258                  msg = "`normalize` parameter is not supported for Serverless Inference API. It will be ignored."
259                  logger.warning(msg)
260                  normalize = None
261          return truncate, normalize
262  
263      def _embed_batch(self, texts_to_embed: list[str], batch_size: int) -> list[list[float]]:
264          """
265          Embed a list of texts in batches.
266          """
267          truncate, normalize = self._adjust_api_parameters(self.truncate, self.normalize, self.api_type)
268  
269          all_embeddings: list = []
270          for i in tqdm(
271              range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
272          ):
273              batch = texts_to_embed[i : i + batch_size]
274  
275              np_embeddings = self._client.feature_extraction(
276                  # this method does not officially support list of strings, but works as expected
277                  text=batch,  # type: ignore[arg-type]
278                  truncate=truncate,
279                  normalize=normalize,
280              )
281  
282              if np_embeddings.ndim != 2 or np_embeddings.shape[0] != len(batch):
283                  raise ValueError(f"Expected embedding shape ({batch_size}, embedding_dim), got {np_embeddings.shape}")
284  
285              all_embeddings.extend(np_embeddings.tolist())
286  
287          return all_embeddings
288  
289      async def _embed_batch_async(self, texts_to_embed: list[str], batch_size: int) -> list[list[float]]:
290          """
291          Embed a list of texts in batches asynchronously.
292          """
293          truncate, normalize = self._adjust_api_parameters(self.truncate, self.normalize, self.api_type)
294          sem = Semaphore(max(1, self.concurrency_limit))
295          num_batches = (len(texts_to_embed) + batch_size - 1) // batch_size
296          pbar = tqdm(total=num_batches, disable=not self.progress_bar, desc="Calculating embeddings")
297  
298          async def _runner(batch: list[str]) -> list[list[float]]:
299              async with sem:
300                  np_embeddings = await self._async_client.feature_extraction(
301                      # this method does not officially support list of strings, but works as expected
302                      text=batch,  # type: ignore[arg-type]
303                      truncate=truncate,
304                      normalize=normalize,
305                  )
306  
307                  if np_embeddings.ndim != 2 or np_embeddings.shape[0] != len(batch):
308                      raise ValueError(
309                          f"Expected embedding shape ({batch_size}, embedding_dim), got {np_embeddings.shape}"
310                      )
311  
312                  pbar.update(1)
313                  return np_embeddings.tolist()
314  
315          try:
316              all_embeddings = [
317                  *chain(
318                      *await gather(
319                          *[
320                              _runner(texts_to_embed[i : i + batch_size])
321                              for i in range(0, len(texts_to_embed), batch_size)
322                          ]
323                      )
324                  )
325              ]
326          finally:
327              pbar.close()
328  
329          return all_embeddings
330  
331      @component.output_types(documents=list[Document])
332      def run(self, documents: list[Document]) -> dict[str, list[Document]]:
333          """
334          Embeds a list of documents.
335  
336          :param documents:
337              Documents to embed.
338  
339          :returns:
340              A dictionary with the following keys:
341              - `documents`: A list of documents with embeddings.
342          """
343          if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
344              raise TypeError(
345                  "HuggingFaceAPIDocumentEmbedder expects a list of Documents as input."
346                  " In case you want to embed a string, please use the HuggingFaceAPITextEmbedder."
347              )
348  
349          texts_to_embed = self._prepare_texts_to_embed(documents=documents)
350  
351          embeddings = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
352  
353          new_documents = []
354          for doc, emb in zip(documents, embeddings, strict=True):
355              new_documents.append(replace(doc, embedding=emb))
356  
357          return {"documents": new_documents}
358  
359      @component.output_types(documents=list[Document])
360      async def run_async(self, documents: list[Document]) -> dict[str, list[Document]]:
361          """
362          Embeds a list of documents asynchronously.
363  
364          :param documents:
365              Documents to embed.
366  
367          :returns:
368              A dictionary with the following keys:
369              - `documents`: A list of documents with embeddings.
370          """
371          if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
372              raise TypeError(
373                  "HuggingFaceAPIDocumentEmbedder expects a list of Documents as input."
374                  " In case you want to embed a string, please use the HuggingFaceAPITextEmbedder."
375              )
376  
377          texts_to_embed = self._prepare_texts_to_embed(documents=documents)
378  
379          embeddings = await self._embed_batch_async(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
380  
381          new_documents = []
382          for doc, emb in zip(documents, embeddings, strict=True):
383              new_documents.append(replace(doc, embedding=emb))
384  
385          return {"documents": new_documents}