sentence_transformers_sparse_backend.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import json 6 from typing import Any, Literal 7 8 from haystack.dataclasses.sparse_embedding import SparseEmbedding 9 from haystack.lazy_imports import LazyImport 10 from haystack.utils.auth import Secret 11 12 with LazyImport(message="Run 'pip install \"sentence-transformers>=5.0.0\"'") as sentence_transformers_import: 13 from sentence_transformers import SparseEncoder 14 15 16 class _SentenceTransformersSparseEmbeddingBackendFactory: 17 """ 18 Factory class to create instances of Sentence Transformers embedding backends. 19 """ 20 21 _instances: dict[str, "_SentenceTransformersSparseEncoderEmbeddingBackend"] = {} 22 23 @staticmethod 24 def get_embedding_backend( 25 *, 26 model: str, 27 device: str | None = None, 28 auth_token: Secret | None = None, 29 trust_remote_code: bool = False, 30 revision: str | None = None, 31 local_files_only: bool = False, 32 model_kwargs: dict[str, Any] | None = None, 33 tokenizer_kwargs: dict[str, Any] | None = None, 34 config_kwargs: dict[str, Any] | None = None, 35 backend: Literal["torch", "onnx", "openvino"] = "torch", 36 ) -> "_SentenceTransformersSparseEncoderEmbeddingBackend": 37 cache_params = { 38 "model": model, 39 "device": device, 40 "auth_token": auth_token, 41 "trust_remote_code": trust_remote_code, 42 "revision": revision, 43 "local_files_only": local_files_only, 44 "model_kwargs": model_kwargs, 45 "tokenizer_kwargs": tokenizer_kwargs, 46 "config_kwargs": config_kwargs, 47 "backend": backend, 48 } 49 50 embedding_backend_id = json.dumps(cache_params, sort_keys=True, default=str) 51 52 if embedding_backend_id in _SentenceTransformersSparseEmbeddingBackendFactory._instances: 53 return _SentenceTransformersSparseEmbeddingBackendFactory._instances[embedding_backend_id] 54 55 embedding_backend = _SentenceTransformersSparseEncoderEmbeddingBackend( 56 model=model, 57 device=device, 58 auth_token=auth_token, 59 trust_remote_code=trust_remote_code, 60 revision=revision, 61 local_files_only=local_files_only, 62 model_kwargs=model_kwargs, 63 tokenizer_kwargs=tokenizer_kwargs, 64 config_kwargs=config_kwargs, 65 backend=backend, 66 ) 67 68 _SentenceTransformersSparseEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend 69 return embedding_backend 70 71 72 class _SentenceTransformersSparseEncoderEmbeddingBackend: 73 """ 74 Class to manage Sparse embeddings from Sentence Transformers. 75 """ 76 77 def __init__( 78 self, 79 *, 80 model: str, 81 device: str | None = None, 82 auth_token: Secret | None = None, 83 trust_remote_code: bool = False, 84 revision: str | None = None, 85 local_files_only: bool = False, 86 model_kwargs: dict[str, Any] | None = None, 87 tokenizer_kwargs: dict[str, Any] | None = None, 88 config_kwargs: dict[str, Any] | None = None, 89 backend: Literal["torch", "onnx", "openvino"] = "torch", 90 ) -> None: 91 sentence_transformers_import.check() 92 93 self.model = SparseEncoder( 94 model_name_or_path=model, 95 device=device, 96 token=auth_token.resolve_value() if auth_token else None, 97 trust_remote_code=trust_remote_code, 98 revision=revision, 99 local_files_only=local_files_only, 100 model_kwargs=model_kwargs, 101 tokenizer_kwargs=tokenizer_kwargs, 102 config_kwargs=config_kwargs, 103 backend=backend, 104 ) 105 106 def embed(self, *, data: list[str], **kwargs: Any) -> list[SparseEmbedding]: 107 embeddings_list = self.model.encode( 108 data, 109 convert_to_tensor=False, # output is a list of individual tensors 110 convert_to_sparse_tensor=True, 111 **kwargs, 112 ) 113 114 sparse_embeddings: list[SparseEmbedding] = [] 115 for embedding_tensor in embeddings_list: 116 embedding_tensor = embedding_tensor.coalesce() 117 indices = embedding_tensor.indices()[0].tolist() # Only column indices 118 values = embedding_tensor.values().tolist() 119 sparse_embeddings.append(SparseEmbedding(indices=indices, values=values)) 120 121 return sparse_embeddings