sentence_transformers_sparse_text_embedder.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 from typing import Any, Literal 6 7 from haystack import component, default_from_dict, default_to_dict 8 from haystack.components.embedders.backends.sentence_transformers_sparse_backend import ( 9 _SentenceTransformersSparseEmbeddingBackendFactory, 10 _SentenceTransformersSparseEncoderEmbeddingBackend, 11 ) 12 from haystack.dataclasses.sparse_embedding import SparseEmbedding 13 from haystack.utils import ComponentDevice, Secret 14 from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs 15 16 17 @component 18 class SentenceTransformersSparseTextEmbedder: 19 """ 20 Embeds strings using sparse embedding models from Sentence Transformers. 21 22 You can use it to embed user query and send it to a sparse embedding retriever. 23 24 Usage example: 25 <!-- test-ignore --> 26 ```python 27 from haystack.components.embedders import SentenceTransformersSparseTextEmbedder 28 29 text_to_embed = "I love pizza!" 30 31 text_embedder = SentenceTransformersSparseTextEmbedder() 32 33 print(text_embedder.run(text_to_embed)) 34 35 # {'sparse_embedding': SparseEmbedding(indices=[999, 1045, ...], values=[0.918, 0.867, ...])} 36 ``` 37 """ 38 39 def __init__( # noqa: PLR0913 40 self, 41 *, 42 model: str = "prithivida/Splade_PP_en_v2", 43 device: ComponentDevice | None = None, 44 token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False), 45 prefix: str = "", 46 suffix: str = "", 47 trust_remote_code: bool = False, 48 local_files_only: bool = False, 49 model_kwargs: dict[str, Any] | None = None, 50 tokenizer_kwargs: dict[str, Any] | None = None, 51 config_kwargs: dict[str, Any] | None = None, 52 backend: Literal["torch", "onnx", "openvino"] = "torch", 53 revision: str | None = None, 54 ) -> None: 55 """ 56 Create a SentenceTransformersSparseTextEmbedder component. 57 58 :param model: 59 The model to use for calculating sparse embeddings. 60 Specify the path to a local model or the ID of the model on Hugging Face. 61 :param device: 62 Overrides the default device used to load the model. 63 :param token: 64 An API token to use private models from Hugging Face. 65 :param prefix: 66 A string to add at the beginning of each text to be embedded. 67 :param suffix: 68 A string to add at the end of each text to embed. 69 :param trust_remote_code: 70 If `False`, permits only Hugging Face verified model architectures. 71 If `True`, permits custom models and scripts. 72 :param local_files_only: 73 If `True`, does not attempt to download the model from Hugging Face Hub and only looks at local files. 74 :param model_kwargs: 75 Additional keyword arguments for `AutoModelForSequenceClassification.from_pretrained` 76 when loading the model. Refer to specific model documentation for available kwargs. 77 :param tokenizer_kwargs: 78 Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer. 79 Refer to specific model documentation for available kwargs. 80 :param config_kwargs: 81 Additional keyword arguments for `AutoConfig.from_pretrained` when loading the model configuration. 82 :param backend: 83 The backend to use for the Sentence Transformers model. Choose from "torch", "onnx", or "openvino". 84 Refer to the [Sentence Transformers documentation](https://sbert.net/docs/sentence_transformer/usage/efficiency.html) 85 for more information on acceleration and quantization options. 86 :param revision: 87 The specific model version to use. It can be a branch name, a tag name, or a commit id, 88 for a stored model on Hugging Face. 89 """ 90 91 self.model = model 92 self.device = ComponentDevice.resolve_device(device) 93 self.token = token 94 self.prefix = prefix 95 self.suffix = suffix 96 self.trust_remote_code = trust_remote_code 97 self.revision = revision 98 self.local_files_only = local_files_only 99 self.model_kwargs = model_kwargs 100 self.tokenizer_kwargs = tokenizer_kwargs 101 self.config_kwargs = config_kwargs 102 self.embedding_backend: _SentenceTransformersSparseEncoderEmbeddingBackend | None = None 103 self.backend = backend 104 105 def _get_telemetry_data(self) -> dict[str, Any]: 106 """ 107 Data that is sent to Posthog for usage analytics. 108 """ 109 return {"model": self.model} 110 111 def to_dict(self) -> dict[str, Any]: 112 """ 113 Serializes the component to a dictionary. 114 115 :returns: 116 Dictionary with serialized data. 117 """ 118 serialization_dict = default_to_dict( 119 self, 120 model=self.model, 121 device=self.device, 122 token=self.token, 123 prefix=self.prefix, 124 suffix=self.suffix, 125 trust_remote_code=self.trust_remote_code, 126 revision=self.revision, 127 local_files_only=self.local_files_only, 128 model_kwargs=self.model_kwargs, 129 tokenizer_kwargs=self.tokenizer_kwargs, 130 config_kwargs=self.config_kwargs, 131 backend=self.backend, 132 ) 133 if serialization_dict["init_parameters"].get("model_kwargs") is not None: 134 serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"]) 135 return serialization_dict 136 137 @classmethod 138 def from_dict(cls, data: dict[str, Any]) -> "SentenceTransformersSparseTextEmbedder": 139 """ 140 Deserializes the component from a dictionary. 141 142 :param data: 143 Dictionary to deserialize from. 144 :returns: 145 Deserialized component. 146 """ 147 init_params = data["init_parameters"] 148 if init_params.get("model_kwargs") is not None: 149 deserialize_hf_model_kwargs(init_params["model_kwargs"]) 150 return default_from_dict(cls, data) 151 152 def warm_up(self) -> None: 153 """ 154 Initializes the component. 155 """ 156 if self.embedding_backend is None: 157 self.embedding_backend = _SentenceTransformersSparseEmbeddingBackendFactory.get_embedding_backend( 158 model=self.model, 159 device=self.device.to_torch_str(), 160 auth_token=self.token, 161 trust_remote_code=self.trust_remote_code, 162 revision=self.revision, 163 local_files_only=self.local_files_only, 164 model_kwargs=self.model_kwargs, 165 tokenizer_kwargs=self.tokenizer_kwargs, 166 config_kwargs=self.config_kwargs, 167 backend=self.backend, 168 ) 169 if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"): 170 self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"] 171 172 @component.output_types(sparse_embedding=SparseEmbedding) 173 def run(self, text: str) -> dict[str, Any]: 174 """ 175 Embed a single string. 176 177 :param text: 178 Text to embed. 179 180 :returns: 181 A dictionary with the following keys: 182 - `sparse_embedding`: The sparse embedding of the input text. 183 """ 184 if not isinstance(text, str): 185 raise TypeError( 186 "SentenceTransformersSparseTextEmbedder expects a string as input." 187 "In case you want to embed a list of Documents, please use the" 188 "SentenceTransformersSparseDocumentEmbedder." 189 ) 190 if self.embedding_backend is None: 191 self.warm_up() 192 193 text_to_embed = self.prefix + text + self.suffix 194 195 # mypy doesn't know this is set in warm_up 196 sparse_embedding = self.embedding_backend.embed(data=[text_to_embed])[0] # type: ignore[union-attr] 197 198 return {"sparse_embedding": sparse_embedding}