/ src / python / txtai / ann / sparse / ivfsparse.py
ivfsparse.py
  1  """
  2  IVFSparse module
  3  """
  4  
  5  import math
  6  import os
  7  
  8  from multiprocessing.pool import ThreadPool
  9  
 10  import numpy as np
 11  
 12  # Conditional import
 13  try:
 14      from scipy.sparse import csr_matrix, vstack
 15      from scipy.sparse.linalg import norm
 16      from sklearn.cluster import MiniBatchKMeans
 17      from sklearn.metrics import pairwise_distances_argmin_min
 18      from sklearn.utils.extmath import safe_sparse_dot
 19  
 20      IVFSPARSE = True
 21  except ImportError:
 22      IVFSPARSE = False
 23  
 24  from ...serialize import SerializeFactory
 25  from ...util import SparseArray
 26  from ..base import ANN
 27  
 28  
 29  class IVFSparse(ANN):
 30      """
 31      Inverted file (IVF) index with flat vector file storage and sparse array support.
 32  
 33      IVFSparse builds an IVF index and enables approximate nearest neighbor (ANN) search.
 34  
 35      This index is modeled after Faiss and supports many of the same parameters.
 36  
 37      See this link for more: https://github.com/facebookresearch/faiss/wiki/Faster-search
 38      """
 39  
 40      def __init__(self, config):
 41          super().__init__(config)
 42  
 43          if not IVFSPARSE:
 44              raise ImportError('IVFSparse is not available - install "ann" extra to enable')
 45  
 46          # Cluster centroids, if computed
 47          self.centroids = None
 48  
 49          # Cluster id mapping
 50          self.ids = None
 51  
 52          # Cluster data blocks - can be a single block with no computed centroids
 53          self.blocks = None
 54  
 55          # Deleted ids
 56          self.deletes = None
 57  
 58      def index(self, embeddings):
 59          # Compute model training size
 60          train, sample = embeddings, self.setting("sample")
 61          if sample:
 62              # Get sample for training
 63              rng = np.random.default_rng(0)
 64              indices = sorted(rng.choice(train.shape[0], int(sample * train.shape[0]), replace=False, shuffle=False))
 65              train = train[indices]
 66  
 67          # Get number of clusters. Note that final number of clusters could be lower due to filtering duplicate centroids
 68          # and pruning of small clusters
 69          clusters = self.nlist(embeddings.shape[0], train.shape[0])
 70  
 71          # Build cluster centroids if approximate search is enabled
 72          # A single cluster performs exact search
 73          self.centroids = self.build(train, clusters) if clusters > 1 else None
 74  
 75          # Sort into clusters
 76          ids = self.aggregate(embeddings)
 77  
 78          # Prune small clusters (less than minpoints parameter) and rebuild
 79          indices = sorted(k for k, v in ids.items() if len(v) >= self.minpoints())
 80          if len(indices) > 0 and len(ids) > 1 and len(indices) != len(ids.keys()):
 81              self.centroids = self.centroids[indices]
 82              ids = self.aggregate(embeddings)
 83  
 84          # Sort clusters by id
 85          self.ids = dict(sorted(ids.items(), key=lambda x: x[0]))
 86  
 87          # Create cluster data blocks
 88          self.blocks = {k: embeddings[v] for k, v in self.ids.items()}
 89  
 90          # Calculate block max summary vectors and use as centroids
 91          self.centroids = vstack([csr_matrix(x.max(axis=0)) for x in self.blocks.values()]) if self.centroids is not None else None
 92  
 93          # Initialize deletes
 94          self.deletes = []
 95  
 96          # Add id offset and index build metadata
 97          self.config["offset"] = embeddings.shape[0]
 98          self.metadata({"clusters": len(self.blocks)})
 99  
100      def append(self, embeddings):
101          # Get offset
102          offset = self.size()
103  
104          # Sort into clusters and merge
105          for cluster, ids in self.aggregate(embeddings).items():
106              # Add new ids
107              self.ids[cluster].extend([x + offset for x in ids])
108  
109              # Add new data
110              self.blocks[cluster] = vstack([self.blocks[cluster], embeddings[ids]])
111  
112          # Update id offset and index metadata
113          self.config["offset"] += embeddings.shape[0]
114          self.metadata()
115  
116      def delete(self, ids):
117          # Set index ids as deleted
118          self.deletes.extend(ids)
119  
120      def search(self, queries, limit):
121          results = []
122  
123          # Calculate number of threads using a thread batch size of 32
124          threads = queries.shape[0] // 32
125          threads = min(max(threads, 1), os.cpu_count())
126  
127          # Approximate search
128          blockids = self.topn(queries, self.centroids, self.nprobe())[0] if self.centroids is not None else None
129  
130          # This method is able to run as multiple threads due to a number of numpy/scipy method calls that drop the GIL.
131          results = []
132          with ThreadPool(threads) as pool:
133              for result in pool.starmap(self.scan, [(x, limit, blockids[i] if blockids is not None else None) for i, x in enumerate(queries)]):
134                  results.append(result)
135  
136          return results
137  
138      def count(self):
139          return self.size() - len(self.deletes)
140  
141      def load(self, path):
142          # Create streaming serializer and limit read size to a byte at a time to ensure
143          # only msgpack data is consumed
144          serializer = SerializeFactory.create("msgpack", streaming=True, read_size=1)
145  
146          with open(path, "rb") as f:
147              # Read header
148              unpacker = serializer.loadstream(f)
149              header = next(unpacker)
150  
151              # Read cluster centroids, if available
152              self.centroids = SparseArray().load(f) if header["centroids"] else None
153  
154              # Read cluster ids
155              self.ids = dict(next(unpacker))
156  
157              # Read cluster data blocks
158              self.blocks = {}
159              for key in self.ids:
160                  self.blocks[key] = SparseArray().load(f)
161  
162              # Read deletes
163              self.deletes = next(unpacker)
164  
165      def save(self, path):
166          # IVFSparse storage format:
167          #    - header msgpack
168          #    - centroids sparse array (optional based on header parameters)
169          #    - cluster ids msgpack
170          #    - cluster data blocks list of sparse arrays
171          #    - deletes msgpack
172  
173          # Create message pack serializer
174          serializer = SerializeFactory.create("msgpack")
175  
176          with open(path, "wb") as f:
177              # Write header
178              serializer.savestream({"centroids": self.centroids is not None, "count": self.count(), "blocks": len(self.blocks)}, f)
179  
180              # Write cluster centroids, if available
181              if self.centroids is not None:
182                  SparseArray().save(f, self.centroids)
183  
184              # Write cluster id mapping
185              serializer.savestream(list(self.ids.items()), f)
186  
187              # Write cluster data blocks
188              for block in self.blocks.values():
189                  SparseArray().save(f, block)
190  
191              # Write deletes
192              serializer.savestream(self.deletes, f)
193  
194      def build(self, train, clusters):
195          """
196          Builds a k-means cluster to calculate centroid points for aggregating data blocks.
197  
198          Args:
199              train: training data
200              clusters: number of clusters to create
201  
202          Returns:
203              cluster centroids
204          """
205  
206          # Select top n most important features that contribute to L2 vector norm
207          indices = np.argsort(-norm(train, axis=0))[: self.setting("nfeatures", 25)]
208          data = train[:, indices]
209          data = train
210  
211          # Cluster data using k-means
212          kmeans = MiniBatchKMeans(n_clusters=clusters, random_state=0, n_init=5).fit(data)
213  
214          # Find closest points to each cluster center and use those as centroids
215          indices, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, data, metric="l2")
216  
217          # Filter out duplicate centroids and return cluster centroids
218          return train[np.unique(indices)]
219  
220      def aggregate(self, data):
221          """
222          Aggregates input data array into clusters. This method sorts each data element into the
223          cluster with the highest L2 similarity centroid.
224  
225          Args:
226              data: input data
227  
228          Returns:
229              {cluster, ids}
230          """
231  
232          # Exact search when only a single cluster
233          if self.centroids is None:
234              return {0: list(range(data.shape[0]))}
235  
236          # Map data to closest centroids
237          indices, _ = pairwise_distances_argmin_min(data, self.centroids, metric="l2")
238  
239          # Sort into clusters
240          ids = {}
241          for x, cluster in enumerate(indices.tolist()):
242              if cluster not in ids:
243                  ids[cluster] = []
244  
245              # Save id
246              ids[cluster].append(x)
247  
248          return ids
249  
250      def topn(self, queries, data, limit, deletes=None):
251          """
252          Gets the top n most similar data elements for query.
253  
254          Args:
255              queries: queries array
256              data: data array
257              limit: top n
258              deletes: optional list of deletes to filter from results
259  
260          Returns:
261              list of matching (indices, scores)
262          """
263  
264          # Dot product similarity
265          scores = safe_sparse_dot(queries, data.T, dense_output=True)
266  
267          # Clear deletes
268          if deletes is not None:
269              scores[:, deletes] = 0
270  
271          # Get top n matching indices and scores
272          indices = np.argpartition(-scores, limit if limit < scores.shape[1] else scores.shape[1] - 1)[:, :limit]
273          scores = np.take_along_axis(scores, indices, axis=1)
274  
275          return indices, scores
276  
277      def scan(self, query, limit, blockids):
278          """
279          Scans a list of blocks for top n ids that match query.
280  
281          Args:
282              query: input query
283              limit top n
284              blockids: block ids to scan
285  
286          Returns:
287              list of (id, scores)
288          """
289  
290          if self.centroids is not None:
291              # Stack into single ids list
292              ids = np.concatenate([self.ids[x] for x in blockids if x in self.ids])
293  
294              # Stack data rows
295              data = vstack([self.blocks[x] for x in blockids if x in self.blocks])
296          else:
297              # Exact search
298              ids, data = np.array(self.ids[0]), self.blocks[0]
299  
300          # Get deletes
301          deletes = np.argwhere(np.isin(ids, self.deletes)).ravel()
302  
303          # Calculate similarity
304          indices, scores = self.topn(query, data, limit, deletes)
305          indices, scores = indices[0], scores[0]
306  
307          # Map data ids and return
308          return list(zip(ids[indices].tolist(), scores.tolist()))
309  
310      def nlist(self, count, train):
311          """
312          Calculates the number of clusters for this IVFSparse index. Note that the final number of clusters
313          could be lower as duplicate cluster centroids are filtered out.
314  
315          Args:
316              count: initial dataset size
317              train: number of rows used to train
318  
319          Returns:
320              number of clusters
321          """
322  
323          # Get data size
324          default = 1 if count <= 5000 else self.cells(train)
325  
326          # Number of clusters to create
327          return self.setting("nlist", default)
328  
329      def nprobe(self):
330          """
331          Gets or derives the nprobe search parameter.
332  
333          Returns:
334              nprobe setting
335          """
336  
337          # Get size of embeddings index
338          size = self.size()
339  
340          default = 6 if size <= 5000 else self.cells(size) // 16
341          return self.setting("nprobe", default)
342  
343      def cells(self, count):
344          """
345          Calculates the number of IVF cells for an IVFSparse index.
346  
347          Args:
348              count: number of rows
349  
350          Returns:
351              number of IVF cells
352          """
353  
354          # Calculate number of IVF cells where x = min(4 * sqrt(count), count / minpoints)
355          return max(min(round(4 * math.sqrt(count)), int(count / self.minpoints())), 1)
356  
357      def size(self):
358          """
359          Gets the total size of this index including deletes.
360  
361          Returns:
362              size
363          """
364  
365          return sum(len(x) for x in self.ids.values())
366  
367      def minpoints(self):
368          """
369          Gets the minimum number of points per cluster.
370  
371          Returns:
372              minimum points per cluster
373          """
374  
375          # Minimum number of points per cluster
376          # Match faiss default that requires at least 39 points per clusters
377          return self.setting("minpoints", 39)