sentence_transformers_document_embedder.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 from dataclasses import replace 6 from typing import Any, Literal 7 8 from haystack import Document, component, default_from_dict, default_to_dict 9 from haystack.components.embedders.backends.sentence_transformers_backend import ( 10 _SentenceTransformersEmbeddingBackend, 11 _SentenceTransformersEmbeddingBackendFactory, 12 ) 13 from haystack.utils import ComponentDevice, Secret 14 from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs 15 16 17 @component 18 class SentenceTransformersDocumentEmbedder: 19 """ 20 Calculates document embeddings using Sentence Transformers models. 21 22 It stores the embeddings in the `embedding` metadata field of each document. 23 You can also embed documents' metadata. 24 Use this component in indexing pipelines to embed input documents 25 and send them to DocumentWriter to write into a Document Store. 26 27 ### Usage example: 28 <!-- test-ignore --> 29 ```python 30 from haystack import Document 31 from haystack.components.embedders import SentenceTransformersDocumentEmbedder 32 doc = Document(content="I love pizza!") 33 doc_embedder = SentenceTransformersDocumentEmbedder() 34 35 result = doc_embedder.run([doc]) 36 print(result['documents'][0].embedding) 37 38 # [-0.07804739475250244, 0.1498992145061493, ...] 39 ``` 40 """ 41 42 def __init__( # noqa: PLR0913 43 self, 44 model: str = "sentence-transformers/all-mpnet-base-v2", 45 device: ComponentDevice | None = None, 46 token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False), 47 prefix: str = "", 48 suffix: str = "", 49 batch_size: int = 32, 50 progress_bar: bool = True, 51 normalize_embeddings: bool = False, 52 meta_fields_to_embed: list[str] | None = None, 53 embedding_separator: str = "\n", 54 trust_remote_code: bool = False, 55 local_files_only: bool = False, 56 truncate_dim: int | None = None, 57 model_kwargs: dict[str, Any] | None = None, 58 tokenizer_kwargs: dict[str, Any] | None = None, 59 config_kwargs: dict[str, Any] | None = None, 60 precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32", 61 encode_kwargs: dict[str, Any] | None = None, 62 backend: Literal["torch", "onnx", "openvino"] = "torch", 63 revision: str | None = None, 64 ) -> None: 65 """ 66 Creates a SentenceTransformersDocumentEmbedder component. 67 68 :param model: 69 The model to use for calculating embeddings. 70 Pass a local path or ID of the model on Hugging Face. 71 :param device: 72 The device to use for loading the model. 73 Overrides the default device. 74 :param token: 75 The API token to download private models from Hugging Face. 76 :param prefix: 77 A string to add at the beginning of each document text. 78 Can be used to prepend the text with an instruction, as required by some embedding models, 79 such as E5 and bge. 80 :param suffix: 81 A string to add at the end of each document text. 82 :param batch_size: 83 Number of documents to embed at once. 84 :param progress_bar: 85 If `True`, shows a progress bar when embedding documents. 86 :param normalize_embeddings: 87 If `True`, the embeddings are normalized using L2 normalization, so that each embedding has a norm of 1. 88 :param meta_fields_to_embed: 89 List of metadata fields to embed along with the document text. 90 :param embedding_separator: 91 Separator used to concatenate the metadata fields to the document text. 92 :param trust_remote_code: 93 If `False`, allows only Hugging Face verified model architectures. 94 If `True`, allows custom models and scripts. 95 :param local_files_only: 96 If `True`, does not attempt to download the model from Hugging Face Hub and only looks at local files. 97 :param truncate_dim: 98 The dimension to truncate sentence embeddings to. `None` does no truncation. 99 If the model wasn't trained with Matryoshka Representation Learning, 100 truncating embeddings can significantly affect performance. 101 :param model_kwargs: 102 Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained` 103 when loading the model. Refer to specific model documentation for available kwargs. 104 :param tokenizer_kwargs: 105 Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer. 106 Refer to specific model documentation for available kwargs. 107 :param config_kwargs: 108 Additional keyword arguments for `AutoConfig.from_pretrained` when loading the model configuration. 109 :param precision: 110 The precision to use for the embeddings. 111 All non-float32 precisions are quantized embeddings. 112 Quantized embeddings are smaller and faster to compute, but may have a lower accuracy. 113 They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks. 114 :param encode_kwargs: 115 Additional keyword arguments for `SentenceTransformer.encode` when embedding documents. 116 This parameter is provided for fine customization. Be careful not to clash with already set parameters and 117 avoid passing parameters that change the output type. 118 :param backend: 119 The backend to use for the Sentence Transformers model. Choose from "torch", "onnx", or "openvino". 120 Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html) 121 for more information on acceleration and quantization options. 122 :param revision: 123 The specific model version to use. It can be a branch name, a tag name, or a commit id, 124 for a stored model on Hugging Face. 125 """ 126 127 self.model = model 128 self.device = ComponentDevice.resolve_device(device) 129 self.token = token 130 self.prefix = prefix 131 self.suffix = suffix 132 self.batch_size = batch_size 133 self.progress_bar = progress_bar 134 self.normalize_embeddings = normalize_embeddings 135 self.meta_fields_to_embed = meta_fields_to_embed or [] 136 self.embedding_separator = embedding_separator 137 self.trust_remote_code = trust_remote_code 138 self.revision = revision 139 self.local_files_only = local_files_only 140 self.truncate_dim = truncate_dim 141 self.model_kwargs = model_kwargs 142 self.tokenizer_kwargs = tokenizer_kwargs 143 self.config_kwargs = config_kwargs 144 self.encode_kwargs = encode_kwargs 145 self.embedding_backend: _SentenceTransformersEmbeddingBackend | None = None 146 self.precision = precision 147 self.backend = backend 148 149 def _get_telemetry_data(self) -> dict[str, Any]: 150 """ 151 Data that is sent to Posthog for usage analytics. 152 """ 153 return {"model": self.model} 154 155 def to_dict(self) -> dict[str, Any]: 156 """ 157 Serializes the component to a dictionary. 158 159 :returns: 160 Dictionary with serialized data. 161 """ 162 serialization_dict = default_to_dict( 163 self, 164 model=self.model, 165 device=self.device, 166 token=self.token, 167 prefix=self.prefix, 168 suffix=self.suffix, 169 batch_size=self.batch_size, 170 progress_bar=self.progress_bar, 171 normalize_embeddings=self.normalize_embeddings, 172 meta_fields_to_embed=self.meta_fields_to_embed, 173 embedding_separator=self.embedding_separator, 174 trust_remote_code=self.trust_remote_code, 175 revision=self.revision, 176 local_files_only=self.local_files_only, 177 truncate_dim=self.truncate_dim, 178 model_kwargs=self.model_kwargs, 179 tokenizer_kwargs=self.tokenizer_kwargs, 180 config_kwargs=self.config_kwargs, 181 precision=self.precision, 182 encode_kwargs=self.encode_kwargs, 183 backend=self.backend, 184 ) 185 if serialization_dict["init_parameters"].get("model_kwargs") is not None: 186 serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"]) 187 return serialization_dict 188 189 @classmethod 190 def from_dict(cls, data: dict[str, Any]) -> "SentenceTransformersDocumentEmbedder": 191 """ 192 Deserializes the component from a dictionary. 193 194 :param data: 195 Dictionary to deserialize from. 196 :returns: 197 Deserialized component. 198 """ 199 init_params = data["init_parameters"] 200 if init_params.get("model_kwargs") is not None: 201 deserialize_hf_model_kwargs(init_params["model_kwargs"]) 202 return default_from_dict(cls, data) 203 204 def warm_up(self) -> None: 205 """ 206 Initializes the component. 207 """ 208 if self.embedding_backend is None: 209 self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( 210 model=self.model, 211 device=self.device.to_torch_str(), 212 auth_token=self.token, 213 trust_remote_code=self.trust_remote_code, 214 revision=self.revision, 215 local_files_only=self.local_files_only, 216 truncate_dim=self.truncate_dim, 217 model_kwargs=self.model_kwargs, 218 tokenizer_kwargs=self.tokenizer_kwargs, 219 config_kwargs=self.config_kwargs, 220 backend=self.backend, 221 ) 222 if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"): 223 self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"] 224 225 @component.output_types(documents=list[Document]) 226 def run(self, documents: list[Document]) -> dict[str, list[Document]]: 227 """ 228 Embed a list of documents. 229 230 :param documents: 231 Documents to embed. 232 233 :returns: 234 A dictionary with the following keys: 235 - `documents`: Documents with embeddings. 236 """ 237 if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): 238 raise TypeError( 239 "SentenceTransformersDocumentEmbedder expects a list of Documents as input." 240 "In case you want to embed a string, please use the SentenceTransformersTextEmbedder." 241 ) 242 if self.embedding_backend is None: 243 self.warm_up() 244 245 texts_to_embed = [] 246 for doc in documents: 247 meta_values_to_embed = [ 248 str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] 249 ] 250 text_to_embed = ( 251 self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix 252 ) 253 texts_to_embed.append(text_to_embed) 254 255 # # mypy doesn't know this is set in warm_up 256 embeddings = self.embedding_backend.embed( # type: ignore[union-attr] 257 texts_to_embed, 258 batch_size=self.batch_size, 259 show_progress_bar=self.progress_bar, 260 normalize_embeddings=self.normalize_embeddings, 261 precision=self.precision, 262 **(self.encode_kwargs if self.encode_kwargs else {}), 263 ) 264 265 new_documents = [] 266 for doc, emb in zip(documents, embeddings, strict=True): 267 new_documents.append(replace(doc, embedding=emb)) 268 269 return {"documents": new_documents}