ivfsparse.py
1 """ 2 IVFSparse module 3 """ 4 5 import math 6 import os 7 8 from multiprocessing.pool import ThreadPool 9 10 import numpy as np 11 12 # Conditional import 13 try: 14 from scipy.sparse import csr_matrix, vstack 15 from scipy.sparse.linalg import norm 16 from sklearn.cluster import MiniBatchKMeans 17 from sklearn.metrics import pairwise_distances_argmin_min 18 from sklearn.utils.extmath import safe_sparse_dot 19 20 IVFSPARSE = True 21 except ImportError: 22 IVFSPARSE = False 23 24 from ...serialize import SerializeFactory 25 from ...util import SparseArray 26 from ..base import ANN 27 28 29 class IVFSparse(ANN): 30 """ 31 Inverted file (IVF) index with flat vector file storage and sparse array support. 32 33 IVFSparse builds an IVF index and enables approximate nearest neighbor (ANN) search. 34 35 This index is modeled after Faiss and supports many of the same parameters. 36 37 See this link for more: https://github.com/facebookresearch/faiss/wiki/Faster-search 38 """ 39 40 def __init__(self, config): 41 super().__init__(config) 42 43 if not IVFSPARSE: 44 raise ImportError('IVFSparse is not available - install "ann" extra to enable') 45 46 # Cluster centroids, if computed 47 self.centroids = None 48 49 # Cluster id mapping 50 self.ids = None 51 52 # Cluster data blocks - can be a single block with no computed centroids 53 self.blocks = None 54 55 # Deleted ids 56 self.deletes = None 57 58 def index(self, embeddings): 59 # Compute model training size 60 train, sample = embeddings, self.setting("sample") 61 if sample: 62 # Get sample for training 63 rng = np.random.default_rng(0) 64 indices = sorted(rng.choice(train.shape[0], int(sample * train.shape[0]), replace=False, shuffle=False)) 65 train = train[indices] 66 67 # Get number of clusters. Note that final number of clusters could be lower due to filtering duplicate centroids 68 # and pruning of small clusters 69 clusters = self.nlist(embeddings.shape[0], train.shape[0]) 70 71 # Build cluster centroids if approximate search is enabled 72 # A single cluster performs exact search 73 self.centroids = self.build(train, clusters) if clusters > 1 else None 74 75 # Sort into clusters 76 ids = self.aggregate(embeddings) 77 78 # Prune small clusters (less than minpoints parameter) and rebuild 79 indices = sorted(k for k, v in ids.items() if len(v) >= self.minpoints()) 80 if len(indices) > 0 and len(ids) > 1 and len(indices) != len(ids.keys()): 81 self.centroids = self.centroids[indices] 82 ids = self.aggregate(embeddings) 83 84 # Sort clusters by id 85 self.ids = dict(sorted(ids.items(), key=lambda x: x[0])) 86 87 # Create cluster data blocks 88 self.blocks = {k: embeddings[v] for k, v in self.ids.items()} 89 90 # Calculate block max summary vectors and use as centroids 91 self.centroids = vstack([csr_matrix(x.max(axis=0)) for x in self.blocks.values()]) if self.centroids is not None else None 92 93 # Initialize deletes 94 self.deletes = [] 95 96 # Add id offset and index build metadata 97 self.config["offset"] = embeddings.shape[0] 98 self.metadata({"clusters": len(self.blocks)}) 99 100 def append(self, embeddings): 101 # Get offset 102 offset = self.size() 103 104 # Sort into clusters and merge 105 for cluster, ids in self.aggregate(embeddings).items(): 106 # Add new ids 107 self.ids[cluster].extend([x + offset for x in ids]) 108 109 # Add new data 110 self.blocks[cluster] = vstack([self.blocks[cluster], embeddings[ids]]) 111 112 # Update id offset and index metadata 113 self.config["offset"] += embeddings.shape[0] 114 self.metadata() 115 116 def delete(self, ids): 117 # Set index ids as deleted 118 self.deletes.extend(ids) 119 120 def search(self, queries, limit): 121 results = [] 122 123 # Calculate number of threads using a thread batch size of 32 124 threads = queries.shape[0] // 32 125 threads = min(max(threads, 1), os.cpu_count()) 126 127 # Approximate search 128 blockids = self.topn(queries, self.centroids, self.nprobe())[0] if self.centroids is not None else None 129 130 # This method is able to run as multiple threads due to a number of numpy/scipy method calls that drop the GIL. 131 results = [] 132 with ThreadPool(threads) as pool: 133 for result in pool.starmap(self.scan, [(x, limit, blockids[i] if blockids is not None else None) for i, x in enumerate(queries)]): 134 results.append(result) 135 136 return results 137 138 def count(self): 139 return self.size() - len(self.deletes) 140 141 def load(self, path): 142 # Create streaming serializer and limit read size to a byte at a time to ensure 143 # only msgpack data is consumed 144 serializer = SerializeFactory.create("msgpack", streaming=True, read_size=1) 145 146 with open(path, "rb") as f: 147 # Read header 148 unpacker = serializer.loadstream(f) 149 header = next(unpacker) 150 151 # Read cluster centroids, if available 152 self.centroids = SparseArray().load(f) if header["centroids"] else None 153 154 # Read cluster ids 155 self.ids = dict(next(unpacker)) 156 157 # Read cluster data blocks 158 self.blocks = {} 159 for key in self.ids: 160 self.blocks[key] = SparseArray().load(f) 161 162 # Read deletes 163 self.deletes = next(unpacker) 164 165 def save(self, path): 166 # IVFSparse storage format: 167 # - header msgpack 168 # - centroids sparse array (optional based on header parameters) 169 # - cluster ids msgpack 170 # - cluster data blocks list of sparse arrays 171 # - deletes msgpack 172 173 # Create message pack serializer 174 serializer = SerializeFactory.create("msgpack") 175 176 with open(path, "wb") as f: 177 # Write header 178 serializer.savestream({"centroids": self.centroids is not None, "count": self.count(), "blocks": len(self.blocks)}, f) 179 180 # Write cluster centroids, if available 181 if self.centroids is not None: 182 SparseArray().save(f, self.centroids) 183 184 # Write cluster id mapping 185 serializer.savestream(list(self.ids.items()), f) 186 187 # Write cluster data blocks 188 for block in self.blocks.values(): 189 SparseArray().save(f, block) 190 191 # Write deletes 192 serializer.savestream(self.deletes, f) 193 194 def build(self, train, clusters): 195 """ 196 Builds a k-means cluster to calculate centroid points for aggregating data blocks. 197 198 Args: 199 train: training data 200 clusters: number of clusters to create 201 202 Returns: 203 cluster centroids 204 """ 205 206 # Select top n most important features that contribute to L2 vector norm 207 indices = np.argsort(-norm(train, axis=0))[: self.setting("nfeatures", 25)] 208 data = train[:, indices] 209 data = train 210 211 # Cluster data using k-means 212 kmeans = MiniBatchKMeans(n_clusters=clusters, random_state=0, n_init=5).fit(data) 213 214 # Find closest points to each cluster center and use those as centroids 215 indices, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, data, metric="l2") 216 217 # Filter out duplicate centroids and return cluster centroids 218 return train[np.unique(indices)] 219 220 def aggregate(self, data): 221 """ 222 Aggregates input data array into clusters. This method sorts each data element into the 223 cluster with the highest L2 similarity centroid. 224 225 Args: 226 data: input data 227 228 Returns: 229 {cluster, ids} 230 """ 231 232 # Exact search when only a single cluster 233 if self.centroids is None: 234 return {0: list(range(data.shape[0]))} 235 236 # Map data to closest centroids 237 indices, _ = pairwise_distances_argmin_min(data, self.centroids, metric="l2") 238 239 # Sort into clusters 240 ids = {} 241 for x, cluster in enumerate(indices.tolist()): 242 if cluster not in ids: 243 ids[cluster] = [] 244 245 # Save id 246 ids[cluster].append(x) 247 248 return ids 249 250 def topn(self, queries, data, limit, deletes=None): 251 """ 252 Gets the top n most similar data elements for query. 253 254 Args: 255 queries: queries array 256 data: data array 257 limit: top n 258 deletes: optional list of deletes to filter from results 259 260 Returns: 261 list of matching (indices, scores) 262 """ 263 264 # Dot product similarity 265 scores = safe_sparse_dot(queries, data.T, dense_output=True) 266 267 # Clear deletes 268 if deletes is not None: 269 scores[:, deletes] = 0 270 271 # Get top n matching indices and scores 272 indices = np.argpartition(-scores, limit if limit < scores.shape[1] else scores.shape[1] - 1)[:, :limit] 273 scores = np.take_along_axis(scores, indices, axis=1) 274 275 return indices, scores 276 277 def scan(self, query, limit, blockids): 278 """ 279 Scans a list of blocks for top n ids that match query. 280 281 Args: 282 query: input query 283 limit top n 284 blockids: block ids to scan 285 286 Returns: 287 list of (id, scores) 288 """ 289 290 if self.centroids is not None: 291 # Stack into single ids list 292 ids = np.concatenate([self.ids[x] for x in blockids if x in self.ids]) 293 294 # Stack data rows 295 data = vstack([self.blocks[x] for x in blockids if x in self.blocks]) 296 else: 297 # Exact search 298 ids, data = np.array(self.ids[0]), self.blocks[0] 299 300 # Get deletes 301 deletes = np.argwhere(np.isin(ids, self.deletes)).ravel() 302 303 # Calculate similarity 304 indices, scores = self.topn(query, data, limit, deletes) 305 indices, scores = indices[0], scores[0] 306 307 # Map data ids and return 308 return list(zip(ids[indices].tolist(), scores.tolist())) 309 310 def nlist(self, count, train): 311 """ 312 Calculates the number of clusters for this IVFSparse index. Note that the final number of clusters 313 could be lower as duplicate cluster centroids are filtered out. 314 315 Args: 316 count: initial dataset size 317 train: number of rows used to train 318 319 Returns: 320 number of clusters 321 """ 322 323 # Get data size 324 default = 1 if count <= 5000 else self.cells(train) 325 326 # Number of clusters to create 327 return self.setting("nlist", default) 328 329 def nprobe(self): 330 """ 331 Gets or derives the nprobe search parameter. 332 333 Returns: 334 nprobe setting 335 """ 336 337 # Get size of embeddings index 338 size = self.size() 339 340 default = 6 if size <= 5000 else self.cells(size) // 16 341 return self.setting("nprobe", default) 342 343 def cells(self, count): 344 """ 345 Calculates the number of IVF cells for an IVFSparse index. 346 347 Args: 348 count: number of rows 349 350 Returns: 351 number of IVF cells 352 """ 353 354 # Calculate number of IVF cells where x = min(4 * sqrt(count), count / minpoints) 355 return max(min(round(4 * math.sqrt(count)), int(count / self.minpoints())), 1) 356 357 def size(self): 358 """ 359 Gets the total size of this index including deletes. 360 361 Returns: 362 size 363 """ 364 365 return sum(len(x) for x in self.ids.values()) 366 367 def minpoints(self): 368 """ 369 Gets the minimum number of points per cluster. 370 371 Returns: 372 minimum points per cluster 373 """ 374 375 # Minimum number of points per cluster 376 # Match faiss default that requires at least 39 points per clusters 377 return self.setting("minpoints", 39)