/ test / python / testann / testdense.py
testdense.py
  1  """
  2  Dense ANN module tests
  3  """
  4  
  5  import os
  6  import platform
  7  import sys
  8  import tempfile
  9  import unittest
 10  
 11  from unittest.mock import patch
 12  
 13  import numpy as np
 14  
 15  from txtai.ann import ANNFactory, ANN
 16  from txtai.serialize import SerializeFactory
 17  
 18  
 19  # pylint: disable=R0904
 20  class TestDense(unittest.TestCase):
 21      """
 22      Dense ANN tests.
 23      """
 24  
 25      def testAnnoy(self):
 26          """
 27          Test Annoy backend
 28          """
 29  
 30          self.runTests("annoy", None, False)
 31  
 32      def testAnnoyCustom(self):
 33          """
 34          Test Annoy backend with custom settings
 35          """
 36  
 37          # Test with custom settings
 38          self.runTests("annoy", {"annoy": {"ntrees": 2, "searchk": 1}}, False)
 39  
 40      def testCustomBackend(self):
 41          """
 42          Test resolving a custom backend
 43          """
 44  
 45          self.runTests("txtai.ann.Faiss")
 46  
 47      def testCustomBackendNotFound(self):
 48          """
 49          Test resolving an unresolvable backend
 50          """
 51  
 52          with self.assertRaises(ImportError):
 53              ANNFactory.create({"backend": "notfound.ann"})
 54  
 55      def testFaiss(self):
 56          """
 57          Test Faiss backend
 58          """
 59  
 60          self.runTests("faiss")
 61  
 62      def testFaissBinary(self):
 63          """
 64          Test Faiss backend with a binary hash index
 65          """
 66  
 67          ann = ANNFactory.create({"backend": "faiss", "quantize": 1, "dimensions": 240 * 8, "faiss": {"components": "BHash32"}})
 68  
 69          # Generate and index dummy data
 70          data = np.random.rand(100, 240).astype(np.uint8)
 71          ann.index(data)
 72  
 73          # Generate query vector and test search
 74          query = np.random.rand(240).astype(np.uint8)
 75          self.assertGreater(ann.search(np.array([query]), 1)[0][0][1], 0)
 76  
 77      def testFaissCustom(self):
 78          """
 79          Test Faiss backend with custom settings
 80          """
 81  
 82          # Test with custom settings
 83          self.runTests("faiss", {"faiss": {"nprobe": 2, "components": "PCA16,IDMap,SQ8", "sample": 1.0}}, False)
 84          self.runTests("faiss", {"faiss": {"components": "IVF,SQ8"}}, False)
 85  
 86      @patch("platform.system")
 87      def testFaissMacOS(self, system):
 88          """
 89          Test Faiss backend with macOS
 90          """
 91  
 92          # Run test
 93          system.return_value = "Darwin"
 94  
 95          # pylint: disable=C0415, W0611
 96          # Force reload of class
 97          name = "txtai.ann.dense.faiss"
 98          module = sys.modules[name]
 99          del sys.modules[name]
100          import txtai.ann.dense.faiss
101  
102          # Run tests
103          self.runTests("faiss")
104  
105          # Restore original module
106          sys.modules[name] = module
107  
108      @unittest.skipIf(os.name == "nt", "mmap not supported on Windows")
109      def testFaissMmap(self):
110          """
111          Test Faiss backend with mmap enabled
112          """
113  
114          # Test to with mmap enabled
115          self.runTests("faiss", {"faiss": {"mmap": True}}, False)
116  
117      def testGGML(self):
118          """
119          Test GGML backend
120          """
121  
122          self.runTests("ggml")
123  
124      def testGGMLQuantization(self):
125          """
126          Test GGML backend with quantization enabled
127          """
128  
129          ann = ANNFactory.create({"backend": "ggml", "ggml": {"quantize": "Q4_0"}})
130  
131          # Generate and index dummy data
132          data = np.random.rand(100, 256).astype(np.float32)
133          ann.index(data)
134  
135          # Test save and load
136          index = os.path.join(tempfile.gettempdir(), "ggml.q4_0.v1")
137          ann.save(index)
138          ann.load(index)
139  
140          # Generate query vector and test search
141          query = np.random.rand(256).astype(np.float32)
142          self.normalize(query)
143          self.assertGreater(ann.search(np.array([query]), 1)[0][0][1], 0)
144  
145          # Validate count
146          self.assertEqual(ann.count(), 100)
147  
148          # Test delete
149          ann.delete([0])
150          self.assertEqual(ann.count(), 99)
151  
152          # Save updated index with deletes and reload
153          index = os.path.join(tempfile.gettempdir(), "ggml.q4_0.v2")
154          ann.save(index)
155          ann.load(index)
156          ann.index(data)
157  
158      def testGGMLInvalid(self):
159          """
160          Test invalid GGML configurations
161          """
162  
163          data = np.random.rand(100, 240).astype(np.float32)
164  
165          with self.assertRaises(ValueError):
166              ann = ANNFactory.create({"backend": "ggml", "ggml": {"quantize": "NOEXIST", "gpu": False}})
167              ann.index(data)
168  
169          with self.assertRaises(ValueError):
170              ann = ANNFactory.create({"backend": "ggml", "ggml": {"quantize": "Q4_K"}})
171              ann.index(data)
172  
173      def testHnsw(self):
174          """
175          Test Hnswlib backend
176          """
177  
178          self.runTests("hnsw")
179  
180      def testHnswCustom(self):
181          """
182          Test Hnswlib backend with custom settings
183          """
184  
185          # Test with custom settings
186          self.runTests("hnsw", {"hnsw": {"efconstruction": 100, "m": 4, "randomseed": 0, "efsearch": 5}})
187  
188      def testNotImplemented(self):
189          """
190          Test exceptions for non-implemented methods
191          """
192  
193          ann = ANN({})
194  
195          self.assertRaises(NotImplementedError, ann.load, None)
196          self.assertRaises(NotImplementedError, ann.index, None)
197          self.assertRaises(NotImplementedError, ann.append, None)
198          self.assertRaises(NotImplementedError, ann.delete, None)
199          self.assertRaises(NotImplementedError, ann.search, None, None)
200          self.assertRaises(NotImplementedError, ann.count)
201          self.assertRaises(NotImplementedError, ann.save, None)
202  
203      def testNumPy(self):
204          """
205          Test NumPy backend
206          """
207  
208          self.runTests("numpy")
209  
210      @patch.dict(os.environ, {"ALLOW_PICKLE": "True"})
211      def testNumPyLegacy(self):
212          """
213          Test NumPy backend with legacy pickled data
214          """
215  
216          serializer = SerializeFactory.create("pickle", allowpickle=True)
217  
218          # Create output directory
219          output = os.path.join(tempfile.gettempdir(), "ann.npy")
220          path = os.path.join(output, "embeddings")
221          os.makedirs(output, exist_ok=True)
222  
223          # Generate data and save as pickle
224          data = np.random.rand(100, 240).astype(np.float32)
225          serializer.save(data, path)
226  
227          ann = ANNFactory.create({"backend": "numpy"})
228          ann.load(path)
229  
230          # Validate count
231          self.assertEqual(ann.count(), 100)
232  
233      def testNumPySafetensors(self):
234          """
235          Test NumPy backend with safetensors storage
236          """
237  
238          ann = ANNFactory.create({"backend": "numpy", "numpy": {"safetensors": True}})
239  
240          # Generate and index dummy data
241          data = np.random.rand(100, 240).astype(np.float32)
242          ann.index(data)
243  
244          # Test save and load
245          index = os.path.join(tempfile.gettempdir(), "numpy.safetensors")
246          ann.save(index)
247          ann.load(index)
248  
249          # Generate query vector and test search
250          query = np.random.rand(240).astype(np.float32)
251          self.normalize(query)
252          self.assertGreater(ann.search(np.array([query]), 1)[0][0][1], 0)
253  
254          # Validate count
255          self.assertEqual(ann.count(), 100)
256  
257      @patch("sqlalchemy.orm.Query.limit")
258      def testPGVector(self, query):
259          """
260          Test PGVector backend
261          """
262  
263          # Generate test record
264          data = np.random.rand(1, 240).astype(np.float32)
265  
266          # Mock database query
267          query.return_value = [(x, -1.0) for x in range(data.shape[0])]
268  
269          configs = [
270              ("full", {"dimensions": 240}, {}, data),
271              ("half", {"dimensions": 240}, {"precision": "half"}, data),
272              ("binary", {"quantize": 1, "dimensions": 240 * 8}, {}, data.astype(np.uint8)),
273          ]
274  
275          # Create ANN
276          for name, config, pgvector, data in configs:
277              path = os.path.join(tempfile.gettempdir(), f"pgvector.{name}.sqlite")
278              ann = ANNFactory.create(
279                  {**{"backend": "pgvector", "pgvector": {**{"url": f"sqlite:///{path}", "schema": "txtai"}, **pgvector}}, **config}
280              )
281  
282              # Test indexing
283              ann.index(data)
284              ann.append(data)
285  
286              # Validate search results
287              self.assertEqual(ann.search(data, 1), [[(0, 1.0)]])
288  
289              # Validate save/load/delete
290              ann.save(None)
291              ann.load(None)
292  
293              # Validate count
294              self.assertEqual(ann.count(), 2)
295  
296              # Test delete
297              ann.delete([0])
298              self.assertEqual(ann.count(), 1)
299  
300              # Close ANN
301              ann.close()
302  
303      @unittest.skipIf(platform.system() == "Darwin", "SQLite extensions not supported on macOS")
304      def testSQLite(self):
305          """
306          Test SQLite backend
307          """
308  
309          self.runTests("sqlite")
310  
311      @unittest.skipIf(platform.system() == "Darwin", "SQLite extensions not supported on macOS")
312      def testSQLiteCustom(self):
313          """
314          Test SQLite backend with custom settings
315          """
316  
317          # Test with custom settings
318          self.runTests("sqlite", {"sqlite": {"quantize": 1}})
319          self.runTests("sqlite", {"sqlite": {"quantize": 8}})
320  
321          # Test saving to a new path
322          model = self.backend("sqlite")
323          expected = model.count() - 1
324  
325          # Test save variations
326          index = os.path.join(tempfile.gettempdir(), "ann.sqlite")
327          new = os.path.join(tempfile.gettempdir(), "ann.sqlite.new")
328  
329          # Save new
330          model.save(index)
331  
332          # Save to same path
333          model.save(index)
334  
335          # Delete id
336          model.delete([0])
337  
338          # Save to another path
339          model.load(index)
340          model.save(new)
341  
342          self.assertEqual(model.count(), expected)
343  
344      def testTorch(self):
345          """
346          Test Torch backend
347          """
348  
349          self.runTests("torch")
350  
351      @unittest.skipIf(platform.system() == "Darwin", "Torch quantization not supported on macOS")
352      def testTorchQuantization(self):
353          """
354          Test Torch backend with quantization enabled
355          """
356  
357          for qtype in ["fp4", "nf4", "int8"]:
358              ann = ANNFactory.create({"backend": "torch", "torch": {"quantize": {"type": qtype}}})
359  
360              # Generate and index dummy data
361              data = np.random.rand(100, 240).astype(np.float32)
362              ann.index(data)
363  
364              # Test save and load
365              index = os.path.join(tempfile.gettempdir(), f"{qtype}.safetensors")
366              ann.save(index)
367              ann.load(index)
368  
369              # Generate query vector and test search
370              query = np.random.rand(240).astype(np.float32)
371              self.normalize(query)
372              self.assertGreater(ann.search(np.array([query]), 1)[0][0][1], 0)
373  
374              # Validate count
375              self.assertEqual(ann.count(), 100)
376  
377              # Test delete
378              ann.delete([0])
379              self.assertEqual(ann.count(), 99)
380  
381      def runTests(self, name, params=None, update=True):
382          """
383          Runs a series of standard backend tests.
384  
385          Args:
386              name: backend name
387              params: additional config parameters
388              update: If append/delete options should be tested
389          """
390  
391          self.assertEqual(self.backend(name, params).config["backend"], name)
392          self.assertEqual(self.save(name, params).count(), 10000)
393  
394          if update:
395              self.assertEqual(self.append(name, params, 500).count(), 10500)
396              self.assertEqual(self.delete(name, params, [0, 1]).count(), 9998)
397              self.assertEqual(self.delete(name, params, [100000]).count(), 10000)
398  
399          self.assertGreater(self.search(name, params), 0)
400  
401      def backend(self, name, params=None, length=10000):
402          """
403          Test a backend.
404  
405          Args:
406              name: backend name
407              params: additional config parameters
408              length: number of rows to generate
409  
410          Returns:
411              ANN model
412          """
413  
414          # Generate test data
415          data = np.random.rand(length, 240).astype(np.float32)
416          self.normalize(data)
417  
418          config = {"backend": name, "dimensions": data.shape[1]}
419          if params:
420              config.update(params)
421  
422          model = ANNFactory.create(config)
423          model.index(data)
424  
425          return model
426  
427      def append(self, name, params=None, length=500):
428          """
429          Appends new data to index.
430  
431          Args:
432              name: backend name
433              params: additional config parameters
434              length: number of rows to generate
435  
436          Returns:
437              ANN model
438          """
439  
440          # Initial model
441          model = self.backend(name, params)
442  
443          # Generate test data
444          data = np.random.rand(length, 240).astype(np.float32)
445          self.normalize(data)
446  
447          model.append(data)
448  
449          return model
450  
451      def delete(self, name, params=None, ids=None):
452          """
453          Deletes data from index.
454  
455          Args:
456              name: backend name
457              params: additional config parameters
458              ids: ids to delete
459  
460          Returns:
461              ANN model
462          """
463  
464          # Initial model
465          model = self.backend(name, params)
466          model.delete(ids)
467  
468          return model
469  
470      def save(self, name, params=None):
471          """
472          Test save/load.
473  
474          Args:
475              name: backend name
476              params: additional config parameters
477  
478          Returns:
479              ANN model
480          """
481  
482          model = self.backend(name, params)
483  
484          # Generate temp file path
485          index = os.path.join(tempfile.gettempdir(), "ann")
486  
487          # Save and close index
488          model.save(index)
489          model.close()
490  
491          # Reload index
492          model.load(index)
493  
494          return model
495  
496      def search(self, name, params=None):
497          """
498          Test ANN search.
499  
500          Args:
501              name: backend name
502              params: additional config parameters
503  
504          Returns:
505              search results
506          """
507  
508          # Generate ANN index
509          model = self.backend(name, params)
510  
511          # Generate query vector
512          query = np.random.rand(240).astype(np.float32)
513          self.normalize(query)
514  
515          # Ensure top result has similarity > 0
516          return model.search(np.array([query]), 1)[0][0][1]
517  
518      def normalize(self, embeddings):
519          """
520          Normalizes embeddings using L2 normalization. Operation applied directly on array.
521  
522          Args:
523              embeddings: input embeddings matrix
524          """
525  
526          # Calculation is different for matrices vs vectors
527          if len(embeddings.shape) > 1:
528              embeddings /= np.linalg.norm(embeddings, axis=1)[:, np.newaxis]
529          else:
530              embeddings /= np.linalg.norm(embeddings)