base.py
1 """ 2 SparseVectors module 3 """ 4 5 # Conditional import 6 try: 7 from scipy.sparse import csr_matrix, vstack 8 from sklearn.preprocessing import normalize 9 from sklearn.utils.extmath import safe_sparse_dot 10 11 SPARSE = True 12 except ImportError: 13 SPARSE = False 14 15 from ...util import SparseArray 16 from ..base import Vectors 17 18 19 # pylint: disable=W0223 20 class SparseVectors(Vectors): 21 """ 22 Base class for sparse vector models. Vector models transform input content into sparse arrays. 23 """ 24 25 def __init__(self, config, scoring, models): 26 # Check before parent constructor since it calls loadmodel 27 if not SPARSE: 28 raise ImportError('SparseVectors is not available - install "vectors" extra to enable') 29 30 super().__init__(config, scoring, models) 31 32 # Get normalization setting 33 self.isnormalize = self.config.get("normalize", self.defaultnormalize()) if self.config else None 34 35 def encode(self, data, category=None): 36 # Encode data to embeddings 37 embeddings = super().encode(data, category) 38 39 # Get sparse torch vector attributes 40 embeddings = embeddings.cpu().coalesce() 41 indices = embeddings.indices().numpy() 42 values = embeddings.values().numpy() 43 44 # Return as SciPy CSR Matrix 45 return csr_matrix((values, indices), shape=embeddings.size()) 46 47 def vectors(self, documents, batchsize=500, checkpoint=None, buffer=None, dtype=None): 48 # Run indexing 49 ids, dimensions, batches, stream = self.index(documents, batchsize, checkpoint) 50 51 # Rebuild sparse array 52 embeddings = None 53 with open(stream, "rb") as queue: 54 for _ in range(batches): 55 # Read in array batch 56 data = self.loadembeddings(queue) 57 embeddings = vstack((embeddings, data)) if embeddings is not None else data 58 59 # Return sparse array 60 return (ids, dimensions, embeddings) 61 62 def dot(self, queries, data): 63 return safe_sparse_dot(queries, data.T, dense_output=True).tolist() 64 65 def loadembeddings(self, f): 66 return SparseArray().load(f) 67 68 def saveembeddings(self, f, embeddings): 69 SparseArray().save(f, embeddings) 70 71 def truncate(self, embeddings): 72 raise ValueError("Truncate is not supported for sparse vectors") 73 74 def normalize(self, embeddings): 75 # Optionally normalize embeddings using method that supports sparse vectors 76 return normalize(embeddings, copy=False) if self.isnormalize else embeddings 77 78 def quantize(self, embeddings): 79 raise ValueError("Quantize is not supported for sparse vectors") 80 81 def defaultnormalize(self): 82 """ 83 Returns the default normalization setting. 84 85 Returns: 86 default normalization setting 87 """ 88 89 # Sparse vector embeddings typically perform better as unnormalized 90 return False