/ test / python / testann / testsparse.py
testsparse.py
  1  """
  2  Sparse ANN module tests
  3  """
  4  
  5  import os
  6  import tempfile
  7  import unittest
  8  
  9  from unittest.mock import patch
 10  
 11  from scipy.sparse import random
 12  from sklearn.preprocessing import normalize
 13  
 14  from txtai.ann import SparseANNFactory
 15  
 16  
 17  class TestSparse(unittest.TestCase):
 18      """
 19      Sparse ANN tests.
 20      """
 21  
 22      def testCustomBackend(self):
 23          """
 24          Test resolving a custom backend
 25          """
 26  
 27          self.assertIsNotNone(SparseANNFactory.create({"backend": "txtai.ann.IVFSparse"}))
 28  
 29      def testCustomBackendNotFound(self):
 30          """
 31          Test resolving an unresolvable backend
 32          """
 33  
 34          with self.assertRaises(ImportError):
 35              SparseANNFactory.create({"backend": "notfound.ann"})
 36  
 37      def testIVFSparse(self):
 38          """
 39          Test IVFSparse backend
 40          """
 41  
 42          # Generate test record
 43          insert = self.generate(500, 30522)
 44          append = self.generate(500, 30522)
 45  
 46          # Count of records
 47          count = insert.shape[0] + append.shape[0]
 48  
 49          # Create ANN
 50          path = os.path.join(tempfile.gettempdir(), "ivfsparse")
 51          ann = SparseANNFactory.create({"backend": "ivfsparse", "ivfsparse": {"nlist": 2, "nprobe": 2, "sample": 1.0}})
 52  
 53          # Test indexing
 54          ann.index(insert)
 55          ann.append(append)
 56  
 57          # Validate search results
 58          results = [x[0] for x in ann.search(insert[5], 10)[0]]
 59          self.assertIn(5, results)
 60  
 61          # Validate save/load/delete
 62          ann.save(path)
 63          ann.load(path)
 64  
 65          # Validate count
 66          self.assertEqual(ann.count(), count)
 67  
 68          # Test delete
 69          ann.delete([0])
 70          self.assertEqual(ann.count(), count - 1)
 71  
 72          # Re-validate search results
 73          results = [x[0] for x in ann.search(append[0], 10)[0]]
 74          self.assertIn(insert.shape[0], results)
 75  
 76          # Close ANN
 77          ann.close()
 78  
 79          # Test cluster pruning
 80          ann = SparseANNFactory.create({"backend": "ivfsparse", "ivfsparse": {"nlist": 15, "nprobe": 1, "sample": 1.0}})
 81          ann.index(insert)
 82          self.assertLess(len(ann.blocks), 15)
 83          ann.close()
 84  
 85      def testIVFSparseTopnOverLimit(self):
 86          """
 87          Test IVFSparse topn when limit exceeds the number of indexed documents
 88          """
 89  
 90          # Generate a small dataset (5 documents)
 91          data = self.generate(5, 30522)
 92  
 93          ann = SparseANNFactory.create({"backend": "ivfsparse"})
 94          ann.index(data)
 95  
 96          # Search with limit (10) greater than document count (5)
 97          results = ann.search(data[0], 10)
 98          self.assertGreater(len(results[0]), 0)
 99  
100          # Batch search with multiple queries exceeding document count
101          results = ann.search(data, 10)
102          self.assertEqual(len(results), data.shape[0])
103          for result in results:
104              self.assertGreater(len(result), 0)
105  
106          ann.close()
107  
108      @patch("sqlalchemy.orm.Query.limit")
109      def testPGSparse(self, query):
110          """
111          Test Sparse Postgres backend
112          """
113  
114          # Generate test record
115          data = self.generate(1, 30522)
116  
117          # Mock database query
118          query.return_value = [(x, -1.0) for x in range(data.shape[0])]
119  
120          # Create ANN
121          path = os.path.join(tempfile.gettempdir(), "pgsparse.sqlite")
122          ann = SparseANNFactory.create({"backend": "pgsparse", "dimensions": 30522, "pgsparse": {"url": f"sqlite:///{path}", "schema": "txtai"}})
123  
124          # Test indexing
125          ann.index(data)
126          ann.append(data)
127  
128          # Validate search results
129          self.assertEqual(ann.search(data, 1), [[(0, 1.0)]])
130  
131          # Validate save/load/delete
132          ann.save(None)
133          ann.load(None)
134  
135          # Validate count
136          self.assertEqual(ann.count(), 2)
137  
138          # Test delete
139          ann.delete([0])
140          self.assertEqual(ann.count(), 1)
141  
142          # Test > 1000 dimensions
143          data = random(1, 30522, format="csr", density=0.1)
144          ann.index(data)
145          self.assertEqual(ann.count(), 1)
146  
147          # Close ANN
148          ann.close()
149  
150      def generate(self, m, n):
151          """
152          Generates random normalized sparse data.
153  
154          Args:
155              m, n: shape of the matrix
156  
157          Returns:
158              csr matrix
159          """
160  
161          # Generate random csr matrix
162          data = random(m, n, format="csr")
163  
164          # Normalize and return
165          return normalize(data)