/ haystack / components / embedders / sentence_transformers_sparse_text_embedder.py
sentence_transformers_sparse_text_embedder.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  from typing import Any, Literal
  6  
  7  from haystack import component, default_from_dict, default_to_dict
  8  from haystack.components.embedders.backends.sentence_transformers_sparse_backend import (
  9      _SentenceTransformersSparseEmbeddingBackendFactory,
 10      _SentenceTransformersSparseEncoderEmbeddingBackend,
 11  )
 12  from haystack.dataclasses.sparse_embedding import SparseEmbedding
 13  from haystack.utils import ComponentDevice, Secret
 14  from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
 15  
 16  
 17  @component
 18  class SentenceTransformersSparseTextEmbedder:
 19      """
 20      Embeds strings using sparse embedding models from Sentence Transformers.
 21  
 22      You can use it to embed user query and send it to a sparse embedding retriever.
 23  
 24      Usage example:
 25      <!-- test-ignore -->
 26      ```python
 27      from haystack.components.embedders import SentenceTransformersSparseTextEmbedder
 28  
 29      text_to_embed = "I love pizza!"
 30  
 31      text_embedder = SentenceTransformersSparseTextEmbedder()
 32  
 33      print(text_embedder.run(text_to_embed))
 34  
 35      # {'sparse_embedding': SparseEmbedding(indices=[999, 1045, ...], values=[0.918, 0.867, ...])}
 36      ```
 37      """
 38  
 39      def __init__(  # noqa: PLR0913
 40          self,
 41          *,
 42          model: str = "prithivida/Splade_PP_en_v2",
 43          device: ComponentDevice | None = None,
 44          token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
 45          prefix: str = "",
 46          suffix: str = "",
 47          trust_remote_code: bool = False,
 48          local_files_only: bool = False,
 49          model_kwargs: dict[str, Any] | None = None,
 50          tokenizer_kwargs: dict[str, Any] | None = None,
 51          config_kwargs: dict[str, Any] | None = None,
 52          backend: Literal["torch", "onnx", "openvino"] = "torch",
 53          revision: str | None = None,
 54      ) -> None:
 55          """
 56          Create a SentenceTransformersSparseTextEmbedder component.
 57  
 58          :param model:
 59              The model to use for calculating sparse embeddings.
 60              Specify the path to a local model or the ID of the model on Hugging Face.
 61          :param device:
 62              Overrides the default device used to load the model.
 63          :param token:
 64              An API token to use private models from Hugging Face.
 65          :param prefix:
 66              A string to add at the beginning of each text to be embedded.
 67          :param suffix:
 68              A string to add at the end of each text to embed.
 69          :param trust_remote_code:
 70              If `False`, permits only Hugging Face verified model architectures.
 71              If `True`, permits custom models and scripts.
 72          :param local_files_only:
 73              If `True`, does not attempt to download the model from Hugging Face Hub and only looks at local files.
 74          :param model_kwargs:
 75              Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained`
 76              when loading the model. Refer to specific model documentation for available kwargs.
 77          :param tokenizer_kwargs:
 78              Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer.
 79              Refer to specific model documentation for available kwargs.
 80          :param config_kwargs:
 81              Additional keyword arguments for `AutoConfig.from_pretrained` when loading the model configuration.
 82          :param backend:
 83              The backend to use for the Sentence Transformers model. Choose from "torch", "onnx", or "openvino".
 84              Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html)
 85              for more information on acceleration and quantization options.
 86          :param revision:
 87              The specific model version to use. It can be a branch name, a tag name, or a commit id,
 88              for a stored model on Hugging Face.
 89          """
 90  
 91          self.model = model
 92          self.device = ComponentDevice.resolve_device(device)
 93          self.token = token
 94          self.prefix = prefix
 95          self.suffix = suffix
 96          self.trust_remote_code = trust_remote_code
 97          self.revision = revision
 98          self.local_files_only = local_files_only
 99          self.model_kwargs = model_kwargs
100          self.tokenizer_kwargs = tokenizer_kwargs
101          self.config_kwargs = config_kwargs
102          self.embedding_backend: _SentenceTransformersSparseEncoderEmbeddingBackend | None = None
103          self.backend = backend
104  
105      def _get_telemetry_data(self) -> dict[str, Any]:
106          """
107          Data that is sent to Posthog for usage analytics.
108          """
109          return {"model": self.model}
110  
111      def to_dict(self) -> dict[str, Any]:
112          """
113          Serializes the component to a dictionary.
114  
115          :returns:
116              Dictionary with serialized data.
117          """
118          serialization_dict = default_to_dict(
119              self,
120              model=self.model,
121              device=self.device,
122              token=self.token,
123              prefix=self.prefix,
124              suffix=self.suffix,
125              trust_remote_code=self.trust_remote_code,
126              revision=self.revision,
127              local_files_only=self.local_files_only,
128              model_kwargs=self.model_kwargs,
129              tokenizer_kwargs=self.tokenizer_kwargs,
130              config_kwargs=self.config_kwargs,
131              backend=self.backend,
132          )
133          if serialization_dict["init_parameters"].get("model_kwargs") is not None:
134              serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
135          return serialization_dict
136  
137      @classmethod
138      def from_dict(cls, data: dict[str, Any]) -> "SentenceTransformersSparseTextEmbedder":
139          """
140          Deserializes the component from a dictionary.
141  
142          :param data:
143              Dictionary to deserialize from.
144          :returns:
145              Deserialized component.
146          """
147          init_params = data["init_parameters"]
148          if init_params.get("model_kwargs") is not None:
149              deserialize_hf_model_kwargs(init_params["model_kwargs"])
150          return default_from_dict(cls, data)
151  
152      def warm_up(self) -> None:
153          """
154          Initializes the component.
155          """
156          if self.embedding_backend is None:
157              self.embedding_backend = _SentenceTransformersSparseEmbeddingBackendFactory.get_embedding_backend(
158                  model=self.model,
159                  device=self.device.to_torch_str(),
160                  auth_token=self.token,
161                  trust_remote_code=self.trust_remote_code,
162                  revision=self.revision,
163                  local_files_only=self.local_files_only,
164                  model_kwargs=self.model_kwargs,
165                  tokenizer_kwargs=self.tokenizer_kwargs,
166                  config_kwargs=self.config_kwargs,
167                  backend=self.backend,
168              )
169              if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"):
170                  self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"]
171  
172      @component.output_types(sparse_embedding=SparseEmbedding)
173      def run(self, text: str) -> dict[str, Any]:
174          """
175          Embed a single string.
176  
177          :param text:
178              Text to embed.
179  
180          :returns:
181              A dictionary with the following keys:
182              - `sparse_embedding`: The sparse embedding of the input text.
183          """
184          if not isinstance(text, str):
185              raise TypeError(
186                  "SentenceTransformersSparseTextEmbedder expects a string as input."
187                  "In case you want to embed a list of Documents, please use the"
188                  "SentenceTransformersSparseDocumentEmbedder."
189              )
190          if self.embedding_backend is None:
191              self.warm_up()
192  
193          text_to_embed = self.prefix + text + self.suffix
194  
195          # mypy doesn't know this is set in warm_up
196          sparse_embedding = self.embedding_backend.embed(data=[text_to_embed])[0]  # type: ignore[union-attr]
197  
198          return {"sparse_embedding": sparse_embedding}