/ test / components / embedders / test_sentence_transformers_sparse_document_embedder.py
test_sentence_transformers_sparse_document_embedder.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  from unittest.mock import MagicMock, patch
  6  
  7  import pytest
  8  import torch
  9  
 10  from haystack import Document
 11  from haystack.components.embedders.sentence_transformers_sparse_document_embedder import (
 12      SentenceTransformersSparseDocumentEmbedder,
 13  )
 14  from haystack.dataclasses.sparse_embedding import SparseEmbedding
 15  from haystack.utils import ComponentDevice, Secret
 16  
 17  TYPE_NAME = (
 18      "haystack.components.embedders.sentence_transformers_sparse_document_embedder."
 19      "SentenceTransformersSparseDocumentEmbedder"
 20  )
 21  
 22  
 23  class TestSentenceTransformersDocumentEmbedder:
 24      def test_init_default(self):
 25          embedder = SentenceTransformersSparseDocumentEmbedder(model="model")
 26          assert embedder.model == "model"
 27          assert embedder.device == ComponentDevice.resolve_device(None)
 28          assert embedder.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
 29          assert embedder.prefix == ""
 30          assert embedder.suffix == ""
 31          assert embedder.batch_size == 32
 32          assert embedder.progress_bar is True
 33          assert embedder.meta_fields_to_embed == []
 34          assert embedder.embedding_separator == "\n"
 35          assert embedder.trust_remote_code is False
 36          assert embedder.revision is None
 37          assert embedder.local_files_only is False
 38  
 39      def test_init_with_parameters(self):
 40          embedder = SentenceTransformersSparseDocumentEmbedder(
 41              model="model",
 42              device=ComponentDevice.from_str("cuda:0"),
 43              token=Secret.from_token("fake-api-token"),
 44              prefix="prefix",
 45              suffix="suffix",
 46              batch_size=64,
 47              progress_bar=False,
 48              meta_fields_to_embed=["test_field"],
 49              embedding_separator=" | ",
 50              trust_remote_code=True,
 51              revision="v1.0",
 52              local_files_only=True,
 53          )
 54          assert embedder.model == "model"
 55          assert embedder.device == ComponentDevice.from_str("cuda:0")
 56          assert embedder.token == Secret.from_token("fake-api-token")
 57          assert embedder.prefix == "prefix"
 58          assert embedder.suffix == "suffix"
 59          assert embedder.batch_size == 64
 60          assert embedder.progress_bar is False
 61          assert embedder.meta_fields_to_embed == ["test_field"]
 62          assert embedder.embedding_separator == " | "
 63          assert embedder.trust_remote_code
 64          assert embedder.revision == "v1.0"
 65          assert embedder.local_files_only
 66  
 67      def test_to_dict(self):
 68          component = SentenceTransformersSparseDocumentEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
 69          data = component.to_dict()
 70          assert data == {
 71              "type": TYPE_NAME,
 72              "init_parameters": {
 73                  "model": "model",
 74                  "device": ComponentDevice.from_str("cpu").to_dict(),
 75                  "token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
 76                  "prefix": "",
 77                  "suffix": "",
 78                  "batch_size": 32,
 79                  "progress_bar": True,
 80                  "embedding_separator": "\n",
 81                  "meta_fields_to_embed": [],
 82                  "trust_remote_code": False,
 83                  "revision": None,
 84                  "local_files_only": False,
 85                  "model_kwargs": None,
 86                  "tokenizer_kwargs": None,
 87                  "config_kwargs": None,
 88                  "backend": "torch",
 89              },
 90          }
 91  
 92      def test_to_dict_with_custom_init_parameters(self):
 93          component = SentenceTransformersSparseDocumentEmbedder(
 94              model="model",
 95              device=ComponentDevice.from_str("cuda:0"),
 96              token=Secret.from_env_var("ENV_VAR", strict=False),
 97              prefix="prefix",
 98              suffix="suffix",
 99              batch_size=64,
100              progress_bar=False,
101              meta_fields_to_embed=["meta_field"],
102              embedding_separator=" - ",
103              trust_remote_code=True,
104              local_files_only=True,
105              model_kwargs={"torch_dtype": torch.float32},
106              tokenizer_kwargs={"model_max_length": 512},
107              config_kwargs={"use_memory_efficient_attention": True},
108          )
109          data = component.to_dict()
110  
111          assert data == {
112              "type": TYPE_NAME,
113              "init_parameters": {
114                  "model": "model",
115                  "device": ComponentDevice.from_str("cuda:0").to_dict(),
116                  "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
117                  "prefix": "prefix",
118                  "suffix": "suffix",
119                  "batch_size": 64,
120                  "progress_bar": False,
121                  "embedding_separator": " - ",
122                  "trust_remote_code": True,
123                  "revision": None,
124                  "local_files_only": True,
125                  "meta_fields_to_embed": ["meta_field"],
126                  "model_kwargs": {"torch_dtype": "torch.float32"},
127                  "tokenizer_kwargs": {"model_max_length": 512},
128                  "config_kwargs": {"use_memory_efficient_attention": True},
129                  "backend": "torch",
130              },
131          }
132  
133      def test_from_dict(self):
134          init_parameters = {
135              "model": "model",
136              "device": ComponentDevice.from_str("cuda:0").to_dict(),
137              "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
138              "prefix": "prefix",
139              "suffix": "suffix",
140              "batch_size": 64,
141              "progress_bar": False,
142              "embedding_separator": " - ",
143              "meta_fields_to_embed": ["meta_field"],
144              "trust_remote_code": True,
145              "revision": "v1.0",
146              "local_files_only": True,
147              "model_kwargs": {"torch_dtype": "torch.float32"},
148              "tokenizer_kwargs": {"model_max_length": 512},
149              "config_kwargs": {"use_memory_efficient_attention": True},
150          }
151          component = SentenceTransformersSparseDocumentEmbedder.from_dict(
152              {"type": TYPE_NAME, "init_parameters": init_parameters}
153          )
154          assert component.model == "model"
155          assert component.device == ComponentDevice.from_str("cuda:0")
156          assert component.token == Secret.from_env_var("ENV_VAR", strict=False)
157          assert component.prefix == "prefix"
158          assert component.suffix == "suffix"
159          assert component.batch_size == 64
160          assert component.progress_bar is False
161          assert component.embedding_separator == " - "
162          assert component.trust_remote_code
163          assert component.revision == "v1.0"
164          assert component.local_files_only
165          assert component.meta_fields_to_embed == ["meta_field"]
166          assert component.model_kwargs == {"torch_dtype": torch.float32}
167          assert component.tokenizer_kwargs == {"model_max_length": 512}
168          assert component.config_kwargs == {"use_memory_efficient_attention": True}
169  
170      def test_from_dict_no_default_parameters(self):
171          component = SentenceTransformersSparseDocumentEmbedder.from_dict({"type": TYPE_NAME, "init_parameters": {}})
172          assert component.model == "prithivida/Splade_PP_en_v2"
173          assert component.device == ComponentDevice.resolve_device(None)
174          assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
175          assert component.prefix == ""
176          assert component.suffix == ""
177          assert component.batch_size == 32
178          assert component.progress_bar is True
179          assert component.embedding_separator == "\n"
180          assert component.trust_remote_code is False
181          assert component.revision is None
182          assert component.local_files_only is False
183          assert component.meta_fields_to_embed == []
184  
185      def test_from_dict_none_device(self):
186          init_parameters = {
187              "model": "model",
188              "device": None,
189              "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
190              "prefix": "prefix",
191              "suffix": "suffix",
192              "batch_size": 64,
193              "progress_bar": False,
194              "embedding_separator": " - ",
195              "meta_fields_to_embed": ["meta_field"],
196              "trust_remote_code": True,
197              "local_files_only": False,
198          }
199          component = SentenceTransformersSparseDocumentEmbedder.from_dict(
200              {"type": TYPE_NAME, "init_parameters": init_parameters}
201          )
202          assert component.model == "model"
203          assert component.device == ComponentDevice.resolve_device(None)
204          assert component.token == Secret.from_env_var("ENV_VAR", strict=False)
205          assert component.prefix == "prefix"
206          assert component.suffix == "suffix"
207          assert component.batch_size == 64
208          assert component.progress_bar is False
209          assert component.embedding_separator == " - "
210          assert component.trust_remote_code
211          assert component.revision is None
212          assert component.local_files_only is False
213          assert component.meta_fields_to_embed == ["meta_field"]
214  
215      @patch(
216          "haystack.components.embedders.sentence_transformers_sparse_document_embedder._SentenceTransformersSparseEmbeddingBackendFactory"
217      )
218      def test_warmup(self, mocked_factory):
219          embedder = SentenceTransformersSparseDocumentEmbedder(
220              model="model",
221              token=None,
222              device=ComponentDevice.from_str("cpu"),
223              tokenizer_kwargs={"model_max_length": 512},
224              config_kwargs={"use_memory_efficient_attention": True},
225          )
226          mocked_factory.get_embedding_backend.assert_not_called()
227          embedder.warm_up()
228          embedder.embedding_backend.model.max_seq_length = 512
229          mocked_factory.get_embedding_backend.assert_called_once_with(
230              model="model",
231              device="cpu",
232              auth_token=None,
233              trust_remote_code=False,
234              revision=None,
235              local_files_only=False,
236              model_kwargs=None,
237              tokenizer_kwargs={"model_max_length": 512},
238              config_kwargs={"use_memory_efficient_attention": True},
239              backend="torch",
240          )
241  
242      @patch(
243          "haystack.components.embedders.sentence_transformers_sparse_document_embedder._SentenceTransformersSparseEmbeddingBackendFactory"
244      )
245      def test_warmup_doesnt_reload(self, mocked_factory):
246          embedder = SentenceTransformersSparseDocumentEmbedder(model="model")
247          mocked_factory.get_embedding_backend.assert_not_called()
248          embedder.warm_up()
249          embedder.warm_up()
250          mocked_factory.get_embedding_backend.assert_called_once()
251  
252      def test_run(self):
253          embedder = SentenceTransformersSparseDocumentEmbedder(model="model")
254          embedder.embedding_backend = MagicMock()
255  
256          def fake_embed(data, **kwargs):
257              return [SparseEmbedding(indices=[0, 2, 5], values=[0.1, 0.2, 0.3]) for _ in range(len(data))]
258  
259          embedder.embedding_backend.embed = fake_embed
260  
261          documents = [Document(content=f"document number {i}") for i in range(5)]
262  
263          result = embedder.run(documents=documents)
264  
265          assert isinstance(result["documents"], list)
266          assert len(result["documents"]) == len(documents)
267          for doc in result["documents"]:
268              assert isinstance(doc, Document)
269              assert isinstance(doc.sparse_embedding, SparseEmbedding)
270              assert isinstance(doc.sparse_embedding.indices[0], int)
271              assert isinstance(doc.sparse_embedding.values[0], float)
272  
273      def test_run_wrong_input_format(self):
274          embedder = SentenceTransformersSparseDocumentEmbedder(model="model")
275  
276          string_input = "text"
277          list_integers_input = [1, 2, 3]
278  
279          with pytest.raises(
280              TypeError, match="SentenceTransformersSparseDocumentEmbedder expects a list of Documents as input"
281          ):
282              embedder.run(documents=string_input)
283  
284          with pytest.raises(
285              TypeError, match="SentenceTransformersSparseDocumentEmbedder expects a list of Documents as input"
286          ):
287              embedder.run(documents=list_integers_input)
288  
289      def test_embed_metadata(self):
290          embedder = SentenceTransformersSparseDocumentEmbedder(
291              model="model", meta_fields_to_embed=["meta_field"], embedding_separator="\n"
292          )
293          embedder.embedding_backend = MagicMock()
294          embedder.embedding_backend.embed.return_value = [
295              SparseEmbedding(indices=[0, 2, 5], values=[0.1, 0.2, 0.3]) for _ in range(5)
296          ]
297          documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)]
298          embedder.run(documents=documents)
299          embedder.embedding_backend.embed.assert_called_once_with(
300              data=[
301                  "meta_value 0\ndocument number 0",
302                  "meta_value 1\ndocument number 1",
303                  "meta_value 2\ndocument number 2",
304                  "meta_value 3\ndocument number 3",
305                  "meta_value 4\ndocument number 4",
306              ],
307              batch_size=32,
308              show_progress_bar=True,
309          )
310  
311      def test_prefix_suffix(self):
312          embedder = SentenceTransformersSparseDocumentEmbedder(
313              model="model",
314              prefix="my_prefix ",
315              suffix=" my_suffix",
316              meta_fields_to_embed=["meta_field"],
317              embedding_separator="\n",
318          )
319          embedder.embedding_backend = MagicMock()
320          embedder.embedding_backend.embed.return_value = [
321              SparseEmbedding(indices=[0, 2, 5], values=[0.1, 0.2, 0.3]) for _ in range(5)
322          ]
323          documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)]
324          embedder.run(documents=documents)
325          embedder.embedding_backend.embed.assert_called_once_with(
326              data=[
327                  "my_prefix meta_value 0\ndocument number 0 my_suffix",
328                  "my_prefix meta_value 1\ndocument number 1 my_suffix",
329                  "my_prefix meta_value 2\ndocument number 2 my_suffix",
330                  "my_prefix meta_value 3\ndocument number 3 my_suffix",
331                  "my_prefix meta_value 4\ndocument number 4 my_suffix",
332              ],
333              batch_size=32,
334              show_progress_bar=True,
335          )
336  
337      @patch(
338          "haystack.components.embedders.sentence_transformers_sparse_document_embedder._SentenceTransformersSparseEmbeddingBackendFactory"
339      )
340      def test_model_onnx_backend(self, mocked_factory):
341          onnx_embedder = SentenceTransformersSparseDocumentEmbedder(
342              model="prithivida/Splade_PP_en_v2",
343              token=None,
344              device=ComponentDevice.from_str("cpu"),
345              model_kwargs={
346                  "file_name": "onnx/model.onnx"
347              },  # setting the path isn't necessary if the repo contains a "onnx/model.onnx" file but this is to
348              # prevent a HF warning
349              backend="onnx",
350          )
351          onnx_embedder.warm_up()
352  
353          mocked_factory.get_embedding_backend.assert_called_once_with(
354              model="prithivida/Splade_PP_en_v2",
355              device="cpu",
356              auth_token=None,
357              trust_remote_code=False,
358              revision=None,
359              local_files_only=False,
360              model_kwargs={"file_name": "onnx/model.onnx"},
361              tokenizer_kwargs=None,
362              config_kwargs=None,
363              backend="onnx",
364          )
365  
366      @patch(
367          "haystack.components.embedders.sentence_transformers_sparse_document_embedder._SentenceTransformersSparseEmbeddingBackendFactory"
368      )
369      def test_model_openvino_backend(self, mocked_factory):
370          openvino_embedder = SentenceTransformersSparseDocumentEmbedder(
371              model="prithivida/Splade_PP_en_v2",
372              token=None,
373              device=ComponentDevice.from_str("cpu"),
374              model_kwargs={
375                  "file_name": "openvino/openvino_model.xml"
376              },  # setting the path isn't necessary if the repo contains a "openvino/openvino_model.xml" file but this
377              # is to prevent a HF warning
378              backend="openvino",
379          )
380          openvino_embedder.warm_up()
381  
382          mocked_factory.get_embedding_backend.assert_called_once_with(
383              model="prithivida/Splade_PP_en_v2",
384              device="cpu",
385              auth_token=None,
386              trust_remote_code=False,
387              revision=None,
388              local_files_only=False,
389              model_kwargs={"file_name": "openvino/openvino_model.xml"},
390              tokenizer_kwargs=None,
391              config_kwargs=None,
392              backend="openvino",
393          )
394  
395      @patch(
396          "haystack.components.embedders.sentence_transformers_sparse_document_embedder._SentenceTransformersSparseEmbeddingBackendFactory"
397      )
398      @pytest.mark.parametrize("model_kwargs", [{"torch_dtype": "bfloat16"}, {"torch_dtype": "float16"}])
399      def test_dtype_on_gpu(self, mocked_factory, model_kwargs):
400          torch_dtype_embedder = SentenceTransformersSparseDocumentEmbedder(
401              model="prithivida/Splade_PP_en_v2",
402              token=None,
403              device=ComponentDevice.from_str("cuda:0"),
404              model_kwargs=model_kwargs,
405          )
406          torch_dtype_embedder.warm_up()
407  
408          mocked_factory.get_embedding_backend.assert_called_once_with(
409              model="prithivida/Splade_PP_en_v2",
410              device="cuda:0",
411              auth_token=None,
412              trust_remote_code=False,
413              revision=None,
414              local_files_only=False,
415              model_kwargs=model_kwargs,
416              tokenizer_kwargs=None,
417              config_kwargs=None,
418              backend="torch",
419          )
420  
421      @pytest.mark.integration
422      @pytest.mark.slow
423      @pytest.mark.flaky(reruns=3, reruns_delay=10)
424      def test_live_run_sparse_document_embedder(self, del_hf_env_vars):
425          docs = [
426              Document(content="I love cheese", meta={"topic": "Cuisine"}),
427              Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
428          ]
429  
430          embedder = SentenceTransformersSparseDocumentEmbedder(
431              model="sparse-encoder-testing/splade-bert-tiny-nq",
432              meta_fields_to_embed=["topic"],
433              embedding_separator=" | ",
434              device=ComponentDevice.from_str("cpu"),
435          )
436          result = embedder.run(documents=docs)
437          documents_with_embeddings = result["documents"]
438  
439          assert isinstance(documents_with_embeddings, list)
440          assert len(documents_with_embeddings) == len(docs)
441          for doc in documents_with_embeddings:
442              assert isinstance(doc, Document)
443              assert hasattr(doc, "sparse_embedding")
444              assert isinstance(doc.sparse_embedding, SparseEmbedding)
445              assert isinstance(doc.sparse_embedding.indices, list)
446              assert isinstance(doc.sparse_embedding.values, list)
447              assert len(doc.sparse_embedding.indices) == len(doc.sparse_embedding.values)
448              # Expect at least one non-zero entry
449              assert len(doc.sparse_embedding.indices) > 0