testdense.py
1 """ 2 Dense ANN module tests 3 """ 4 5 import os 6 import platform 7 import sys 8 import tempfile 9 import unittest 10 11 from unittest.mock import patch 12 13 import numpy as np 14 15 from txtai.ann import ANNFactory, ANN 16 from txtai.serialize import SerializeFactory 17 18 19 # pylint: disable=R0904 20 class TestDense(unittest.TestCase): 21 """ 22 Dense ANN tests. 23 """ 24 25 def testAnnoy(self): 26 """ 27 Test Annoy backend 28 """ 29 30 self.runTests("annoy", None, False) 31 32 def testAnnoyCustom(self): 33 """ 34 Test Annoy backend with custom settings 35 """ 36 37 # Test with custom settings 38 self.runTests("annoy", {"annoy": {"ntrees": 2, "searchk": 1}}, False) 39 40 def testCustomBackend(self): 41 """ 42 Test resolving a custom backend 43 """ 44 45 self.runTests("txtai.ann.Faiss") 46 47 def testCustomBackendNotFound(self): 48 """ 49 Test resolving an unresolvable backend 50 """ 51 52 with self.assertRaises(ImportError): 53 ANNFactory.create({"backend": "notfound.ann"}) 54 55 def testFaiss(self): 56 """ 57 Test Faiss backend 58 """ 59 60 self.runTests("faiss") 61 62 def testFaissBinary(self): 63 """ 64 Test Faiss backend with a binary hash index 65 """ 66 67 ann = ANNFactory.create({"backend": "faiss", "quantize": 1, "dimensions": 240 * 8, "faiss": {"components": "BHash32"}}) 68 69 # Generate and index dummy data 70 data = np.random.rand(100, 240).astype(np.uint8) 71 ann.index(data) 72 73 # Generate query vector and test search 74 query = np.random.rand(240).astype(np.uint8) 75 self.assertGreater(ann.search(np.array([query]), 1)[0][0][1], 0) 76 77 def testFaissCustom(self): 78 """ 79 Test Faiss backend with custom settings 80 """ 81 82 # Test with custom settings 83 self.runTests("faiss", {"faiss": {"nprobe": 2, "components": "PCA16,IDMap,SQ8", "sample": 1.0}}, False) 84 self.runTests("faiss", {"faiss": {"components": "IVF,SQ8"}}, False) 85 86 @patch("platform.system") 87 def testFaissMacOS(self, system): 88 """ 89 Test Faiss backend with macOS 90 """ 91 92 # Run test 93 system.return_value = "Darwin" 94 95 # pylint: disable=C0415, W0611 96 # Force reload of class 97 name = "txtai.ann.dense.faiss" 98 module = sys.modules[name] 99 del sys.modules[name] 100 import txtai.ann.dense.faiss 101 102 # Run tests 103 self.runTests("faiss") 104 105 # Restore original module 106 sys.modules[name] = module 107 108 @unittest.skipIf(os.name == "nt", "mmap not supported on Windows") 109 def testFaissMmap(self): 110 """ 111 Test Faiss backend with mmap enabled 112 """ 113 114 # Test to with mmap enabled 115 self.runTests("faiss", {"faiss": {"mmap": True}}, False) 116 117 def testGGML(self): 118 """ 119 Test GGML backend 120 """ 121 122 self.runTests("ggml") 123 124 def testGGMLQuantization(self): 125 """ 126 Test GGML backend with quantization enabled 127 """ 128 129 ann = ANNFactory.create({"backend": "ggml", "ggml": {"quantize": "Q4_0"}}) 130 131 # Generate and index dummy data 132 data = np.random.rand(100, 256).astype(np.float32) 133 ann.index(data) 134 135 # Test save and load 136 index = os.path.join(tempfile.gettempdir(), "ggml.q4_0.v1") 137 ann.save(index) 138 ann.load(index) 139 140 # Generate query vector and test search 141 query = np.random.rand(256).astype(np.float32) 142 self.normalize(query) 143 self.assertGreater(ann.search(np.array([query]), 1)[0][0][1], 0) 144 145 # Validate count 146 self.assertEqual(ann.count(), 100) 147 148 # Test delete 149 ann.delete([0]) 150 self.assertEqual(ann.count(), 99) 151 152 # Save updated index with deletes and reload 153 index = os.path.join(tempfile.gettempdir(), "ggml.q4_0.v2") 154 ann.save(index) 155 ann.load(index) 156 ann.index(data) 157 158 def testGGMLInvalid(self): 159 """ 160 Test invalid GGML configurations 161 """ 162 163 data = np.random.rand(100, 240).astype(np.float32) 164 165 with self.assertRaises(ValueError): 166 ann = ANNFactory.create({"backend": "ggml", "ggml": {"quantize": "NOEXIST", "gpu": False}}) 167 ann.index(data) 168 169 with self.assertRaises(ValueError): 170 ann = ANNFactory.create({"backend": "ggml", "ggml": {"quantize": "Q4_K"}}) 171 ann.index(data) 172 173 def testHnsw(self): 174 """ 175 Test Hnswlib backend 176 """ 177 178 self.runTests("hnsw") 179 180 def testHnswCustom(self): 181 """ 182 Test Hnswlib backend with custom settings 183 """ 184 185 # Test with custom settings 186 self.runTests("hnsw", {"hnsw": {"efconstruction": 100, "m": 4, "randomseed": 0, "efsearch": 5}}) 187 188 def testNotImplemented(self): 189 """ 190 Test exceptions for non-implemented methods 191 """ 192 193 ann = ANN({}) 194 195 self.assertRaises(NotImplementedError, ann.load, None) 196 self.assertRaises(NotImplementedError, ann.index, None) 197 self.assertRaises(NotImplementedError, ann.append, None) 198 self.assertRaises(NotImplementedError, ann.delete, None) 199 self.assertRaises(NotImplementedError, ann.search, None, None) 200 self.assertRaises(NotImplementedError, ann.count) 201 self.assertRaises(NotImplementedError, ann.save, None) 202 203 def testNumPy(self): 204 """ 205 Test NumPy backend 206 """ 207 208 self.runTests("numpy") 209 210 @patch.dict(os.environ, {"ALLOW_PICKLE": "True"}) 211 def testNumPyLegacy(self): 212 """ 213 Test NumPy backend with legacy pickled data 214 """ 215 216 serializer = SerializeFactory.create("pickle", allowpickle=True) 217 218 # Create output directory 219 output = os.path.join(tempfile.gettempdir(), "ann.npy") 220 path = os.path.join(output, "embeddings") 221 os.makedirs(output, exist_ok=True) 222 223 # Generate data and save as pickle 224 data = np.random.rand(100, 240).astype(np.float32) 225 serializer.save(data, path) 226 227 ann = ANNFactory.create({"backend": "numpy"}) 228 ann.load(path) 229 230 # Validate count 231 self.assertEqual(ann.count(), 100) 232 233 def testNumPySafetensors(self): 234 """ 235 Test NumPy backend with safetensors storage 236 """ 237 238 ann = ANNFactory.create({"backend": "numpy", "numpy": {"safetensors": True}}) 239 240 # Generate and index dummy data 241 data = np.random.rand(100, 240).astype(np.float32) 242 ann.index(data) 243 244 # Test save and load 245 index = os.path.join(tempfile.gettempdir(), "numpy.safetensors") 246 ann.save(index) 247 ann.load(index) 248 249 # Generate query vector and test search 250 query = np.random.rand(240).astype(np.float32) 251 self.normalize(query) 252 self.assertGreater(ann.search(np.array([query]), 1)[0][0][1], 0) 253 254 # Validate count 255 self.assertEqual(ann.count(), 100) 256 257 @patch("sqlalchemy.orm.Query.limit") 258 def testPGVector(self, query): 259 """ 260 Test PGVector backend 261 """ 262 263 # Generate test record 264 data = np.random.rand(1, 240).astype(np.float32) 265 266 # Mock database query 267 query.return_value = [(x, -1.0) for x in range(data.shape[0])] 268 269 configs = [ 270 ("full", {"dimensions": 240}, {}, data), 271 ("half", {"dimensions": 240}, {"precision": "half"}, data), 272 ("binary", {"quantize": 1, "dimensions": 240 * 8}, {}, data.astype(np.uint8)), 273 ] 274 275 # Create ANN 276 for name, config, pgvector, data in configs: 277 path = os.path.join(tempfile.gettempdir(), f"pgvector.{name}.sqlite") 278 ann = ANNFactory.create( 279 {**{"backend": "pgvector", "pgvector": {**{"url": f"sqlite:///{path}", "schema": "txtai"}, **pgvector}}, **config} 280 ) 281 282 # Test indexing 283 ann.index(data) 284 ann.append(data) 285 286 # Validate search results 287 self.assertEqual(ann.search(data, 1), [[(0, 1.0)]]) 288 289 # Validate save/load/delete 290 ann.save(None) 291 ann.load(None) 292 293 # Validate count 294 self.assertEqual(ann.count(), 2) 295 296 # Test delete 297 ann.delete([0]) 298 self.assertEqual(ann.count(), 1) 299 300 # Close ANN 301 ann.close() 302 303 @unittest.skipIf(platform.system() == "Darwin", "SQLite extensions not supported on macOS") 304 def testSQLite(self): 305 """ 306 Test SQLite backend 307 """ 308 309 self.runTests("sqlite") 310 311 @unittest.skipIf(platform.system() == "Darwin", "SQLite extensions not supported on macOS") 312 def testSQLiteCustom(self): 313 """ 314 Test SQLite backend with custom settings 315 """ 316 317 # Test with custom settings 318 self.runTests("sqlite", {"sqlite": {"quantize": 1}}) 319 self.runTests("sqlite", {"sqlite": {"quantize": 8}}) 320 321 # Test saving to a new path 322 model = self.backend("sqlite") 323 expected = model.count() - 1 324 325 # Test save variations 326 index = os.path.join(tempfile.gettempdir(), "ann.sqlite") 327 new = os.path.join(tempfile.gettempdir(), "ann.sqlite.new") 328 329 # Save new 330 model.save(index) 331 332 # Save to same path 333 model.save(index) 334 335 # Delete id 336 model.delete([0]) 337 338 # Save to another path 339 model.load(index) 340 model.save(new) 341 342 self.assertEqual(model.count(), expected) 343 344 def testTorch(self): 345 """ 346 Test Torch backend 347 """ 348 349 self.runTests("torch") 350 351 @unittest.skipIf(platform.system() == "Darwin", "Torch quantization not supported on macOS") 352 def testTorchQuantization(self): 353 """ 354 Test Torch backend with quantization enabled 355 """ 356 357 for qtype in ["fp4", "nf4", "int8"]: 358 ann = ANNFactory.create({"backend": "torch", "torch": {"quantize": {"type": qtype}}}) 359 360 # Generate and index dummy data 361 data = np.random.rand(100, 240).astype(np.float32) 362 ann.index(data) 363 364 # Test save and load 365 index = os.path.join(tempfile.gettempdir(), f"{qtype}.safetensors") 366 ann.save(index) 367 ann.load(index) 368 369 # Generate query vector and test search 370 query = np.random.rand(240).astype(np.float32) 371 self.normalize(query) 372 self.assertGreater(ann.search(np.array([query]), 1)[0][0][1], 0) 373 374 # Validate count 375 self.assertEqual(ann.count(), 100) 376 377 # Test delete 378 ann.delete([0]) 379 self.assertEqual(ann.count(), 99) 380 381 def runTests(self, name, params=None, update=True): 382 """ 383 Runs a series of standard backend tests. 384 385 Args: 386 name: backend name 387 params: additional config parameters 388 update: If append/delete options should be tested 389 """ 390 391 self.assertEqual(self.backend(name, params).config["backend"], name) 392 self.assertEqual(self.save(name, params).count(), 10000) 393 394 if update: 395 self.assertEqual(self.append(name, params, 500).count(), 10500) 396 self.assertEqual(self.delete(name, params, [0, 1]).count(), 9998) 397 self.assertEqual(self.delete(name, params, [100000]).count(), 10000) 398 399 self.assertGreater(self.search(name, params), 0) 400 401 def backend(self, name, params=None, length=10000): 402 """ 403 Test a backend. 404 405 Args: 406 name: backend name 407 params: additional config parameters 408 length: number of rows to generate 409 410 Returns: 411 ANN model 412 """ 413 414 # Generate test data 415 data = np.random.rand(length, 240).astype(np.float32) 416 self.normalize(data) 417 418 config = {"backend": name, "dimensions": data.shape[1]} 419 if params: 420 config.update(params) 421 422 model = ANNFactory.create(config) 423 model.index(data) 424 425 return model 426 427 def append(self, name, params=None, length=500): 428 """ 429 Appends new data to index. 430 431 Args: 432 name: backend name 433 params: additional config parameters 434 length: number of rows to generate 435 436 Returns: 437 ANN model 438 """ 439 440 # Initial model 441 model = self.backend(name, params) 442 443 # Generate test data 444 data = np.random.rand(length, 240).astype(np.float32) 445 self.normalize(data) 446 447 model.append(data) 448 449 return model 450 451 def delete(self, name, params=None, ids=None): 452 """ 453 Deletes data from index. 454 455 Args: 456 name: backend name 457 params: additional config parameters 458 ids: ids to delete 459 460 Returns: 461 ANN model 462 """ 463 464 # Initial model 465 model = self.backend(name, params) 466 model.delete(ids) 467 468 return model 469 470 def save(self, name, params=None): 471 """ 472 Test save/load. 473 474 Args: 475 name: backend name 476 params: additional config parameters 477 478 Returns: 479 ANN model 480 """ 481 482 model = self.backend(name, params) 483 484 # Generate temp file path 485 index = os.path.join(tempfile.gettempdir(), "ann") 486 487 # Save and close index 488 model.save(index) 489 model.close() 490 491 # Reload index 492 model.load(index) 493 494 return model 495 496 def search(self, name, params=None): 497 """ 498 Test ANN search. 499 500 Args: 501 name: backend name 502 params: additional config parameters 503 504 Returns: 505 search results 506 """ 507 508 # Generate ANN index 509 model = self.backend(name, params) 510 511 # Generate query vector 512 query = np.random.rand(240).astype(np.float32) 513 self.normalize(query) 514 515 # Ensure top result has similarity > 0 516 return model.search(np.array([query]), 1)[0][0][1] 517 518 def normalize(self, embeddings): 519 """ 520 Normalizes embeddings using L2 normalization. Operation applied directly on array. 521 522 Args: 523 embeddings: input embeddings matrix 524 """ 525 526 # Calculation is different for matrices vs vectors 527 if len(embeddings.shape) > 1: 528 embeddings /= np.linalg.norm(embeddings, axis=1)[:, np.newaxis] 529 else: 530 embeddings /= np.linalg.norm(embeddings)