/ src / python / txtai / vectors / sparse / base.py
base.py
 1  """
 2  SparseVectors module
 3  """
 4  
 5  # Conditional import
 6  try:
 7      from scipy.sparse import csr_matrix, vstack
 8      from sklearn.preprocessing import normalize
 9      from sklearn.utils.extmath import safe_sparse_dot
10  
11      SPARSE = True
12  except ImportError:
13      SPARSE = False
14  
15  from ...util import SparseArray
16  from ..base import Vectors
17  
18  
19  # pylint: disable=W0223
20  class SparseVectors(Vectors):
21      """
22      Base class for sparse vector models. Vector models transform input content into sparse arrays.
23      """
24  
25      def __init__(self, config, scoring, models):
26          # Check before parent constructor since it calls loadmodel
27          if not SPARSE:
28              raise ImportError('SparseVectors is not available - install "vectors" extra to enable')
29  
30          super().__init__(config, scoring, models)
31  
32          # Get normalization setting
33          self.isnormalize = self.config.get("normalize", self.defaultnormalize()) if self.config else None
34  
35      def encode(self, data, category=None):
36          # Encode data to embeddings
37          embeddings = super().encode(data, category)
38  
39          # Get sparse torch vector attributes
40          embeddings = embeddings.cpu().coalesce()
41          indices = embeddings.indices().numpy()
42          values = embeddings.values().numpy()
43  
44          # Return as SciPy CSR Matrix
45          return csr_matrix((values, indices), shape=embeddings.size())
46  
47      def vectors(self, documents, batchsize=500, checkpoint=None, buffer=None, dtype=None):
48          # Run indexing
49          ids, dimensions, batches, stream = self.index(documents, batchsize, checkpoint)
50  
51          # Rebuild sparse array
52          embeddings = None
53          with open(stream, "rb") as queue:
54              for _ in range(batches):
55                  # Read in array batch
56                  data = self.loadembeddings(queue)
57                  embeddings = vstack((embeddings, data)) if embeddings is not None else data
58  
59          # Return sparse array
60          return (ids, dimensions, embeddings)
61  
62      def dot(self, queries, data):
63          return safe_sparse_dot(queries, data.T, dense_output=True).tolist()
64  
65      def loadembeddings(self, f):
66          return SparseArray().load(f)
67  
68      def saveembeddings(self, f, embeddings):
69          SparseArray().save(f, embeddings)
70  
71      def truncate(self, embeddings):
72          raise ValueError("Truncate is not supported for sparse vectors")
73  
74      def normalize(self, embeddings):
75          # Optionally normalize embeddings using method that supports sparse vectors
76          return normalize(embeddings, copy=False) if self.isnormalize else embeddings
77  
78      def quantize(self, embeddings):
79          raise ValueError("Quantize is not supported for sparse vectors")
80  
81      def defaultnormalize(self):
82          """
83          Returns the default normalization setting.
84  
85          Returns:
86              default normalization setting
87          """
88  
89          # Sparse vector embeddings typically perform better as unnormalized
90          return False