test_sentence_transformers_document_embedder.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import random 6 from unittest.mock import MagicMock, patch 7 8 import pytest 9 import torch 10 11 from haystack import Document 12 from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder 13 from haystack.utils import ComponentDevice, Secret 14 15 16 class TestSentenceTransformersDocumentEmbedder: 17 def test_init_default(self): 18 embedder = SentenceTransformersDocumentEmbedder(model="model") 19 assert embedder.model == "model" 20 assert embedder.device == ComponentDevice.resolve_device(None) 21 assert embedder.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False) 22 assert embedder.prefix == "" 23 assert embedder.suffix == "" 24 assert embedder.batch_size == 32 25 assert embedder.progress_bar is True 26 assert embedder.normalize_embeddings is False 27 assert embedder.meta_fields_to_embed == [] 28 assert embedder.embedding_separator == "\n" 29 assert embedder.trust_remote_code is False 30 assert embedder.revision is None 31 assert embedder.local_files_only is False 32 assert embedder.truncate_dim is None 33 assert embedder.precision == "float32" 34 35 def test_init_with_parameters(self): 36 embedder = SentenceTransformersDocumentEmbedder( 37 model="model", 38 device=ComponentDevice.from_str("cuda:0"), 39 token=Secret.from_token("fake-api-token"), 40 prefix="prefix", 41 suffix="suffix", 42 batch_size=64, 43 progress_bar=False, 44 normalize_embeddings=True, 45 meta_fields_to_embed=["test_field"], 46 embedding_separator=" | ", 47 trust_remote_code=True, 48 revision="v1.0", 49 local_files_only=True, 50 truncate_dim=256, 51 precision="int8", 52 ) 53 assert embedder.model == "model" 54 assert embedder.device == ComponentDevice.from_str("cuda:0") 55 assert embedder.token == Secret.from_token("fake-api-token") 56 assert embedder.prefix == "prefix" 57 assert embedder.suffix == "suffix" 58 assert embedder.batch_size == 64 59 assert embedder.progress_bar is False 60 assert embedder.normalize_embeddings is True 61 assert embedder.meta_fields_to_embed == ["test_field"] 62 assert embedder.embedding_separator == " | " 63 assert embedder.trust_remote_code 64 assert embedder.revision == "v1.0" 65 assert embedder.local_files_only 66 assert embedder.truncate_dim == 256 67 assert embedder.precision == "int8" 68 69 def test_to_dict(self): 70 component = SentenceTransformersDocumentEmbedder(model="model", device=ComponentDevice.from_str("cpu")) 71 data = component.to_dict() 72 assert data == { 73 "type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder", # noqa: E501 74 "init_parameters": { 75 "model": "model", 76 "device": ComponentDevice.from_str("cpu").to_dict(), 77 "token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}, 78 "prefix": "", 79 "suffix": "", 80 "batch_size": 32, 81 "progress_bar": True, 82 "normalize_embeddings": False, 83 "embedding_separator": "\n", 84 "meta_fields_to_embed": [], 85 "trust_remote_code": False, 86 "revision": None, 87 "local_files_only": False, 88 "truncate_dim": None, 89 "model_kwargs": None, 90 "tokenizer_kwargs": None, 91 "encode_kwargs": None, 92 "config_kwargs": None, 93 "precision": "float32", 94 "backend": "torch", 95 }, 96 } 97 98 def test_to_dict_with_custom_init_parameters(self): 99 component = SentenceTransformersDocumentEmbedder( 100 model="model", 101 device=ComponentDevice.from_str("cuda:0"), 102 token=Secret.from_env_var("ENV_VAR", strict=False), 103 prefix="prefix", 104 suffix="suffix", 105 batch_size=64, 106 progress_bar=False, 107 normalize_embeddings=True, 108 meta_fields_to_embed=["meta_field"], 109 embedding_separator=" - ", 110 trust_remote_code=True, 111 local_files_only=True, 112 truncate_dim=256, 113 model_kwargs={"torch_dtype": torch.float32}, 114 tokenizer_kwargs={"model_max_length": 512}, 115 config_kwargs={"use_memory_efficient_attention": True}, 116 precision="int8", 117 encode_kwargs={"task": "clustering"}, 118 ) 119 data = component.to_dict() 120 121 assert data == { 122 "type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder", # noqa: E501 123 "init_parameters": { 124 "model": "model", 125 "device": ComponentDevice.from_str("cuda:0").to_dict(), 126 "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, 127 "prefix": "prefix", 128 "suffix": "suffix", 129 "batch_size": 64, 130 "progress_bar": False, 131 "normalize_embeddings": True, 132 "embedding_separator": " - ", 133 "trust_remote_code": True, 134 "revision": None, 135 "local_files_only": True, 136 "meta_fields_to_embed": ["meta_field"], 137 "truncate_dim": 256, 138 "model_kwargs": {"torch_dtype": "torch.float32"}, 139 "tokenizer_kwargs": {"model_max_length": 512}, 140 "config_kwargs": {"use_memory_efficient_attention": True}, 141 "precision": "int8", 142 "encode_kwargs": {"task": "clustering"}, 143 "backend": "torch", 144 }, 145 } 146 147 def test_from_dict(self): 148 init_parameters = { 149 "model": "model", 150 "device": ComponentDevice.from_str("cuda:0").to_dict(), 151 "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, 152 "prefix": "prefix", 153 "suffix": "suffix", 154 "batch_size": 64, 155 "progress_bar": False, 156 "normalize_embeddings": True, 157 "embedding_separator": " - ", 158 "meta_fields_to_embed": ["meta_field"], 159 "trust_remote_code": True, 160 "revision": "v1.0", 161 "local_files_only": True, 162 "truncate_dim": 256, 163 "model_kwargs": {"torch_dtype": "torch.float32"}, 164 "tokenizer_kwargs": {"model_max_length": 512}, 165 "config_kwargs": {"use_memory_efficient_attention": True}, 166 "precision": "int8", 167 } 168 component = SentenceTransformersDocumentEmbedder.from_dict( 169 { 170 "type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder", # noqa: E501 171 "init_parameters": init_parameters, 172 } 173 ) 174 assert component.model == "model" 175 assert component.device == ComponentDevice.from_str("cuda:0") 176 assert component.token == Secret.from_env_var("ENV_VAR", strict=False) 177 assert component.prefix == "prefix" 178 assert component.suffix == "suffix" 179 assert component.batch_size == 64 180 assert component.progress_bar is False 181 assert component.normalize_embeddings is True 182 assert component.embedding_separator == " - " 183 assert component.trust_remote_code 184 assert component.revision == "v1.0" 185 assert component.local_files_only 186 assert component.meta_fields_to_embed == ["meta_field"] 187 assert component.truncate_dim == 256 188 assert component.model_kwargs == {"torch_dtype": torch.float32} 189 assert component.tokenizer_kwargs == {"model_max_length": 512} 190 assert component.config_kwargs == {"use_memory_efficient_attention": True} 191 assert component.precision == "int8" 192 193 def test_from_dict_no_default_parameters(self): 194 component = SentenceTransformersDocumentEmbedder.from_dict( 195 { 196 "type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder", # noqa: E501 197 "init_parameters": {}, 198 } 199 ) 200 assert component.model == "sentence-transformers/all-mpnet-base-v2" 201 assert component.device == ComponentDevice.resolve_device(None) 202 assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False) 203 assert component.prefix == "" 204 assert component.suffix == "" 205 assert component.batch_size == 32 206 assert component.progress_bar is True 207 assert component.normalize_embeddings is False 208 assert component.embedding_separator == "\n" 209 assert component.trust_remote_code is False 210 assert component.revision is None 211 assert component.local_files_only is False 212 assert component.meta_fields_to_embed == [] 213 assert component.truncate_dim is None 214 assert component.precision == "float32" 215 216 def test_from_dict_none_device(self): 217 init_parameters = { 218 "model": "model", 219 "device": None, 220 "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, 221 "prefix": "prefix", 222 "suffix": "suffix", 223 "batch_size": 64, 224 "progress_bar": False, 225 "normalize_embeddings": True, 226 "embedding_separator": " - ", 227 "meta_fields_to_embed": ["meta_field"], 228 "trust_remote_code": True, 229 "local_files_only": False, 230 "truncate_dim": None, 231 "precision": "float32", 232 } 233 component = SentenceTransformersDocumentEmbedder.from_dict( 234 { 235 "type": "haystack.components.embedders.sentence_transformers_document_embedder.SentenceTransformersDocumentEmbedder", # noqa: E501 236 "init_parameters": init_parameters, 237 } 238 ) 239 assert component.model == "model" 240 assert component.device == ComponentDevice.resolve_device(None) 241 assert component.token == Secret.from_env_var("ENV_VAR", strict=False) 242 assert component.prefix == "prefix" 243 assert component.suffix == "suffix" 244 assert component.batch_size == 64 245 assert component.progress_bar is False 246 assert component.normalize_embeddings is True 247 assert component.embedding_separator == " - " 248 assert component.trust_remote_code 249 assert component.revision is None 250 assert component.local_files_only is False 251 assert component.meta_fields_to_embed == ["meta_field"] 252 assert component.truncate_dim is None 253 assert component.precision == "float32" 254 255 @patch( 256 "haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory" 257 ) 258 def test_warmup(self, mocked_factory): 259 embedder = SentenceTransformersDocumentEmbedder( 260 model="model", 261 token=None, 262 device=ComponentDevice.from_str("cpu"), 263 tokenizer_kwargs={"model_max_length": 512}, 264 config_kwargs={"use_memory_efficient_attention": True}, 265 ) 266 mocked_factory.get_embedding_backend.assert_not_called() 267 embedder.warm_up() 268 embedder.embedding_backend.model.max_seq_length = 512 269 mocked_factory.get_embedding_backend.assert_called_once_with( 270 model="model", 271 device="cpu", 272 auth_token=None, 273 trust_remote_code=False, 274 revision=None, 275 local_files_only=False, 276 truncate_dim=None, 277 model_kwargs=None, 278 tokenizer_kwargs={"model_max_length": 512}, 279 config_kwargs={"use_memory_efficient_attention": True}, 280 backend="torch", 281 ) 282 283 @patch( 284 "haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory" 285 ) 286 def test_warmup_doesnt_reload(self, mocked_factory): 287 embedder = SentenceTransformersDocumentEmbedder(model="model") 288 mocked_factory.get_embedding_backend.assert_not_called() 289 embedder.warm_up() 290 embedder.warm_up() 291 mocked_factory.get_embedding_backend.assert_called_once() 292 293 def test_run(self): 294 embedder = SentenceTransformersDocumentEmbedder(model="model") 295 embedder.embedding_backend = MagicMock() 296 embedder.embedding_backend.embed = lambda x, **_: [[random.random() for _ in range(16)] for _ in range(len(x))] 297 298 documents = [Document(content=f"document number {i}") for i in range(5)] 299 300 result = embedder.run(documents=documents) 301 302 assert isinstance(result["documents"], list) 303 assert len(result["documents"]) == len(documents) 304 for doc, new_doc in zip(documents, result["documents"], strict=True): 305 assert new_doc is not doc 306 assert doc.embedding is None 307 assert isinstance(new_doc, Document) 308 assert isinstance(new_doc.embedding, list) 309 assert isinstance(new_doc.embedding[0], float) 310 311 def test_run_wrong_input_format(self): 312 embedder = SentenceTransformersDocumentEmbedder(model="model") 313 314 string_input = "text" 315 list_integers_input = [1, 2, 3] 316 317 with pytest.raises( 318 TypeError, match="SentenceTransformersDocumentEmbedder expects a list of Documents as input" 319 ): 320 embedder.run(documents=string_input) 321 322 with pytest.raises( 323 TypeError, match="SentenceTransformersDocumentEmbedder expects a list of Documents as input" 324 ): 325 embedder.run(documents=list_integers_input) 326 327 def test_embed_metadata(self): 328 embedder = SentenceTransformersDocumentEmbedder( 329 model="model", meta_fields_to_embed=["meta_field"], embedding_separator="\n" 330 ) 331 embedder.embedding_backend = MagicMock() 332 embedder.embedding_backend.embed.return_value = [[random.random() for _ in range(16)] for _ in range(5)] 333 documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)] 334 embedder.run(documents=documents) 335 embedder.embedding_backend.embed.assert_called_once_with( 336 [ 337 "meta_value 0\ndocument number 0", 338 "meta_value 1\ndocument number 1", 339 "meta_value 2\ndocument number 2", 340 "meta_value 3\ndocument number 3", 341 "meta_value 4\ndocument number 4", 342 ], 343 batch_size=32, 344 show_progress_bar=True, 345 normalize_embeddings=False, 346 precision="float32", 347 ) 348 349 def test_embed_encode_kwargs(self): 350 embedder = SentenceTransformersDocumentEmbedder(model="model", encode_kwargs={"task": "retrieval.passage"}) 351 embedder.embedding_backend = MagicMock() 352 embedder.embedding_backend.embed.return_value = [[random.random() for _ in range(16)] for _ in range(5)] 353 documents = [Document(content=f"document number {i}") for i in range(5)] 354 embedder.run(documents=documents) 355 embedder.embedding_backend.embed.assert_called_once_with( 356 ["document number 0", "document number 1", "document number 2", "document number 3", "document number 4"], 357 batch_size=32, 358 show_progress_bar=True, 359 normalize_embeddings=False, 360 precision="float32", 361 task="retrieval.passage", 362 ) 363 364 def test_prefix_suffix(self): 365 embedder = SentenceTransformersDocumentEmbedder( 366 model="model", 367 prefix="my_prefix ", 368 suffix=" my_suffix", 369 meta_fields_to_embed=["meta_field"], 370 embedding_separator="\n", 371 ) 372 embedder.embedding_backend = MagicMock() 373 embedder.embedding_backend.embed.return_value = [[random.random() for _ in range(16)] for _ in range(5)] 374 documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)] 375 embedder.run(documents=documents) 376 embedder.embedding_backend.embed.assert_called_once_with( 377 [ 378 "my_prefix meta_value 0\ndocument number 0 my_suffix", 379 "my_prefix meta_value 1\ndocument number 1 my_suffix", 380 "my_prefix meta_value 2\ndocument number 2 my_suffix", 381 "my_prefix meta_value 3\ndocument number 3 my_suffix", 382 "my_prefix meta_value 4\ndocument number 4 my_suffix", 383 ], 384 batch_size=32, 385 show_progress_bar=True, 386 normalize_embeddings=False, 387 precision="float32", 388 ) 389 390 @patch( 391 "haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory" 392 ) 393 def test_model_onnx_backend(self, mocked_factory): 394 onnx_embedder = SentenceTransformersDocumentEmbedder( 395 model="sentence-transformers/all-MiniLM-L6-v2", 396 token=None, 397 device=ComponentDevice.from_str("cpu"), 398 # setting the path isn't necessary if the repo contains a "onnx/model.onnx" file but this is to prevent 399 # a HF warning 400 model_kwargs={"file_name": "onnx/model.onnx"}, 401 backend="onnx", 402 ) 403 onnx_embedder.warm_up() 404 405 mocked_factory.get_embedding_backend.assert_called_once_with( 406 model="sentence-transformers/all-MiniLM-L6-v2", 407 device="cpu", 408 auth_token=None, 409 trust_remote_code=False, 410 revision=None, 411 local_files_only=False, 412 truncate_dim=None, 413 model_kwargs={"file_name": "onnx/model.onnx"}, 414 tokenizer_kwargs=None, 415 config_kwargs=None, 416 backend="onnx", 417 ) 418 419 @patch( 420 "haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory" 421 ) 422 def test_model_openvino_backend(self, mocked_factory): 423 openvino_embedder = SentenceTransformersDocumentEmbedder( 424 model="sentence-transformers/all-MiniLM-L6-v2", 425 token=None, 426 device=ComponentDevice.from_str("cpu"), 427 # setting the path isn't necessary if the repo contains a "openvino/openvino_model.xml" file but this is 428 # to prevent a HF warning 429 model_kwargs={"file_name": "openvino/openvino_model.xml"}, 430 backend="openvino", 431 ) 432 openvino_embedder.warm_up() 433 434 mocked_factory.get_embedding_backend.assert_called_once_with( 435 model="sentence-transformers/all-MiniLM-L6-v2", 436 device="cpu", 437 auth_token=None, 438 trust_remote_code=False, 439 revision=None, 440 local_files_only=False, 441 truncate_dim=None, 442 model_kwargs={"file_name": "openvino/openvino_model.xml"}, 443 tokenizer_kwargs=None, 444 config_kwargs=None, 445 backend="openvino", 446 ) 447 448 @patch( 449 "haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory" 450 ) 451 @pytest.mark.parametrize("model_kwargs", [{"torch_dtype": "bfloat16"}, {"torch_dtype": "float16"}]) 452 def test_dtype_on_gpu(self, mocked_factory, model_kwargs): 453 torch_dtype_embedder = SentenceTransformersDocumentEmbedder( 454 model="sentence-transformers/all-MiniLM-L6-v2", 455 token=None, 456 device=ComponentDevice.from_str("cuda:0"), 457 model_kwargs=model_kwargs, 458 ) 459 torch_dtype_embedder.warm_up() 460 461 mocked_factory.get_embedding_backend.assert_called_once_with( 462 model="sentence-transformers/all-MiniLM-L6-v2", 463 device="cuda:0", 464 auth_token=None, 465 trust_remote_code=False, 466 revision=None, 467 local_files_only=False, 468 truncate_dim=None, 469 model_kwargs=model_kwargs, 470 tokenizer_kwargs=None, 471 config_kwargs=None, 472 backend="torch", 473 )