testsparse.py
1 """ 2 Sparse ANN module tests 3 """ 4 5 import os 6 import tempfile 7 import unittest 8 9 from unittest.mock import patch 10 11 from scipy.sparse import random 12 from sklearn.preprocessing import normalize 13 14 from txtai.ann import SparseANNFactory 15 16 17 class TestSparse(unittest.TestCase): 18 """ 19 Sparse ANN tests. 20 """ 21 22 def testCustomBackend(self): 23 """ 24 Test resolving a custom backend 25 """ 26 27 self.assertIsNotNone(SparseANNFactory.create({"backend": "txtai.ann.IVFSparse"})) 28 29 def testCustomBackendNotFound(self): 30 """ 31 Test resolving an unresolvable backend 32 """ 33 34 with self.assertRaises(ImportError): 35 SparseANNFactory.create({"backend": "notfound.ann"}) 36 37 def testIVFSparse(self): 38 """ 39 Test IVFSparse backend 40 """ 41 42 # Generate test record 43 insert = self.generate(500, 30522) 44 append = self.generate(500, 30522) 45 46 # Count of records 47 count = insert.shape[0] + append.shape[0] 48 49 # Create ANN 50 path = os.path.join(tempfile.gettempdir(), "ivfsparse") 51 ann = SparseANNFactory.create({"backend": "ivfsparse", "ivfsparse": {"nlist": 2, "nprobe": 2, "sample": 1.0}}) 52 53 # Test indexing 54 ann.index(insert) 55 ann.append(append) 56 57 # Validate search results 58 results = [x[0] for x in ann.search(insert[5], 10)[0]] 59 self.assertIn(5, results) 60 61 # Validate save/load/delete 62 ann.save(path) 63 ann.load(path) 64 65 # Validate count 66 self.assertEqual(ann.count(), count) 67 68 # Test delete 69 ann.delete([0]) 70 self.assertEqual(ann.count(), count - 1) 71 72 # Re-validate search results 73 results = [x[0] for x in ann.search(append[0], 10)[0]] 74 self.assertIn(insert.shape[0], results) 75 76 # Close ANN 77 ann.close() 78 79 # Test cluster pruning 80 ann = SparseANNFactory.create({"backend": "ivfsparse", "ivfsparse": {"nlist": 15, "nprobe": 1, "sample": 1.0}}) 81 ann.index(insert) 82 self.assertLess(len(ann.blocks), 15) 83 ann.close() 84 85 def testIVFSparseTopnOverLimit(self): 86 """ 87 Test IVFSparse topn when limit exceeds the number of indexed documents 88 """ 89 90 # Generate a small dataset (5 documents) 91 data = self.generate(5, 30522) 92 93 ann = SparseANNFactory.create({"backend": "ivfsparse"}) 94 ann.index(data) 95 96 # Search with limit (10) greater than document count (5) 97 results = ann.search(data[0], 10) 98 self.assertGreater(len(results[0]), 0) 99 100 # Batch search with multiple queries exceeding document count 101 results = ann.search(data, 10) 102 self.assertEqual(len(results), data.shape[0]) 103 for result in results: 104 self.assertGreater(len(result), 0) 105 106 ann.close() 107 108 @patch("sqlalchemy.orm.Query.limit") 109 def testPGSparse(self, query): 110 """ 111 Test Sparse Postgres backend 112 """ 113 114 # Generate test record 115 data = self.generate(1, 30522) 116 117 # Mock database query 118 query.return_value = [(x, -1.0) for x in range(data.shape[0])] 119 120 # Create ANN 121 path = os.path.join(tempfile.gettempdir(), "pgsparse.sqlite") 122 ann = SparseANNFactory.create({"backend": "pgsparse", "dimensions": 30522, "pgsparse": {"url": f"sqlite:///{path}", "schema": "txtai"}}) 123 124 # Test indexing 125 ann.index(data) 126 ann.append(data) 127 128 # Validate search results 129 self.assertEqual(ann.search(data, 1), [[(0, 1.0)]]) 130 131 # Validate save/load/delete 132 ann.save(None) 133 ann.load(None) 134 135 # Validate count 136 self.assertEqual(ann.count(), 2) 137 138 # Test delete 139 ann.delete([0]) 140 self.assertEqual(ann.count(), 1) 141 142 # Test > 1000 dimensions 143 data = random(1, 30522, format="csr", density=0.1) 144 ann.index(data) 145 self.assertEqual(ann.count(), 1) 146 147 # Close ANN 148 ann.close() 149 150 def generate(self, m, n): 151 """ 152 Generates random normalized sparse data. 153 154 Args: 155 m, n: shape of the matrix 156 157 Returns: 158 csr matrix 159 """ 160 161 # Generate random csr matrix 162 data = random(m, n, format="csr") 163 164 # Normalize and return 165 return normalize(data)