/ haystack / components / embedders / hugging_face_api_text_embedder.py
hugging_face_api_text_embedder.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  from typing import Any
  6  
  7  from haystack import component, default_from_dict, default_to_dict, logging
  8  from haystack.lazy_imports import LazyImport
  9  from haystack.utils import Secret
 10  from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model
 11  from haystack.utils.url_validation import is_valid_http_url
 12  
 13  with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import:
 14      from huggingface_hub import AsyncInferenceClient, InferenceClient
 15  
 16  logger = logging.getLogger(__name__)
 17  
 18  
 19  @component
 20  class HuggingFaceAPITextEmbedder:
 21      """
 22      Embeds strings using Hugging Face APIs.
 23  
 24      Use it with the following Hugging Face APIs:
 25      - [Free Serverless Inference API](https://huggingface.co/inference-api)
 26      - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
 27      - [Self-hosted Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference)
 28  
 29      ### Usage examples
 30  
 31      #### With free serverless inference API
 32      <!-- test-ignore -->
 33      ```python
 34      from haystack.components.embedders import HuggingFaceAPITextEmbedder
 35      from haystack.utils import Secret
 36  
 37      text_embedder = HuggingFaceAPITextEmbedder(api_type="serverless_inference_api",
 38                                                 api_params={"model": "BAAI/bge-small-en-v1.5"},
 39                                                 token=Secret.from_token("<your-api-key>"))
 40  
 41      print(text_embedder.run("I love pizza!"))
 42  
 43      # {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
 44      ```
 45  
 46      #### With paid inference endpoints
 47      <!-- test-ignore -->
 48      ```python
 49      from haystack.components.embedders import HuggingFaceAPITextEmbedder
 50      from haystack.utils import Secret
 51      text_embedder = HuggingFaceAPITextEmbedder(api_type="inference_endpoints",
 52                                                 api_params={"model": "BAAI/bge-small-en-v1.5"},
 53                                                 token=Secret.from_token("<your-api-key>"))
 54  
 55      print(text_embedder.run("I love pizza!"))
 56  
 57      # {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
 58      ```
 59  
 60      #### With self-hosted text embeddings inference
 61      <!-- test-ignore -->
 62      ```python
 63      from haystack.components.embedders import HuggingFaceAPITextEmbedder
 64      from haystack.utils import Secret
 65  
 66      text_embedder = HuggingFaceAPITextEmbedder(api_type="text_embeddings_inference",
 67                                                 api_params={"url": "http://localhost:8080"})
 68  
 69      print(text_embedder.run("I love pizza!"))
 70  
 71      # {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
 72      ```
 73      """
 74  
 75      def __init__(
 76          self,
 77          api_type: HFEmbeddingAPIType | str,
 78          api_params: dict[str, str],
 79          token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
 80          prefix: str = "",
 81          suffix: str = "",
 82          truncate: bool | None = True,
 83          normalize: bool | None = False,
 84      ) -> None:
 85          """
 86          Creates a HuggingFaceAPITextEmbedder component.
 87  
 88          :param api_type:
 89              The type of Hugging Face API to use.
 90          :param api_params:
 91              A dictionary with the following keys:
 92              - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
 93              - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
 94              `TEXT_EMBEDDINGS_INFERENCE`.
 95          :param token: The Hugging Face token to use as HTTP bearer authorization.
 96              Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
 97          :param prefix:
 98              A string to add at the beginning of each text.
 99          :param suffix:
100              A string to add at the end of each text.
101          :param truncate:
102              Truncates the input text to the maximum length supported by the model.
103              Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
104              if the backend uses Text Embeddings Inference.
105              If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
106          :param normalize:
107              Normalizes the embeddings to unit length.
108              Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
109              if the backend uses Text Embeddings Inference.
110              If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
111          """
112          huggingface_hub_import.check()
113  
114          if isinstance(api_type, str):
115              api_type = HFEmbeddingAPIType.from_str(api_type)
116  
117          if api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
118              model = api_params.get("model")
119              if model is None:
120                  raise ValueError(
121                      "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
122                  )
123              check_valid_model(model, HFModelType.EMBEDDING, token)
124              model_or_url = model
125          elif api_type in [HFEmbeddingAPIType.INFERENCE_ENDPOINTS, HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE]:
126              url = api_params.get("url")
127              if url is None:
128                  msg = (
129                      "To use Text Embeddings Inference or Inference Endpoints, you need to specify the `url` "
130                      "parameter in `api_params`."
131                  )
132                  raise ValueError(msg)
133              if not is_valid_http_url(url):
134                  raise ValueError(f"Invalid URL: {url}")
135              model_or_url = url
136          else:
137              msg = f"Unknown api_type {api_type}"
138              raise ValueError(msg)
139  
140          self.api_type = api_type
141          self.api_params = api_params
142          self.token = token
143          self.prefix = prefix
144          self.suffix = suffix
145          self.truncate = truncate
146          self.normalize = normalize
147          self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
148          self._async_client = AsyncInferenceClient(model_or_url, token=token.resolve_value() if token else None)
149  
150      def _prepare_input(self, text: str) -> tuple[str, bool | None, bool | None]:
151          if not isinstance(text, str):
152              raise TypeError(
153                  "HuggingFaceAPITextEmbedder expects a string as an input."
154                  "In case you want to embed a list of Documents, please use the HuggingFaceAPIDocumentEmbedder."
155              )
156  
157          truncate = self.truncate
158          normalize = self.normalize
159  
160          if self.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
161              if truncate is not None:
162                  msg = "`truncate` parameter is not supported for Serverless Inference API. It will be ignored."
163                  logger.warning(msg)
164                  truncate = None
165              if normalize is not None:
166                  msg = "`normalize` parameter is not supported for Serverless Inference API. It will be ignored."
167                  logger.warning(msg)
168                  normalize = None
169  
170          text_to_embed = self.prefix + text + self.suffix
171  
172          return text_to_embed, truncate, normalize
173  
174      def to_dict(self) -> dict[str, Any]:
175          """
176          Serializes the component to a dictionary.
177  
178          :returns:
179              Dictionary with serialized data.
180          """
181          return default_to_dict(
182              self,
183              api_type=str(self.api_type),
184              api_params=self.api_params,
185              prefix=self.prefix,
186              suffix=self.suffix,
187              token=self.token,
188              truncate=self.truncate,
189              normalize=self.normalize,
190          )
191  
192      @classmethod
193      def from_dict(cls, data: dict[str, Any]) -> "HuggingFaceAPITextEmbedder":
194          """
195          Deserializes the component from a dictionary.
196  
197          :param data:
198              Dictionary to deserialize from.
199          :returns:
200              Deserialized component.
201          """
202          return default_from_dict(cls, data)
203  
204      @component.output_types(embedding=list[float])
205      def run(self, text: str) -> dict[str, Any]:
206          """
207          Embeds a single string.
208  
209          :param text:
210              Text to embed.
211  
212          :returns:
213              A dictionary with the following keys:
214              - `embedding`: The embedding of the input text.
215          """
216          text_to_embed, truncate_val, normalize_val = self._prepare_input(text)
217  
218          np_embedding = self._client.feature_extraction(
219              text=text_to_embed, truncate=truncate_val, normalize=normalize_val
220          )
221  
222          error_msg = f"Expected embedding shape (1, embedding_dim) or (embedding_dim,), got {np_embedding.shape}"
223          if np_embedding.ndim > 2:
224              raise ValueError(error_msg)
225          if np_embedding.ndim == 2 and np_embedding.shape[0] != 1:
226              raise ValueError(error_msg)
227  
228          embedding = np_embedding.flatten().tolist()
229  
230          return {"embedding": embedding}
231  
232      @component.output_types(embedding=list[float])
233      async def run_async(self, text: str) -> dict[str, Any]:
234          """
235          Embeds a single string asynchronously.
236  
237          :param text:
238              Text to embed.
239  
240          :returns:
241              A dictionary with the following keys:
242              - `embedding`: The embedding of the input text.
243          """
244          text_to_embed, truncate_val, normalize_val = self._prepare_input(text)
245  
246          np_embedding = await self._async_client.feature_extraction(
247              text=text_to_embed, truncate=truncate_val, normalize=normalize_val
248          )
249  
250          error_msg = f"Expected embedding shape (1, embedding_dim) or (embedding_dim,), got {np_embedding.shape}"
251          if np_embedding.ndim > 2:
252              raise ValueError(error_msg)
253          if np_embedding.ndim == 2 and np_embedding.shape[0] != 1:
254              raise ValueError(error_msg)
255  
256          embedding = np_embedding.flatten().tolist()
257  
258          return {"embedding": embedding}