/ haystack / components / embedders / backends / sentence_transformers_sparse_backend.py
sentence_transformers_sparse_backend.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import json
  6  from typing import Any, Literal
  7  
  8  from haystack.dataclasses.sparse_embedding import SparseEmbedding
  9  from haystack.lazy_imports import LazyImport
 10  from haystack.utils.auth import Secret
 11  
 12  with LazyImport(message="Run 'pip install \"sentence-transformers>=5.0.0\"'") as sentence_transformers_import:
 13      from sentence_transformers import SparseEncoder
 14  
 15  
 16  class _SentenceTransformersSparseEmbeddingBackendFactory:
 17      """
 18      Factory class to create instances of Sentence Transformers embedding backends.
 19      """
 20  
 21      _instances: dict[str, "_SentenceTransformersSparseEncoderEmbeddingBackend"] = {}
 22  
 23      @staticmethod
 24      def get_embedding_backend(
 25          *,
 26          model: str,
 27          device: str | None = None,
 28          auth_token: Secret | None = None,
 29          trust_remote_code: bool = False,
 30          revision: str | None = None,
 31          local_files_only: bool = False,
 32          model_kwargs: dict[str, Any] | None = None,
 33          tokenizer_kwargs: dict[str, Any] | None = None,
 34          config_kwargs: dict[str, Any] | None = None,
 35          backend: Literal["torch", "onnx", "openvino"] = "torch",
 36      ) -> "_SentenceTransformersSparseEncoderEmbeddingBackend":
 37          cache_params = {
 38              "model": model,
 39              "device": device,
 40              "auth_token": auth_token,
 41              "trust_remote_code": trust_remote_code,
 42              "revision": revision,
 43              "local_files_only": local_files_only,
 44              "model_kwargs": model_kwargs,
 45              "tokenizer_kwargs": tokenizer_kwargs,
 46              "config_kwargs": config_kwargs,
 47              "backend": backend,
 48          }
 49  
 50          embedding_backend_id = json.dumps(cache_params, sort_keys=True, default=str)
 51  
 52          if embedding_backend_id in _SentenceTransformersSparseEmbeddingBackendFactory._instances:
 53              return _SentenceTransformersSparseEmbeddingBackendFactory._instances[embedding_backend_id]
 54  
 55          embedding_backend = _SentenceTransformersSparseEncoderEmbeddingBackend(
 56              model=model,
 57              device=device,
 58              auth_token=auth_token,
 59              trust_remote_code=trust_remote_code,
 60              revision=revision,
 61              local_files_only=local_files_only,
 62              model_kwargs=model_kwargs,
 63              tokenizer_kwargs=tokenizer_kwargs,
 64              config_kwargs=config_kwargs,
 65              backend=backend,
 66          )
 67  
 68          _SentenceTransformersSparseEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
 69          return embedding_backend
 70  
 71  
 72  class _SentenceTransformersSparseEncoderEmbeddingBackend:
 73      """
 74      Class to manage Sparse embeddings from Sentence Transformers.
 75      """
 76  
 77      def __init__(
 78          self,
 79          *,
 80          model: str,
 81          device: str | None = None,
 82          auth_token: Secret | None = None,
 83          trust_remote_code: bool = False,
 84          revision: str | None = None,
 85          local_files_only: bool = False,
 86          model_kwargs: dict[str, Any] | None = None,
 87          tokenizer_kwargs: dict[str, Any] | None = None,
 88          config_kwargs: dict[str, Any] | None = None,
 89          backend: Literal["torch", "onnx", "openvino"] = "torch",
 90      ) -> None:
 91          sentence_transformers_import.check()
 92  
 93          self.model = SparseEncoder(
 94              model_name_or_path=model,
 95              device=device,
 96              token=auth_token.resolve_value() if auth_token else None,
 97              trust_remote_code=trust_remote_code,
 98              revision=revision,
 99              local_files_only=local_files_only,
100              model_kwargs=model_kwargs,
101              tokenizer_kwargs=tokenizer_kwargs,
102              config_kwargs=config_kwargs,
103              backend=backend,
104          )
105  
106      def embed(self, *, data: list[str], **kwargs: Any) -> list[SparseEmbedding]:
107          embeddings_list = self.model.encode(
108              data,
109              convert_to_tensor=False,  # output is a list of individual tensors
110              convert_to_sparse_tensor=True,
111              **kwargs,
112          )
113  
114          sparse_embeddings: list[SparseEmbedding] = []
115          for embedding_tensor in embeddings_list:
116              embedding_tensor = embedding_tensor.coalesce()
117              indices = embedding_tensor.indices()[0].tolist()  # Only column indices
118              values = embedding_tensor.values().tolist()
119              sparse_embeddings.append(SparseEmbedding(indices=indices, values=values))
120  
121          return sparse_embeddings