/ haystack / components / embedders / sentence_transformers_document_embedder.py
sentence_transformers_document_embedder.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  from dataclasses import replace
  6  from typing import Any, Literal
  7  
  8  from haystack import Document, component, default_from_dict, default_to_dict
  9  from haystack.components.embedders.backends.sentence_transformers_backend import (
 10      _SentenceTransformersEmbeddingBackend,
 11      _SentenceTransformersEmbeddingBackendFactory,
 12  )
 13  from haystack.utils import ComponentDevice, Secret
 14  from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
 15  
 16  
 17  @component
 18  class SentenceTransformersDocumentEmbedder:
 19      """
 20      Calculates document embeddings using Sentence Transformers models.
 21  
 22      It stores the embeddings in the `embedding` metadata field of each document.
 23      You can also embed documents' metadata.
 24      Use this component in indexing pipelines to embed input documents
 25      and send them to DocumentWriter to write into a Document Store.
 26  
 27      ### Usage example:
 28      <!-- test-ignore -->
 29      ```python
 30      from haystack import Document
 31      from haystack.components.embedders import SentenceTransformersDocumentEmbedder
 32      doc = Document(content="I love pizza!")
 33      doc_embedder = SentenceTransformersDocumentEmbedder()
 34  
 35      result = doc_embedder.run([doc])
 36      print(result['documents'][0].embedding)
 37  
 38      # [-0.07804739475250244, 0.1498992145061493, ...]
 39      ```
 40      """
 41  
 42      def __init__(  # noqa: PLR0913
 43          self,
 44          model: str = "sentence-transformers/all-mpnet-base-v2",
 45          device: ComponentDevice | None = None,
 46          token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
 47          prefix: str = "",
 48          suffix: str = "",
 49          batch_size: int = 32,
 50          progress_bar: bool = True,
 51          normalize_embeddings: bool = False,
 52          meta_fields_to_embed: list[str] | None = None,
 53          embedding_separator: str = "\n",
 54          trust_remote_code: bool = False,
 55          local_files_only: bool = False,
 56          truncate_dim: int | None = None,
 57          model_kwargs: dict[str, Any] | None = None,
 58          tokenizer_kwargs: dict[str, Any] | None = None,
 59          config_kwargs: dict[str, Any] | None = None,
 60          precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
 61          encode_kwargs: dict[str, Any] | None = None,
 62          backend: Literal["torch", "onnx", "openvino"] = "torch",
 63          revision: str | None = None,
 64      ) -> None:
 65          """
 66          Creates a SentenceTransformersDocumentEmbedder component.
 67  
 68          :param model:
 69              The model to use for calculating embeddings.
 70              Pass a local path or ID of the model on Hugging Face.
 71          :param device:
 72              The device to use for loading the model.
 73              Overrides the default device.
 74          :param token:
 75              The API token to download private models from Hugging Face.
 76          :param prefix:
 77              A string to add at the beginning of each document text.
 78              Can be used to prepend the text with an instruction, as required by some embedding models,
 79              such as E5 and bge.
 80          :param suffix:
 81              A string to add at the end of each document text.
 82          :param batch_size:
 83              Number of documents to embed at once.
 84          :param progress_bar:
 85              If `True`, shows a progress bar when embedding documents.
 86          :param normalize_embeddings:
 87              If `True`, the embeddings are normalized using L2 normalization, so that each embedding has a norm of 1.
 88          :param meta_fields_to_embed:
 89              List of metadata fields to embed along with the document text.
 90          :param embedding_separator:
 91              Separator used to concatenate the metadata fields to the document text.
 92          :param trust_remote_code:
 93              If `False`, allows only Hugging Face verified model architectures.
 94              If `True`, allows custom models and scripts.
 95          :param local_files_only:
 96              If `True`, does not attempt to download the model from Hugging Face Hub and only looks at local files.
 97          :param truncate_dim:
 98              The dimension to truncate sentence embeddings to. `None` does no truncation.
 99              If the model wasn't trained with Matryoshka Representation Learning,
100              truncating embeddings can significantly affect performance.
101          :param model_kwargs:
102              Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained`
103              when loading the model. Refer to specific model documentation for available kwargs.
104          :param tokenizer_kwargs:
105              Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer.
106              Refer to specific model documentation for available kwargs.
107          :param config_kwargs:
108              Additional keyword arguments for `AutoConfig.from_pretrained` when loading the model configuration.
109          :param precision:
110              The precision to use for the embeddings.
111              All non-float32 precisions are quantized embeddings.
112              Quantized embeddings are smaller and faster to compute, but may have a lower accuracy.
113              They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks.
114          :param encode_kwargs:
115              Additional keyword arguments for `SentenceTransformer.encode` when embedding documents.
116              This parameter is provided for fine customization. Be careful not to clash with already set parameters and
117              avoid passing parameters that change the output type.
118          :param backend:
119              The backend to use for the Sentence Transformers model. Choose from "torch", "onnx", or "openvino".
120              Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html)
121              for more information on acceleration and quantization options.
122          :param revision:
123              The specific model version to use. It can be a branch name, a tag name, or a commit id,
124              for a stored model on Hugging Face.
125          """
126  
127          self.model = model
128          self.device = ComponentDevice.resolve_device(device)
129          self.token = token
130          self.prefix = prefix
131          self.suffix = suffix
132          self.batch_size = batch_size
133          self.progress_bar = progress_bar
134          self.normalize_embeddings = normalize_embeddings
135          self.meta_fields_to_embed = meta_fields_to_embed or []
136          self.embedding_separator = embedding_separator
137          self.trust_remote_code = trust_remote_code
138          self.revision = revision
139          self.local_files_only = local_files_only
140          self.truncate_dim = truncate_dim
141          self.model_kwargs = model_kwargs
142          self.tokenizer_kwargs = tokenizer_kwargs
143          self.config_kwargs = config_kwargs
144          self.encode_kwargs = encode_kwargs
145          self.embedding_backend: _SentenceTransformersEmbeddingBackend | None = None
146          self.precision = precision
147          self.backend = backend
148  
149      def _get_telemetry_data(self) -> dict[str, Any]:
150          """
151          Data that is sent to Posthog for usage analytics.
152          """
153          return {"model": self.model}
154  
155      def to_dict(self) -> dict[str, Any]:
156          """
157          Serializes the component to a dictionary.
158  
159          :returns:
160              Dictionary with serialized data.
161          """
162          serialization_dict = default_to_dict(
163              self,
164              model=self.model,
165              device=self.device,
166              token=self.token,
167              prefix=self.prefix,
168              suffix=self.suffix,
169              batch_size=self.batch_size,
170              progress_bar=self.progress_bar,
171              normalize_embeddings=self.normalize_embeddings,
172              meta_fields_to_embed=self.meta_fields_to_embed,
173              embedding_separator=self.embedding_separator,
174              trust_remote_code=self.trust_remote_code,
175              revision=self.revision,
176              local_files_only=self.local_files_only,
177              truncate_dim=self.truncate_dim,
178              model_kwargs=self.model_kwargs,
179              tokenizer_kwargs=self.tokenizer_kwargs,
180              config_kwargs=self.config_kwargs,
181              precision=self.precision,
182              encode_kwargs=self.encode_kwargs,
183              backend=self.backend,
184          )
185          if serialization_dict["init_parameters"].get("model_kwargs") is not None:
186              serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
187          return serialization_dict
188  
189      @classmethod
190      def from_dict(cls, data: dict[str, Any]) -> "SentenceTransformersDocumentEmbedder":
191          """
192          Deserializes the component from a dictionary.
193  
194          :param data:
195              Dictionary to deserialize from.
196          :returns:
197              Deserialized component.
198          """
199          init_params = data["init_parameters"]
200          if init_params.get("model_kwargs") is not None:
201              deserialize_hf_model_kwargs(init_params["model_kwargs"])
202          return default_from_dict(cls, data)
203  
204      def warm_up(self) -> None:
205          """
206          Initializes the component.
207          """
208          if self.embedding_backend is None:
209              self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
210                  model=self.model,
211                  device=self.device.to_torch_str(),
212                  auth_token=self.token,
213                  trust_remote_code=self.trust_remote_code,
214                  revision=self.revision,
215                  local_files_only=self.local_files_only,
216                  truncate_dim=self.truncate_dim,
217                  model_kwargs=self.model_kwargs,
218                  tokenizer_kwargs=self.tokenizer_kwargs,
219                  config_kwargs=self.config_kwargs,
220                  backend=self.backend,
221              )
222              if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"):
223                  self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"]
224  
225      @component.output_types(documents=list[Document])
226      def run(self, documents: list[Document]) -> dict[str, list[Document]]:
227          """
228          Embed a list of documents.
229  
230          :param documents:
231              Documents to embed.
232  
233          :returns:
234              A dictionary with the following keys:
235              - `documents`: Documents with embeddings.
236          """
237          if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
238              raise TypeError(
239                  "SentenceTransformersDocumentEmbedder expects a list of Documents as input."
240                  "In case you want to embed a string, please use the SentenceTransformersTextEmbedder."
241              )
242          if self.embedding_backend is None:
243              self.warm_up()
244  
245          texts_to_embed = []
246          for doc in documents:
247              meta_values_to_embed = [
248                  str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key]
249              ]
250              text_to_embed = (
251                  self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix
252              )
253              texts_to_embed.append(text_to_embed)
254  
255          # # mypy doesn't know this is set in warm_up
256          embeddings = self.embedding_backend.embed(  # type: ignore[union-attr]
257              texts_to_embed,
258              batch_size=self.batch_size,
259              show_progress_bar=self.progress_bar,
260              normalize_embeddings=self.normalize_embeddings,
261              precision=self.precision,
262              **(self.encode_kwargs if self.encode_kwargs else {}),
263          )
264  
265          new_documents = []
266          for doc, emb in zip(documents, embeddings, strict=True):
267              new_documents.append(replace(doc, embedding=emb))
268  
269          return {"documents": new_documents}