/ test / python / testscoring / testsparse.py
testsparse.py
  1  """
  2  Sparse module tests
  3  """
  4  
  5  import os
  6  import platform
  7  import tempfile
  8  import unittest
  9  
 10  from unittest.mock import patch
 11  
 12  from txtai.scoring import ScoringFactory
 13  
 14  
 15  # pylint: disable=R0904
 16  class TestSparse(unittest.TestCase):
 17      """
 18      Sparse vector scoring tests.
 19      """
 20  
 21      @classmethod
 22      def setUpClass(cls):
 23          """
 24          Initialize test data.
 25          """
 26  
 27          cls.data = [
 28              "US tops 5 million confirmed virus cases",
 29              "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg",
 30              "Beijing mobilises invasion craft along coast as Taiwan tensions escalate",
 31              "The National Park Service warns against sacrificing slower friends in a bear attack",
 32              "Maine man wins $1M from $25 lottery ticket",
 33              "Make huge profits without work, earn up to $100,000 a day",
 34          ]
 35  
 36          cls.data = [(uid, x, None) for uid, x in enumerate(cls.data)]
 37  
 38      def testGeneral(self):
 39          """
 40          Test general sparse vector operations
 41          """
 42  
 43          # Models cache
 44          models = {}
 45  
 46          # Test sparse scoring
 47          scoring = ScoringFactory.create({"method": "sparse", "path": "sparse-encoder-testing/splade-bert-tiny-nq"}, models=models)
 48          scoring.index((uid, {"text": text}, tags) for uid, text, tags in self.data)
 49  
 50          # Run search and validate correct result returned
 51          index, _ = scoring.search("lottery ticket", 1)[0]
 52          self.assertEqual(index, 4)
 53  
 54          # Run batch search
 55          index, _ = scoring.batchsearch(["lottery ticket"], 1)[0][0]
 56          self.assertEqual(index, 4)
 57  
 58          # Validate count
 59          self.assertEqual(scoring.count(), len(self.data))
 60  
 61          # Test delete
 62          scoring.delete([4])
 63          self.assertEqual(scoring.count(), len(self.data) - 1)
 64  
 65          # Run search after delete
 66          index, _ = scoring.search("lottery ticket", 1)[0]
 67          self.assertEqual(index, 5)
 68  
 69          # Sparse vectors is a normalized sparse index
 70          self.assertTrue(scoring.issparse() and scoring.isnormalized() and not scoring.isbayes())
 71          self.assertIsNone(scoring.weights("This is a test".split()))
 72  
 73          # Close scoring
 74          scoring.close()
 75  
 76          # Test model caching
 77          scoring = ScoringFactory.create({"method": "sparse", "path": "sparse-encoder-testing/splade-bert-tiny-nq"}, models=models)
 78          self.assertIsNotNone(scoring.model)
 79          scoring.close()
 80  
 81      def testEmpty(self):
 82          """
 83          Test empty sparse vectors
 84          """
 85  
 86          scoring = ScoringFactory.create({"method": "sparse", "path": "sparse-encoder-testing/splade-bert-tiny-nq"})
 87          scoring.upsert((uid, {"text": text}, tags) for uid, text, tags in self.data)
 88          self.assertEqual(scoring.count(), len(self.data))
 89  
 90      @unittest.skipIf(platform.system() == "Darwin", "Torch memory sharing not supported on macOS")
 91      @patch("torch.cuda.device_count")
 92      def testGPU(self, count):
 93          """
 94          Test sparse vectors with GPU encoding
 95          """
 96  
 97          # Mock accelerator count
 98          count.return_value = 2
 99  
100          # Test multiple gpus
101          scoring = ScoringFactory.create({"method": "sparse", "path": "sparse-encoder-testing/splade-bert-tiny-nq", "gpu": "all"})
102          self.assertIsNotNone(scoring)
103          scoring.close()
104  
105      def testBayes(self):
106          """
107          Test BB25 Bayesian normalization for sparse scoring
108          """
109  
110          config = {
111              "method": "sparse",
112              "path": "sparse-encoder-testing/splade-bert-tiny-nq",
113              "normalize": "bb25",
114          }
115          scoring = ScoringFactory.create(config)
116          scoring.index((uid, {"text": text}, tags) for uid, text, tags in self.data)
117  
118          # Verify Bayesian mode flags
119          self.assertTrue(scoring.isbayes())
120          self.assertTrue(scoring.isnormalized())
121  
122          # Search and validate scores are calibrated probabilities in [0, 1]
123          results = scoring.search("lottery ticket", 3)
124          self.assertGreater(len(results), 0)
125          for _, score in results:
126              self.assertGreaterEqual(score, 0.0)
127              self.assertLessEqual(score, 1.0)
128  
129          # Batch search
130          results = scoring.batchsearch(["lottery ticket", "ice shelf"], 3)
131          self.assertEqual(len(results), 2)
132          for query_results in results:
133              for _, score in query_results:
134                  self.assertGreaterEqual(score, 0.0)
135                  self.assertLessEqual(score, 1.0)
136  
137          scoring.close()
138  
139      def testBayesDict(self):
140          """
141          Test BB25 normalization with dict config
142          """
143  
144          config = {
145              "method": "sparse",
146              "path": "sparse-encoder-testing/splade-bert-tiny-nq",
147              "normalize": {"method": "bb25", "alpha": 2.0},
148          }
149          scoring = ScoringFactory.create(config)
150          scoring.index((uid, {"text": text}, tags) for uid, text, tags in self.data)
151  
152          self.assertTrue(scoring.isbayes())
153  
154          results = scoring.search("lottery ticket", 3)
155          self.assertGreater(len(results), 0)
156          for _, score in results:
157              self.assertGreaterEqual(score, 0.0)
158              self.assertLessEqual(score, 1.0)
159  
160          scoring.close()
161  
162      def testBayesNonBayes(self):
163          """
164          Test that non-Bayesian string normalize values do not activate Bayesian mode
165          """
166  
167          config = {
168              "method": "sparse",
169              "path": "sparse-encoder-testing/splade-bert-tiny-nq",
170              "normalize": "default",
171          }
172          scoring = ScoringFactory.create(config)
173          self.assertFalse(scoring.isbayes())
174          scoring.close()
175  
176      def testIVFFlat(self):
177          """
178          Test sparse vectors with IVFFlat clustering
179          """
180  
181          # Expand dataset
182          data = self.data * 1000
183  
184          # Test higher volume IVFFlat index with clustering
185          config = {
186              "method": "sparse",
187              "vectormethod": "sentence-transformers",
188              "path": "sparse-encoder-testing/splade-bert-tiny-nq",
189              "ivfsparse": {"sample": 1.0},
190          }
191          scoring = ScoringFactory.create(config)
192          scoring.index((uid, {"text": text}, tags) for uid, text, tags in data)
193  
194          # Generate temp file path
195          index = os.path.join(tempfile.gettempdir(), "scoring")
196          os.makedirs(index, exist_ok=True)
197  
198          # Save scoring instance
199          scoring.save(f"{index}/scoring.sparse.index")
200  
201          # Reload scoring instance
202          scoring = ScoringFactory.create(config)
203          scoring.load(f"{index}/scoring.sparse.index")
204  
205          # Run search and validate correct result returned
206          results = scoring.search("lottery ticket", 1)
207          self.assertGreater(len(results), 0)
208          scoring.close()