/ test / components / embedders / test_sentence_transformers_document_embedder.py
test_sentence_transformers_document_embedder.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import random
  6  from unittest.mock import MagicMock, patch
  7  
  8  import pytest
  9  import torch
 10  
 11  from haystack import Document
 12  from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder
 13  from haystack.utils import ComponentDevice, Secret
 14  
 15  
 16  class TestSentenceTransformersDocumentEmbedder:
 17      def test_init_default(self):
 18          embedder = SentenceTransformersDocumentEmbedder(model="model")
 19          assert embedder.model == "model"
 20          assert embedder.device == ComponentDevice.resolve_device(None)
 21          assert embedder.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
 22          assert embedder.prefix == ""
 23          assert embedder.suffix == ""
 24          assert embedder.batch_size == 32
 25          assert embedder.progress_bar is True
 26          assert embedder.normalize_embeddings is False
 27          assert embedder.meta_fields_to_embed == []
 28          assert embedder.embedding_separator == "\n"
 29          assert embedder.trust_remote_code is False
 30          assert embedder.revision is None
 31          assert embedder.local_files_only is False
 32          assert embedder.truncate_dim is None
 33          assert embedder.precision == "float32"
 34  
 35      def test_init_with_parameters(self):
 36          embedder = SentenceTransformersDocumentEmbedder(
 37              model="model",
 38              device=ComponentDevice.from_str("cuda:0"),
 39              token=Secret.from_token("fake-api-token"),
 40              prefix="prefix",
 41              suffix="suffix",
 42              batch_size=64,
 43              progress_bar=False,
 44              normalize_embeddings=True,
 45              meta_fields_to_embed=["test_field"],
 46              embedding_separator=" | ",
 47              trust_remote_code=True,
 48              revision="v1.0",
 49              local_files_only=True,
 50              truncate_dim=256,
 51              precision="int8",
 52          )
 53          assert embedder.model == "model"
 54          assert embedder.device == ComponentDevice.from_str("cuda:0")
 55          assert embedder.token == Secret.from_token("fake-api-token")
 56          assert embedder.prefix == "prefix"
 57          assert embedder.suffix == "suffix"
 58          assert embedder.batch_size == 64
 59          assert embedder.progress_bar is False
 60          assert embedder.normalize_embeddings is True
 61          assert embedder.meta_fields_to_embed == ["test_field"]
 62          assert embedder.embedding_separator == " | "
 63          assert embedder.trust_remote_code
 64          assert embedder.revision == "v1.0"
 65          assert embedder.local_files_only
 66          assert embedder.truncate_dim == 256
 67          assert embedder.precision == "int8"
 68  
 69      def test_to_dict(self):
 70          component = SentenceTransformersDocumentEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
 71          data = component.to_dict()
 72          assert data == {
 73              "type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder",  # noqa: E501
 74              "init_parameters": {
 75                  "model": "model",
 76                  "device": ComponentDevice.from_str("cpu").to_dict(),
 77                  "token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
 78                  "prefix": "",
 79                  "suffix": "",
 80                  "batch_size": 32,
 81                  "progress_bar": True,
 82                  "normalize_embeddings": False,
 83                  "embedding_separator": "\n",
 84                  "meta_fields_to_embed": [],
 85                  "trust_remote_code": False,
 86                  "revision": None,
 87                  "local_files_only": False,
 88                  "truncate_dim": None,
 89                  "model_kwargs": None,
 90                  "tokenizer_kwargs": None,
 91                  "encode_kwargs": None,
 92                  "config_kwargs": None,
 93                  "precision": "float32",
 94                  "backend": "torch",
 95              },
 96          }
 97  
 98      def test_to_dict_with_custom_init_parameters(self):
 99          component = SentenceTransformersDocumentEmbedder(
100              model="model",
101              device=ComponentDevice.from_str("cuda:0"),
102              token=Secret.from_env_var("ENV_VAR", strict=False),
103              prefix="prefix",
104              suffix="suffix",
105              batch_size=64,
106              progress_bar=False,
107              normalize_embeddings=True,
108              meta_fields_to_embed=["meta_field"],
109              embedding_separator=" - ",
110              trust_remote_code=True,
111              local_files_only=True,
112              truncate_dim=256,
113              model_kwargs={"torch_dtype": torch.float32},
114              tokenizer_kwargs={"model_max_length": 512},
115              config_kwargs={"use_memory_efficient_attention": True},
116              precision="int8",
117              encode_kwargs={"task": "clustering"},
118          )
119          data = component.to_dict()
120  
121          assert data == {
122              "type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder",  # noqa: E501
123              "init_parameters": {
124                  "model": "model",
125                  "device": ComponentDevice.from_str("cuda:0").to_dict(),
126                  "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
127                  "prefix": "prefix",
128                  "suffix": "suffix",
129                  "batch_size": 64,
130                  "progress_bar": False,
131                  "normalize_embeddings": True,
132                  "embedding_separator": " - ",
133                  "trust_remote_code": True,
134                  "revision": None,
135                  "local_files_only": True,
136                  "meta_fields_to_embed": ["meta_field"],
137                  "truncate_dim": 256,
138                  "model_kwargs": {"torch_dtype": "torch.float32"},
139                  "tokenizer_kwargs": {"model_max_length": 512},
140                  "config_kwargs": {"use_memory_efficient_attention": True},
141                  "precision": "int8",
142                  "encode_kwargs": {"task": "clustering"},
143                  "backend": "torch",
144              },
145          }
146  
147      def test_from_dict(self):
148          init_parameters = {
149              "model": "model",
150              "device": ComponentDevice.from_str("cuda:0").to_dict(),
151              "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
152              "prefix": "prefix",
153              "suffix": "suffix",
154              "batch_size": 64,
155              "progress_bar": False,
156              "normalize_embeddings": True,
157              "embedding_separator": " - ",
158              "meta_fields_to_embed": ["meta_field"],
159              "trust_remote_code": True,
160              "revision": "v1.0",
161              "local_files_only": True,
162              "truncate_dim": 256,
163              "model_kwargs": {"torch_dtype": "torch.float32"},
164              "tokenizer_kwargs": {"model_max_length": 512},
165              "config_kwargs": {"use_memory_efficient_attention": True},
166              "precision": "int8",
167          }
168          component = SentenceTransformersDocumentEmbedder.from_dict(
169              {
170                  "type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder",  # noqa: E501
171                  "init_parameters": init_parameters,
172              }
173          )
174          assert component.model == "model"
175          assert component.device == ComponentDevice.from_str("cuda:0")
176          assert component.token == Secret.from_env_var("ENV_VAR", strict=False)
177          assert component.prefix == "prefix"
178          assert component.suffix == "suffix"
179          assert component.batch_size == 64
180          assert component.progress_bar is False
181          assert component.normalize_embeddings is True
182          assert component.embedding_separator == " - "
183          assert component.trust_remote_code
184          assert component.revision == "v1.0"
185          assert component.local_files_only
186          assert component.meta_fields_to_embed == ["meta_field"]
187          assert component.truncate_dim == 256
188          assert component.model_kwargs == {"torch_dtype": torch.float32}
189          assert component.tokenizer_kwargs == {"model_max_length": 512}
190          assert component.config_kwargs == {"use_memory_efficient_attention": True}
191          assert component.precision == "int8"
192  
193      def test_from_dict_no_default_parameters(self):
194          component = SentenceTransformersDocumentEmbedder.from_dict(
195              {
196                  "type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder",  # noqa: E501
197                  "init_parameters": {},
198              }
199          )
200          assert component.model == "sentence-transformers/all-mpnet-base-v2"
201          assert component.device == ComponentDevice.resolve_device(None)
202          assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
203          assert component.prefix == ""
204          assert component.suffix == ""
205          assert component.batch_size == 32
206          assert component.progress_bar is True
207          assert component.normalize_embeddings is False
208          assert component.embedding_separator == "\n"
209          assert component.trust_remote_code is False
210          assert component.revision is None
211          assert component.local_files_only is False
212          assert component.meta_fields_to_embed == []
213          assert component.truncate_dim is None
214          assert component.precision == "float32"
215  
216      def test_from_dict_none_device(self):
217          init_parameters = {
218              "model": "model",
219              "device": None,
220              "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
221              "prefix": "prefix",
222              "suffix": "suffix",
223              "batch_size": 64,
224              "progress_bar": False,
225              "normalize_embeddings": True,
226              "embedding_separator": " - ",
227              "meta_fields_to_embed": ["meta_field"],
228              "trust_remote_code": True,
229              "local_files_only": False,
230              "truncate_dim": None,
231              "precision": "float32",
232          }
233          component = SentenceTransformersDocumentEmbedder.from_dict(
234              {
235                  "type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder",  # noqa: E501
236                  "init_parameters": init_parameters,
237              }
238          )
239          assert component.model == "model"
240          assert component.device == ComponentDevice.resolve_device(None)
241          assert component.token == Secret.from_env_var("ENV_VAR", strict=False)
242          assert component.prefix == "prefix"
243          assert component.suffix == "suffix"
244          assert component.batch_size == 64
245          assert component.progress_bar is False
246          assert component.normalize_embeddings is True
247          assert component.embedding_separator == " - "
248          assert component.trust_remote_code
249          assert component.revision is None
250          assert component.local_files_only is False
251          assert component.meta_fields_to_embed == ["meta_field"]
252          assert component.truncate_dim is None
253          assert component.precision == "float32"
254  
255      @patch(
256          "haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
257      )
258      def test_warmup(self, mocked_factory):
259          embedder = SentenceTransformersDocumentEmbedder(
260              model="model",
261              token=None,
262              device=ComponentDevice.from_str("cpu"),
263              tokenizer_kwargs={"model_max_length": 512},
264              config_kwargs={"use_memory_efficient_attention": True},
265          )
266          mocked_factory.get_embedding_backend.assert_not_called()
267          embedder.warm_up()
268          embedder.embedding_backend.model.max_seq_length = 512
269          mocked_factory.get_embedding_backend.assert_called_once_with(
270              model="model",
271              device="cpu",
272              auth_token=None,
273              trust_remote_code=False,
274              revision=None,
275              local_files_only=False,
276              truncate_dim=None,
277              model_kwargs=None,
278              tokenizer_kwargs={"model_max_length": 512},
279              config_kwargs={"use_memory_efficient_attention": True},
280              backend="torch",
281          )
282  
283      @patch(
284          "haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
285      )
286      def test_warmup_doesnt_reload(self, mocked_factory):
287          embedder = SentenceTransformersDocumentEmbedder(model="model")
288          mocked_factory.get_embedding_backend.assert_not_called()
289          embedder.warm_up()
290          embedder.warm_up()
291          mocked_factory.get_embedding_backend.assert_called_once()
292  
293      def test_run(self):
294          embedder = SentenceTransformersDocumentEmbedder(model="model")
295          embedder.embedding_backend = MagicMock()
296          embedder.embedding_backend.embed = lambda x, **_: [[random.random() for _ in range(16)] for _ in range(len(x))]
297  
298          documents = [Document(content=f"document number {i}") for i in range(5)]
299  
300          result = embedder.run(documents=documents)
301  
302          assert isinstance(result["documents"], list)
303          assert len(result["documents"]) == len(documents)
304          for doc, new_doc in zip(documents, result["documents"], strict=True):
305              assert new_doc is not doc
306              assert doc.embedding is None
307              assert isinstance(new_doc, Document)
308              assert isinstance(new_doc.embedding, list)
309              assert isinstance(new_doc.embedding[0], float)
310  
311      def test_run_wrong_input_format(self):
312          embedder = SentenceTransformersDocumentEmbedder(model="model")
313  
314          string_input = "text"
315          list_integers_input = [1, 2, 3]
316  
317          with pytest.raises(
318              TypeError, match="SentenceTransformersDocumentEmbedder expects a list of Documents as input"
319          ):
320              embedder.run(documents=string_input)
321  
322          with pytest.raises(
323              TypeError, match="SentenceTransformersDocumentEmbedder expects a list of Documents as input"
324          ):
325              embedder.run(documents=list_integers_input)
326  
327      def test_embed_metadata(self):
328          embedder = SentenceTransformersDocumentEmbedder(
329              model="model", meta_fields_to_embed=["meta_field"], embedding_separator="\n"
330          )
331          embedder.embedding_backend = MagicMock()
332          embedder.embedding_backend.embed.return_value = [[random.random() for _ in range(16)] for _ in range(5)]
333          documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)]
334          embedder.run(documents=documents)
335          embedder.embedding_backend.embed.assert_called_once_with(
336              [
337                  "meta_value 0\ndocument number 0",
338                  "meta_value 1\ndocument number 1",
339                  "meta_value 2\ndocument number 2",
340                  "meta_value 3\ndocument number 3",
341                  "meta_value 4\ndocument number 4",
342              ],
343              batch_size=32,
344              show_progress_bar=True,
345              normalize_embeddings=False,
346              precision="float32",
347          )
348  
349      def test_embed_encode_kwargs(self):
350          embedder = SentenceTransformersDocumentEmbedder(model="model", encode_kwargs={"task": "retrieval.passage"})
351          embedder.embedding_backend = MagicMock()
352          embedder.embedding_backend.embed.return_value = [[random.random() for _ in range(16)] for _ in range(5)]
353          documents = [Document(content=f"document number {i}") for i in range(5)]
354          embedder.run(documents=documents)
355          embedder.embedding_backend.embed.assert_called_once_with(
356              ["document number 0", "document number 1", "document number 2", "document number 3", "document number 4"],
357              batch_size=32,
358              show_progress_bar=True,
359              normalize_embeddings=False,
360              precision="float32",
361              task="retrieval.passage",
362          )
363  
364      def test_prefix_suffix(self):
365          embedder = SentenceTransformersDocumentEmbedder(
366              model="model",
367              prefix="my_prefix ",
368              suffix=" my_suffix",
369              meta_fields_to_embed=["meta_field"],
370              embedding_separator="\n",
371          )
372          embedder.embedding_backend = MagicMock()
373          embedder.embedding_backend.embed.return_value = [[random.random() for _ in range(16)] for _ in range(5)]
374          documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)]
375          embedder.run(documents=documents)
376          embedder.embedding_backend.embed.assert_called_once_with(
377              [
378                  "my_prefix meta_value 0\ndocument number 0 my_suffix",
379                  "my_prefix meta_value 1\ndocument number 1 my_suffix",
380                  "my_prefix meta_value 2\ndocument number 2 my_suffix",
381                  "my_prefix meta_value 3\ndocument number 3 my_suffix",
382                  "my_prefix meta_value 4\ndocument number 4 my_suffix",
383              ],
384              batch_size=32,
385              show_progress_bar=True,
386              normalize_embeddings=False,
387              precision="float32",
388          )
389  
390      @patch(
391          "haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
392      )
393      def test_model_onnx_backend(self, mocked_factory):
394          onnx_embedder = SentenceTransformersDocumentEmbedder(
395              model="sentence-transformers/all-MiniLM-L6-v2",
396              token=None,
397              device=ComponentDevice.from_str("cpu"),
398              # setting the path isn't necessary if the repo contains a "onnx/model.onnx" file but this is to prevent
399              # a HF warning
400              model_kwargs={"file_name": "onnx/model.onnx"},
401              backend="onnx",
402          )
403          onnx_embedder.warm_up()
404  
405          mocked_factory.get_embedding_backend.assert_called_once_with(
406              model="sentence-transformers/all-MiniLM-L6-v2",
407              device="cpu",
408              auth_token=None,
409              trust_remote_code=False,
410              revision=None,
411              local_files_only=False,
412              truncate_dim=None,
413              model_kwargs={"file_name": "onnx/model.onnx"},
414              tokenizer_kwargs=None,
415              config_kwargs=None,
416              backend="onnx",
417          )
418  
419      @patch(
420          "haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
421      )
422      def test_model_openvino_backend(self, mocked_factory):
423          openvino_embedder = SentenceTransformersDocumentEmbedder(
424              model="sentence-transformers/all-MiniLM-L6-v2",
425              token=None,
426              device=ComponentDevice.from_str("cpu"),
427              # setting the path isn't necessary if the repo contains a "openvino/openvino_model.xml" file but this is
428              # to prevent a HF warning
429              model_kwargs={"file_name": "openvino/openvino_model.xml"},
430              backend="openvino",
431          )
432          openvino_embedder.warm_up()
433  
434          mocked_factory.get_embedding_backend.assert_called_once_with(
435              model="sentence-transformers/all-MiniLM-L6-v2",
436              device="cpu",
437              auth_token=None,
438              trust_remote_code=False,
439              revision=None,
440              local_files_only=False,
441              truncate_dim=None,
442              model_kwargs={"file_name": "openvino/openvino_model.xml"},
443              tokenizer_kwargs=None,
444              config_kwargs=None,
445              backend="openvino",
446          )
447  
448      @patch(
449          "haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
450      )
451      @pytest.mark.parametrize("model_kwargs", [{"torch_dtype": "bfloat16"}, {"torch_dtype": "float16"}])
452      def test_dtype_on_gpu(self, mocked_factory, model_kwargs):
453          torch_dtype_embedder = SentenceTransformersDocumentEmbedder(
454              model="sentence-transformers/all-MiniLM-L6-v2",
455              token=None,
456              device=ComponentDevice.from_str("cuda:0"),
457              model_kwargs=model_kwargs,
458          )
459          torch_dtype_embedder.warm_up()
460  
461          mocked_factory.get_embedding_backend.assert_called_once_with(
462              model="sentence-transformers/all-MiniLM-L6-v2",
463              device="cuda:0",
464              auth_token=None,
465              trust_remote_code=False,
466              revision=None,
467              local_files_only=False,
468              truncate_dim=None,
469              model_kwargs=model_kwargs,
470              tokenizer_kwargs=None,
471              config_kwargs=None,
472              backend="torch",
473          )