hugging_face_api_document_embedder.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 from asyncio import Semaphore, gather 6 from dataclasses import replace 7 from itertools import chain 8 from typing import Any 9 10 from tqdm import tqdm 11 12 from haystack import component, default_from_dict, default_to_dict, logging 13 from haystack.dataclasses import Document 14 from haystack.lazy_imports import LazyImport 15 from haystack.utils import Secret 16 from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model 17 from haystack.utils.url_validation import is_valid_http_url 18 19 with LazyImport(message="Run 'pip install \"huggingface_hub>=0.27.0\"'") as huggingface_hub_import: 20 from huggingface_hub import AsyncInferenceClient, InferenceClient 21 22 logger = logging.getLogger(__name__) 23 24 25 @component 26 class HuggingFaceAPIDocumentEmbedder: 27 """ 28 Embeds documents using Hugging Face APIs. 29 30 Use it with the following Hugging Face APIs: 31 - [Free Serverless Inference API](https://huggingface.co/inference-api) 32 - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints) 33 - [Self-hosted Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference) 34 35 36 ### Usage examples 37 38 #### With free serverless inference API 39 <!-- test-ignore --> 40 ```python 41 from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder 42 from haystack.utils import Secret 43 from haystack.dataclasses import Document 44 45 doc = Document(content="I love pizza!") 46 47 doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="serverless_inference_api", 48 api_params={"model": "BAAI/bge-small-en-v1.5"}, 49 token=Secret.from_token("<your-api-key>")) 50 51 result = document_embedder.run([doc]) 52 print(result["documents"][0].embedding) 53 54 # [0.017020374536514282, -0.023255806416273117, ...] 55 ``` 56 57 #### With paid inference endpoints 58 <!-- test-ignore --> 59 ```python 60 from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder 61 from haystack.utils import Secret 62 from haystack.dataclasses import Document 63 64 doc = Document(content="I love pizza!") 65 66 doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="inference_endpoints", 67 api_params={"url": "<your-inference-endpoint-url>"}, 68 token=Secret.from_token("<your-api-key>")) 69 70 result = document_embedder.run([doc]) 71 print(result["documents"][0].embedding) 72 73 # [0.017020374536514282, -0.023255806416273117, ...] 74 ``` 75 76 #### With self-hosted text embeddings inference 77 <!-- test-ignore --> 78 ```python 79 from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder 80 from haystack.dataclasses import Document 81 82 doc = Document(content="I love pizza!") 83 84 doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="text_embeddings_inference", 85 api_params={"url": "http://localhost:8080"}) 86 87 result = document_embedder.run([doc]) 88 print(result["documents"][0].embedding) 89 90 # [0.017020374536514282, -0.023255806416273117, ...] 91 ``` 92 """ 93 94 def __init__( 95 self, 96 api_type: HFEmbeddingAPIType | str, 97 api_params: dict[str, str], 98 token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False), 99 prefix: str = "", 100 suffix: str = "", 101 truncate: bool | None = True, 102 normalize: bool | None = False, 103 batch_size: int = 32, 104 progress_bar: bool = True, 105 meta_fields_to_embed: list[str] | None = None, 106 embedding_separator: str = "\n", 107 concurrency_limit: int = 4, 108 ) -> None: 109 """ 110 Creates a HuggingFaceAPIDocumentEmbedder component. 111 112 :param api_type: 113 The type of Hugging Face API to use. 114 :param api_params: 115 A dictionary with the following keys: 116 - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`. 117 - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or 118 `TEXT_EMBEDDINGS_INFERENCE`. 119 :param token: The Hugging Face token to use as HTTP bearer authorization. 120 Check your HF token in your [account settings](https://huggingface.co/settings/tokens). 121 :param prefix: 122 A string to add at the beginning of each text. 123 :param suffix: 124 A string to add at the end of each text. 125 :param truncate: 126 Truncates the input text to the maximum length supported by the model. 127 Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS` 128 if the backend uses Text Embeddings Inference. 129 If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored. 130 :param normalize: 131 Normalizes the embeddings to unit length. 132 Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS` 133 if the backend uses Text Embeddings Inference. 134 If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored. 135 :param batch_size: 136 Number of documents to process at once. 137 :param progress_bar: 138 If `True`, shows a progress bar when running. 139 :param meta_fields_to_embed: 140 List of metadata fields to embed along with the document text. 141 :param embedding_separator: 142 Separator used to concatenate the metadata fields to the document text. 143 :param concurrency_limit: 144 The maximum number of requests that should be allowed to run concurrently. 145 This parameter is only used in the `run_async` method. 146 """ 147 huggingface_hub_import.check() 148 149 if isinstance(api_type, str): 150 api_type = HFEmbeddingAPIType.from_str(api_type) 151 152 api_params = api_params or {} 153 154 if api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API: 155 model = api_params.get("model") 156 if model is None: 157 raise ValueError( 158 "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`." 159 ) 160 check_valid_model(model, HFModelType.EMBEDDING, token) 161 model_or_url = model 162 elif api_type in [HFEmbeddingAPIType.INFERENCE_ENDPOINTS, HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE]: 163 url = api_params.get("url") 164 if url is None: 165 msg = ( 166 "To use Text Embeddings Inference or Inference Endpoints, you need to specify the `url` " 167 "parameter in `api_params`." 168 ) 169 raise ValueError(msg) 170 if not is_valid_http_url(url): 171 raise ValueError(f"Invalid URL: {url}") 172 model_or_url = url 173 else: 174 msg = f"Unknown api_type {api_type}" 175 raise ValueError(msg) 176 177 client_args: dict[str, Any] = {"model": model_or_url, "token": token.resolve_value() if token else None} 178 179 self.api_type = api_type 180 self.api_params = api_params 181 self.token = token 182 self.prefix = prefix 183 self.suffix = suffix 184 self.truncate = truncate 185 self.normalize = normalize 186 self.batch_size = batch_size 187 self.progress_bar = progress_bar 188 self.meta_fields_to_embed = meta_fields_to_embed or [] 189 self.embedding_separator = embedding_separator 190 self.concurrency_limit = concurrency_limit 191 self._client = InferenceClient(**client_args) 192 self._async_client = AsyncInferenceClient(**client_args) 193 194 def to_dict(self) -> dict[str, Any]: 195 """ 196 Serializes the component to a dictionary. 197 198 :returns: 199 Dictionary with serialized data. 200 """ 201 return default_to_dict( 202 self, 203 api_type=str(self.api_type), 204 api_params=self.api_params, 205 prefix=self.prefix, 206 suffix=self.suffix, 207 token=self.token, 208 truncate=self.truncate, 209 normalize=self.normalize, 210 batch_size=self.batch_size, 211 progress_bar=self.progress_bar, 212 meta_fields_to_embed=self.meta_fields_to_embed, 213 embedding_separator=self.embedding_separator, 214 ) 215 216 @classmethod 217 def from_dict(cls, data: dict[str, Any]) -> "HuggingFaceAPIDocumentEmbedder": 218 """ 219 Deserializes the component from a dictionary. 220 221 :param data: 222 Dictionary to deserialize from. 223 :returns: 224 Deserialized component. 225 """ 226 return default_from_dict(cls, data) 227 228 def _prepare_texts_to_embed(self, documents: list[Document]) -> list[str]: 229 """ 230 Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. 231 """ 232 texts_to_embed = [] 233 for doc in documents: 234 meta_values_to_embed = [ 235 str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None 236 ] 237 238 text_to_embed = ( 239 self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix 240 ) 241 242 texts_to_embed.append(text_to_embed) 243 return texts_to_embed 244 245 @staticmethod 246 def _adjust_api_parameters( 247 truncate: bool | None, normalize: bool | None, api_type: HFEmbeddingAPIType 248 ) -> tuple[bool | None, bool | None]: 249 """ 250 Adjust the truncate and normalize parameters based on the API type. 251 """ 252 if api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API: 253 if truncate is not None: 254 msg = "`truncate` parameter is not supported for Serverless Inference API. It will be ignored." 255 logger.warning(msg) 256 truncate = None 257 if normalize is not None: 258 msg = "`normalize` parameter is not supported for Serverless Inference API. It will be ignored." 259 logger.warning(msg) 260 normalize = None 261 return truncate, normalize 262 263 def _embed_batch(self, texts_to_embed: list[str], batch_size: int) -> list[list[float]]: 264 """ 265 Embed a list of texts in batches. 266 """ 267 truncate, normalize = self._adjust_api_parameters(self.truncate, self.normalize, self.api_type) 268 269 all_embeddings: list = [] 270 for i in tqdm( 271 range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" 272 ): 273 batch = texts_to_embed[i : i + batch_size] 274 275 np_embeddings = self._client.feature_extraction( 276 # this method does not officially support list of strings, but works as expected 277 text=batch, # type: ignore[arg-type] 278 truncate=truncate, 279 normalize=normalize, 280 ) 281 282 if np_embeddings.ndim != 2 or np_embeddings.shape[0] != len(batch): 283 raise ValueError(f"Expected embedding shape ({batch_size}, embedding_dim), got {np_embeddings.shape}") 284 285 all_embeddings.extend(np_embeddings.tolist()) 286 287 return all_embeddings 288 289 async def _embed_batch_async(self, texts_to_embed: list[str], batch_size: int) -> list[list[float]]: 290 """ 291 Embed a list of texts in batches asynchronously. 292 """ 293 truncate, normalize = self._adjust_api_parameters(self.truncate, self.normalize, self.api_type) 294 sem = Semaphore(max(1, self.concurrency_limit)) 295 num_batches = (len(texts_to_embed) + batch_size - 1) // batch_size 296 pbar = tqdm(total=num_batches, disable=not self.progress_bar, desc="Calculating embeddings") 297 298 async def _runner(batch: list[str]) -> list[list[float]]: 299 async with sem: 300 np_embeddings = await self._async_client.feature_extraction( 301 # this method does not officially support list of strings, but works as expected 302 text=batch, # type: ignore[arg-type] 303 truncate=truncate, 304 normalize=normalize, 305 ) 306 307 if np_embeddings.ndim != 2 or np_embeddings.shape[0] != len(batch): 308 raise ValueError( 309 f"Expected embedding shape ({batch_size}, embedding_dim), got {np_embeddings.shape}" 310 ) 311 312 pbar.update(1) 313 return np_embeddings.tolist() 314 315 try: 316 all_embeddings = [ 317 *chain( 318 *await gather( 319 *[ 320 _runner(texts_to_embed[i : i + batch_size]) 321 for i in range(0, len(texts_to_embed), batch_size) 322 ] 323 ) 324 ) 325 ] 326 finally: 327 pbar.close() 328 329 return all_embeddings 330 331 @component.output_types(documents=list[Document]) 332 def run(self, documents: list[Document]) -> dict[str, list[Document]]: 333 """ 334 Embeds a list of documents. 335 336 :param documents: 337 Documents to embed. 338 339 :returns: 340 A dictionary with the following keys: 341 - `documents`: A list of documents with embeddings. 342 """ 343 if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): 344 raise TypeError( 345 "HuggingFaceAPIDocumentEmbedder expects a list of Documents as input." 346 " In case you want to embed a string, please use the HuggingFaceAPITextEmbedder." 347 ) 348 349 texts_to_embed = self._prepare_texts_to_embed(documents=documents) 350 351 embeddings = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size) 352 353 new_documents = [] 354 for doc, emb in zip(documents, embeddings, strict=True): 355 new_documents.append(replace(doc, embedding=emb)) 356 357 return {"documents": new_documents} 358 359 @component.output_types(documents=list[Document]) 360 async def run_async(self, documents: list[Document]) -> dict[str, list[Document]]: 361 """ 362 Embeds a list of documents asynchronously. 363 364 :param documents: 365 Documents to embed. 366 367 :returns: 368 A dictionary with the following keys: 369 - `documents`: A list of documents with embeddings. 370 """ 371 if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): 372 raise TypeError( 373 "HuggingFaceAPIDocumentEmbedder expects a list of Documents as input." 374 " In case you want to embed a string, please use the HuggingFaceAPITextEmbedder." 375 ) 376 377 texts_to_embed = self._prepare_texts_to_embed(documents=documents) 378 379 embeddings = await self._embed_batch_async(texts_to_embed=texts_to_embed, batch_size=self.batch_size) 380 381 new_documents = [] 382 for doc, emb in zip(documents, embeddings, strict=True): 383 new_documents.append(replace(doc, embedding=emb)) 384 385 return {"documents": new_documents}