test_sentence_transformers_sparse_document_embedder.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 from unittest.mock import MagicMock, patch 6 7 import pytest 8 import torch 9 10 from haystack import Document 11 from haystack.components.embedders.sentence_transformers_sparse_document_embedder import ( 12 SentenceTransformersSparseDocumentEmbedder, 13 ) 14 from haystack.dataclasses.sparse_embedding import SparseEmbedding 15 from haystack.utils import ComponentDevice, Secret 16 17 TYPE_NAME = ( 18 "haystack.components.embedders.sentence_transformers_sparse_document_embedder." 19 "SentenceTransformersSparseDocumentEmbedder" 20 ) 21 22 23 class TestSentenceTransformersDocumentEmbedder: 24 def test_init_default(self): 25 embedder = SentenceTransformersSparseDocumentEmbedder(model="model") 26 assert embedder.model == "model" 27 assert embedder.device == ComponentDevice.resolve_device(None) 28 assert embedder.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False) 29 assert embedder.prefix == "" 30 assert embedder.suffix == "" 31 assert embedder.batch_size == 32 32 assert embedder.progress_bar is True 33 assert embedder.meta_fields_to_embed == [] 34 assert embedder.embedding_separator == "\n" 35 assert embedder.trust_remote_code is False 36 assert embedder.revision is None 37 assert embedder.local_files_only is False 38 39 def test_init_with_parameters(self): 40 embedder = SentenceTransformersSparseDocumentEmbedder( 41 model="model", 42 device=ComponentDevice.from_str("cuda:0"), 43 token=Secret.from_token("fake-api-token"), 44 prefix="prefix", 45 suffix="suffix", 46 batch_size=64, 47 progress_bar=False, 48 meta_fields_to_embed=["test_field"], 49 embedding_separator=" | ", 50 trust_remote_code=True, 51 revision="v1.0", 52 local_files_only=True, 53 ) 54 assert embedder.model == "model" 55 assert embedder.device == ComponentDevice.from_str("cuda:0") 56 assert embedder.token == Secret.from_token("fake-api-token") 57 assert embedder.prefix == "prefix" 58 assert embedder.suffix == "suffix" 59 assert embedder.batch_size == 64 60 assert embedder.progress_bar is False 61 assert embedder.meta_fields_to_embed == ["test_field"] 62 assert embedder.embedding_separator == " | " 63 assert embedder.trust_remote_code 64 assert embedder.revision == "v1.0" 65 assert embedder.local_files_only 66 67 def test_to_dict(self): 68 component = SentenceTransformersSparseDocumentEmbedder(model="model", device=ComponentDevice.from_str("cpu")) 69 data = component.to_dict() 70 assert data == { 71 "type": TYPE_NAME, 72 "init_parameters": { 73 "model": "model", 74 "device": ComponentDevice.from_str("cpu").to_dict(), 75 "token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"}, 76 "prefix": "", 77 "suffix": "", 78 "batch_size": 32, 79 "progress_bar": True, 80 "embedding_separator": "\n", 81 "meta_fields_to_embed": [], 82 "trust_remote_code": False, 83 "revision": None, 84 "local_files_only": False, 85 "model_kwargs": None, 86 "tokenizer_kwargs": None, 87 "config_kwargs": None, 88 "backend": "torch", 89 }, 90 } 91 92 def test_to_dict_with_custom_init_parameters(self): 93 component = SentenceTransformersSparseDocumentEmbedder( 94 model="model", 95 device=ComponentDevice.from_str("cuda:0"), 96 token=Secret.from_env_var("ENV_VAR", strict=False), 97 prefix="prefix", 98 suffix="suffix", 99 batch_size=64, 100 progress_bar=False, 101 meta_fields_to_embed=["meta_field"], 102 embedding_separator=" - ", 103 trust_remote_code=True, 104 local_files_only=True, 105 model_kwargs={"torch_dtype": torch.float32}, 106 tokenizer_kwargs={"model_max_length": 512}, 107 config_kwargs={"use_memory_efficient_attention": True}, 108 ) 109 data = component.to_dict() 110 111 assert data == { 112 "type": TYPE_NAME, 113 "init_parameters": { 114 "model": "model", 115 "device": ComponentDevice.from_str("cuda:0").to_dict(), 116 "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, 117 "prefix": "prefix", 118 "suffix": "suffix", 119 "batch_size": 64, 120 "progress_bar": False, 121 "embedding_separator": " - ", 122 "trust_remote_code": True, 123 "revision": None, 124 "local_files_only": True, 125 "meta_fields_to_embed": ["meta_field"], 126 "model_kwargs": {"torch_dtype": "torch.float32"}, 127 "tokenizer_kwargs": {"model_max_length": 512}, 128 "config_kwargs": {"use_memory_efficient_attention": True}, 129 "backend": "torch", 130 }, 131 } 132 133 def test_from_dict(self): 134 init_parameters = { 135 "model": "model", 136 "device": ComponentDevice.from_str("cuda:0").to_dict(), 137 "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, 138 "prefix": "prefix", 139 "suffix": "suffix", 140 "batch_size": 64, 141 "progress_bar": False, 142 "embedding_separator": " - ", 143 "meta_fields_to_embed": ["meta_field"], 144 "trust_remote_code": True, 145 "revision": "v1.0", 146 "local_files_only": True, 147 "model_kwargs": {"torch_dtype": "torch.float32"}, 148 "tokenizer_kwargs": {"model_max_length": 512}, 149 "config_kwargs": {"use_memory_efficient_attention": True}, 150 } 151 component = SentenceTransformersSparseDocumentEmbedder.from_dict( 152 {"type": TYPE_NAME, "init_parameters": init_parameters} 153 ) 154 assert component.model == "model" 155 assert component.device == ComponentDevice.from_str("cuda:0") 156 assert component.token == Secret.from_env_var("ENV_VAR", strict=False) 157 assert component.prefix == "prefix" 158 assert component.suffix == "suffix" 159 assert component.batch_size == 64 160 assert component.progress_bar is False 161 assert component.embedding_separator == " - " 162 assert component.trust_remote_code 163 assert component.revision == "v1.0" 164 assert component.local_files_only 165 assert component.meta_fields_to_embed == ["meta_field"] 166 assert component.model_kwargs == {"torch_dtype": torch.float32} 167 assert component.tokenizer_kwargs == {"model_max_length": 512} 168 assert component.config_kwargs == {"use_memory_efficient_attention": True} 169 170 def test_from_dict_no_default_parameters(self): 171 component = SentenceTransformersSparseDocumentEmbedder.from_dict({"type": TYPE_NAME, "init_parameters": {}}) 172 assert component.model == "prithivida/Splade_PP_en_v2" 173 assert component.device == ComponentDevice.resolve_device(None) 174 assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False) 175 assert component.prefix == "" 176 assert component.suffix == "" 177 assert component.batch_size == 32 178 assert component.progress_bar is True 179 assert component.embedding_separator == "\n" 180 assert component.trust_remote_code is False 181 assert component.revision is None 182 assert component.local_files_only is False 183 assert component.meta_fields_to_embed == [] 184 185 def test_from_dict_none_device(self): 186 init_parameters = { 187 "model": "model", 188 "device": None, 189 "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, 190 "prefix": "prefix", 191 "suffix": "suffix", 192 "batch_size": 64, 193 "progress_bar": False, 194 "embedding_separator": " - ", 195 "meta_fields_to_embed": ["meta_field"], 196 "trust_remote_code": True, 197 "local_files_only": False, 198 } 199 component = SentenceTransformersSparseDocumentEmbedder.from_dict( 200 {"type": TYPE_NAME, "init_parameters": init_parameters} 201 ) 202 assert component.model == "model" 203 assert component.device == ComponentDevice.resolve_device(None) 204 assert component.token == Secret.from_env_var("ENV_VAR", strict=False) 205 assert component.prefix == "prefix" 206 assert component.suffix == "suffix" 207 assert component.batch_size == 64 208 assert component.progress_bar is False 209 assert component.embedding_separator == " - " 210 assert component.trust_remote_code 211 assert component.revision is None 212 assert component.local_files_only is False 213 assert component.meta_fields_to_embed == ["meta_field"] 214 215 @patch( 216 "haystack.components.embedders.sentence_transformers_sparse_document_embedder._SentenceTransformersSparseEmbeddingBackendFactory" 217 ) 218 def test_warmup(self, mocked_factory): 219 embedder = SentenceTransformersSparseDocumentEmbedder( 220 model="model", 221 token=None, 222 device=ComponentDevice.from_str("cpu"), 223 tokenizer_kwargs={"model_max_length": 512}, 224 config_kwargs={"use_memory_efficient_attention": True}, 225 ) 226 mocked_factory.get_embedding_backend.assert_not_called() 227 embedder.warm_up() 228 embedder.embedding_backend.model.max_seq_length = 512 229 mocked_factory.get_embedding_backend.assert_called_once_with( 230 model="model", 231 device="cpu", 232 auth_token=None, 233 trust_remote_code=False, 234 revision=None, 235 local_files_only=False, 236 model_kwargs=None, 237 tokenizer_kwargs={"model_max_length": 512}, 238 config_kwargs={"use_memory_efficient_attention": True}, 239 backend="torch", 240 ) 241 242 @patch( 243 "haystack.components.embedders.sentence_transformers_sparse_document_embedder._SentenceTransformersSparseEmbeddingBackendFactory" 244 ) 245 def test_warmup_doesnt_reload(self, mocked_factory): 246 embedder = SentenceTransformersSparseDocumentEmbedder(model="model") 247 mocked_factory.get_embedding_backend.assert_not_called() 248 embedder.warm_up() 249 embedder.warm_up() 250 mocked_factory.get_embedding_backend.assert_called_once() 251 252 def test_run(self): 253 embedder = SentenceTransformersSparseDocumentEmbedder(model="model") 254 embedder.embedding_backend = MagicMock() 255 256 def fake_embed(data, **kwargs): 257 return [SparseEmbedding(indices=[0, 2, 5], values=[0.1, 0.2, 0.3]) for _ in range(len(data))] 258 259 embedder.embedding_backend.embed = fake_embed 260 261 documents = [Document(content=f"document number {i}") for i in range(5)] 262 263 result = embedder.run(documents=documents) 264 265 assert isinstance(result["documents"], list) 266 assert len(result["documents"]) == len(documents) 267 for doc in result["documents"]: 268 assert isinstance(doc, Document) 269 assert isinstance(doc.sparse_embedding, SparseEmbedding) 270 assert isinstance(doc.sparse_embedding.indices[0], int) 271 assert isinstance(doc.sparse_embedding.values[0], float) 272 273 def test_run_wrong_input_format(self): 274 embedder = SentenceTransformersSparseDocumentEmbedder(model="model") 275 276 string_input = "text" 277 list_integers_input = [1, 2, 3] 278 279 with pytest.raises( 280 TypeError, match="SentenceTransformersSparseDocumentEmbedder expects a list of Documents as input" 281 ): 282 embedder.run(documents=string_input) 283 284 with pytest.raises( 285 TypeError, match="SentenceTransformersSparseDocumentEmbedder expects a list of Documents as input" 286 ): 287 embedder.run(documents=list_integers_input) 288 289 def test_embed_metadata(self): 290 embedder = SentenceTransformersSparseDocumentEmbedder( 291 model="model", meta_fields_to_embed=["meta_field"], embedding_separator="\n" 292 ) 293 embedder.embedding_backend = MagicMock() 294 embedder.embedding_backend.embed.return_value = [ 295 SparseEmbedding(indices=[0, 2, 5], values=[0.1, 0.2, 0.3]) for _ in range(5) 296 ] 297 documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)] 298 embedder.run(documents=documents) 299 embedder.embedding_backend.embed.assert_called_once_with( 300 data=[ 301 "meta_value 0\ndocument number 0", 302 "meta_value 1\ndocument number 1", 303 "meta_value 2\ndocument number 2", 304 "meta_value 3\ndocument number 3", 305 "meta_value 4\ndocument number 4", 306 ], 307 batch_size=32, 308 show_progress_bar=True, 309 ) 310 311 def test_prefix_suffix(self): 312 embedder = SentenceTransformersSparseDocumentEmbedder( 313 model="model", 314 prefix="my_prefix ", 315 suffix=" my_suffix", 316 meta_fields_to_embed=["meta_field"], 317 embedding_separator="\n", 318 ) 319 embedder.embedding_backend = MagicMock() 320 embedder.embedding_backend.embed.return_value = [ 321 SparseEmbedding(indices=[0, 2, 5], values=[0.1, 0.2, 0.3]) for _ in range(5) 322 ] 323 documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)] 324 embedder.run(documents=documents) 325 embedder.embedding_backend.embed.assert_called_once_with( 326 data=[ 327 "my_prefix meta_value 0\ndocument number 0 my_suffix", 328 "my_prefix meta_value 1\ndocument number 1 my_suffix", 329 "my_prefix meta_value 2\ndocument number 2 my_suffix", 330 "my_prefix meta_value 3\ndocument number 3 my_suffix", 331 "my_prefix meta_value 4\ndocument number 4 my_suffix", 332 ], 333 batch_size=32, 334 show_progress_bar=True, 335 ) 336 337 @patch( 338 "haystack.components.embedders.sentence_transformers_sparse_document_embedder._SentenceTransformersSparseEmbeddingBackendFactory" 339 ) 340 def test_model_onnx_backend(self, mocked_factory): 341 onnx_embedder = SentenceTransformersSparseDocumentEmbedder( 342 model="prithivida/Splade_PP_en_v2", 343 token=None, 344 device=ComponentDevice.from_str("cpu"), 345 model_kwargs={ 346 "file_name": "onnx/model.onnx" 347 }, # setting the path isn't necessary if the repo contains a "onnx/model.onnx" file but this is to 348 # prevent a HF warning 349 backend="onnx", 350 ) 351 onnx_embedder.warm_up() 352 353 mocked_factory.get_embedding_backend.assert_called_once_with( 354 model="prithivida/Splade_PP_en_v2", 355 device="cpu", 356 auth_token=None, 357 trust_remote_code=False, 358 revision=None, 359 local_files_only=False, 360 model_kwargs={"file_name": "onnx/model.onnx"}, 361 tokenizer_kwargs=None, 362 config_kwargs=None, 363 backend="onnx", 364 ) 365 366 @patch( 367 "haystack.components.embedders.sentence_transformers_sparse_document_embedder._SentenceTransformersSparseEmbeddingBackendFactory" 368 ) 369 def test_model_openvino_backend(self, mocked_factory): 370 openvino_embedder = SentenceTransformersSparseDocumentEmbedder( 371 model="prithivida/Splade_PP_en_v2", 372 token=None, 373 device=ComponentDevice.from_str("cpu"), 374 model_kwargs={ 375 "file_name": "openvino/openvino_model.xml" 376 }, # setting the path isn't necessary if the repo contains a "openvino/openvino_model.xml" file but this 377 # is to prevent a HF warning 378 backend="openvino", 379 ) 380 openvino_embedder.warm_up() 381 382 mocked_factory.get_embedding_backend.assert_called_once_with( 383 model="prithivida/Splade_PP_en_v2", 384 device="cpu", 385 auth_token=None, 386 trust_remote_code=False, 387 revision=None, 388 local_files_only=False, 389 model_kwargs={"file_name": "openvino/openvino_model.xml"}, 390 tokenizer_kwargs=None, 391 config_kwargs=None, 392 backend="openvino", 393 ) 394 395 @patch( 396 "haystack.components.embedders.sentence_transformers_sparse_document_embedder._SentenceTransformersSparseEmbeddingBackendFactory" 397 ) 398 @pytest.mark.parametrize("model_kwargs", [{"torch_dtype": "bfloat16"}, {"torch_dtype": "float16"}]) 399 def test_dtype_on_gpu(self, mocked_factory, model_kwargs): 400 torch_dtype_embedder = SentenceTransformersSparseDocumentEmbedder( 401 model="prithivida/Splade_PP_en_v2", 402 token=None, 403 device=ComponentDevice.from_str("cuda:0"), 404 model_kwargs=model_kwargs, 405 ) 406 torch_dtype_embedder.warm_up() 407 408 mocked_factory.get_embedding_backend.assert_called_once_with( 409 model="prithivida/Splade_PP_en_v2", 410 device="cuda:0", 411 auth_token=None, 412 trust_remote_code=False, 413 revision=None, 414 local_files_only=False, 415 model_kwargs=model_kwargs, 416 tokenizer_kwargs=None, 417 config_kwargs=None, 418 backend="torch", 419 ) 420 421 @pytest.mark.integration 422 @pytest.mark.slow 423 @pytest.mark.flaky(reruns=3, reruns_delay=10) 424 def test_live_run_sparse_document_embedder(self, del_hf_env_vars): 425 docs = [ 426 Document(content="I love cheese", meta={"topic": "Cuisine"}), 427 Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), 428 ] 429 430 embedder = SentenceTransformersSparseDocumentEmbedder( 431 model="sparse-encoder-testing/splade-bert-tiny-nq", 432 meta_fields_to_embed=["topic"], 433 embedding_separator=" | ", 434 device=ComponentDevice.from_str("cpu"), 435 ) 436 result = embedder.run(documents=docs) 437 documents_with_embeddings = result["documents"] 438 439 assert isinstance(documents_with_embeddings, list) 440 assert len(documents_with_embeddings) == len(docs) 441 for doc in documents_with_embeddings: 442 assert isinstance(doc, Document) 443 assert hasattr(doc, "sparse_embedding") 444 assert isinstance(doc.sparse_embedding, SparseEmbedding) 445 assert isinstance(doc.sparse_embedding.indices, list) 446 assert isinstance(doc.sparse_embedding.values, list) 447 assert len(doc.sparse_embedding.indices) == len(doc.sparse_embedding.values) 448 # Expect at least one non-zero entry 449 assert len(doc.sparse_embedding.indices) > 0