hugging_face_api_text_embedder.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 from typing import Any 6 7 from haystack import component, default_from_dict, default_to_dict, logging 8 from haystack.lazy_imports import LazyImport 9 from haystack.utils import Secret 10 from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model 11 from haystack.utils.url_validation import is_valid_http_url 12 13 with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import: 14 from huggingface_hub import AsyncInferenceClient, InferenceClient 15 16 logger = logging.getLogger(__name__) 17 18 19 @component 20 class HuggingFaceAPITextEmbedder: 21 """ 22 Embeds strings using Hugging Face APIs. 23 24 Use it with the following Hugging Face APIs: 25 - [Free Serverless Inference API](https://huggingface.co/inference-api) 26 - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints) 27 - [Self-hosted Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference) 28 29 ### Usage examples 30 31 #### With free serverless inference API 32 <!-- test-ignore --> 33 ```python 34 from haystack.components.embedders import HuggingFaceAPITextEmbedder 35 from haystack.utils import Secret 36 37 text_embedder = HuggingFaceAPITextEmbedder(api_type="serverless_inference_api", 38 api_params={"model": "BAAI/bge-small-en-v1.5"}, 39 token=Secret.from_token("<your-api-key>")) 40 41 print(text_embedder.run("I love pizza!")) 42 43 # {'embedding': [0.017020374536514282, -0.023255806416273117, ...], 44 ``` 45 46 #### With paid inference endpoints 47 <!-- test-ignore --> 48 ```python 49 from haystack.components.embedders import HuggingFaceAPITextEmbedder 50 from haystack.utils import Secret 51 text_embedder = HuggingFaceAPITextEmbedder(api_type="inference_endpoints", 52 api_params={"model": "BAAI/bge-small-en-v1.5"}, 53 token=Secret.from_token("<your-api-key>")) 54 55 print(text_embedder.run("I love pizza!")) 56 57 # {'embedding': [0.017020374536514282, -0.023255806416273117, ...], 58 ``` 59 60 #### With self-hosted text embeddings inference 61 <!-- test-ignore --> 62 ```python 63 from haystack.components.embedders import HuggingFaceAPITextEmbedder 64 from haystack.utils import Secret 65 66 text_embedder = HuggingFaceAPITextEmbedder(api_type="text_embeddings_inference", 67 api_params={"url": "http://localhost:8080"}) 68 69 print(text_embedder.run("I love pizza!")) 70 71 # {'embedding': [0.017020374536514282, -0.023255806416273117, ...], 72 ``` 73 """ 74 75 def __init__( 76 self, 77 api_type: HFEmbeddingAPIType | str, 78 api_params: dict[str, str], 79 token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False), 80 prefix: str = "", 81 suffix: str = "", 82 truncate: bool | None = True, 83 normalize: bool | None = False, 84 ) -> None: 85 """ 86 Creates a HuggingFaceAPITextEmbedder component. 87 88 :param api_type: 89 The type of Hugging Face API to use. 90 :param api_params: 91 A dictionary with the following keys: 92 - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`. 93 - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or 94 `TEXT_EMBEDDINGS_INFERENCE`. 95 :param token: The Hugging Face token to use as HTTP bearer authorization. 96 Check your HF token in your [account settings](https://huggingface.co/settings/tokens). 97 :param prefix: 98 A string to add at the beginning of each text. 99 :param suffix: 100 A string to add at the end of each text. 101 :param truncate: 102 Truncates the input text to the maximum length supported by the model. 103 Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS` 104 if the backend uses Text Embeddings Inference. 105 If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored. 106 :param normalize: 107 Normalizes the embeddings to unit length. 108 Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS` 109 if the backend uses Text Embeddings Inference. 110 If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored. 111 """ 112 huggingface_hub_import.check() 113 114 if isinstance(api_type, str): 115 api_type = HFEmbeddingAPIType.from_str(api_type) 116 117 if api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API: 118 model = api_params.get("model") 119 if model is None: 120 raise ValueError( 121 "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`." 122 ) 123 check_valid_model(model, HFModelType.EMBEDDING, token) 124 model_or_url = model 125 elif api_type in [HFEmbeddingAPIType.INFERENCE_ENDPOINTS, HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE]: 126 url = api_params.get("url") 127 if url is None: 128 msg = ( 129 "To use Text Embeddings Inference or Inference Endpoints, you need to specify the `url` " 130 "parameter in `api_params`." 131 ) 132 raise ValueError(msg) 133 if not is_valid_http_url(url): 134 raise ValueError(f"Invalid URL: {url}") 135 model_or_url = url 136 else: 137 msg = f"Unknown api_type {api_type}" 138 raise ValueError(msg) 139 140 self.api_type = api_type 141 self.api_params = api_params 142 self.token = token 143 self.prefix = prefix 144 self.suffix = suffix 145 self.truncate = truncate 146 self.normalize = normalize 147 self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None) 148 self._async_client = AsyncInferenceClient(model_or_url, token=token.resolve_value() if token else None) 149 150 def _prepare_input(self, text: str) -> tuple[str, bool | None, bool | None]: 151 if not isinstance(text, str): 152 raise TypeError( 153 "HuggingFaceAPITextEmbedder expects a string as an input." 154 "In case you want to embed a list of Documents, please use the HuggingFaceAPIDocumentEmbedder." 155 ) 156 157 truncate = self.truncate 158 normalize = self.normalize 159 160 if self.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API: 161 if truncate is not None: 162 msg = "`truncate` parameter is not supported for Serverless Inference API. It will be ignored." 163 logger.warning(msg) 164 truncate = None 165 if normalize is not None: 166 msg = "`normalize` parameter is not supported for Serverless Inference API. It will be ignored." 167 logger.warning(msg) 168 normalize = None 169 170 text_to_embed = self.prefix + text + self.suffix 171 172 return text_to_embed, truncate, normalize 173 174 def to_dict(self) -> dict[str, Any]: 175 """ 176 Serializes the component to a dictionary. 177 178 :returns: 179 Dictionary with serialized data. 180 """ 181 return default_to_dict( 182 self, 183 api_type=str(self.api_type), 184 api_params=self.api_params, 185 prefix=self.prefix, 186 suffix=self.suffix, 187 token=self.token, 188 truncate=self.truncate, 189 normalize=self.normalize, 190 ) 191 192 @classmethod 193 def from_dict(cls, data: dict[str, Any]) -> "HuggingFaceAPITextEmbedder": 194 """ 195 Deserializes the component from a dictionary. 196 197 :param data: 198 Dictionary to deserialize from. 199 :returns: 200 Deserialized component. 201 """ 202 return default_from_dict(cls, data) 203 204 @component.output_types(embedding=list[float]) 205 def run(self, text: str) -> dict[str, Any]: 206 """ 207 Embeds a single string. 208 209 :param text: 210 Text to embed. 211 212 :returns: 213 A dictionary with the following keys: 214 - `embedding`: The embedding of the input text. 215 """ 216 text_to_embed, truncate_val, normalize_val = self._prepare_input(text) 217 218 np_embedding = self._client.feature_extraction( 219 text=text_to_embed, truncate=truncate_val, normalize=normalize_val 220 ) 221 222 error_msg = f"Expected embedding shape (1, embedding_dim) or (embedding_dim,), got {np_embedding.shape}" 223 if np_embedding.ndim > 2: 224 raise ValueError(error_msg) 225 if np_embedding.ndim == 2 and np_embedding.shape[0] != 1: 226 raise ValueError(error_msg) 227 228 embedding = np_embedding.flatten().tolist() 229 230 return {"embedding": embedding} 231 232 @component.output_types(embedding=list[float]) 233 async def run_async(self, text: str) -> dict[str, Any]: 234 """ 235 Embeds a single string asynchronously. 236 237 :param text: 238 Text to embed. 239 240 :returns: 241 A dictionary with the following keys: 242 - `embedding`: The embedding of the input text. 243 """ 244 text_to_embed, truncate_val, normalize_val = self._prepare_input(text) 245 246 np_embedding = await self._async_client.feature_extraction( 247 text=text_to_embed, truncate=truncate_val, normalize=normalize_val 248 ) 249 250 error_msg = f"Expected embedding shape (1, embedding_dim) or (embedding_dim,), got {np_embedding.shape}" 251 if np_embedding.ndim > 2: 252 raise ValueError(error_msg) 253 if np_embedding.ndim == 2 and np_embedding.shape[0] != 1: 254 raise ValueError(error_msg) 255 256 embedding = np_embedding.flatten().tolist() 257 258 return {"embedding": embedding}