testsparse.py
1 """ 2 Sparse module tests 3 """ 4 5 import os 6 import platform 7 import tempfile 8 import unittest 9 10 from unittest.mock import patch 11 12 from txtai.scoring import ScoringFactory 13 14 15 # pylint: disable=R0904 16 class TestSparse(unittest.TestCase): 17 """ 18 Sparse vector scoring tests. 19 """ 20 21 @classmethod 22 def setUpClass(cls): 23 """ 24 Initialize test data. 25 """ 26 27 cls.data = [ 28 "US tops 5 million confirmed virus cases", 29 "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg", 30 "Beijing mobilises invasion craft along coast as Taiwan tensions escalate", 31 "The National Park Service warns against sacrificing slower friends in a bear attack", 32 "Maine man wins $1M from $25 lottery ticket", 33 "Make huge profits without work, earn up to $100,000 a day", 34 ] 35 36 cls.data = [(uid, x, None) for uid, x in enumerate(cls.data)] 37 38 def testGeneral(self): 39 """ 40 Test general sparse vector operations 41 """ 42 43 # Models cache 44 models = {} 45 46 # Test sparse scoring 47 scoring = ScoringFactory.create({"method": "sparse", "path": "sparse-encoder-testing/splade-bert-tiny-nq"}, models=models) 48 scoring.index((uid, {"text": text}, tags) for uid, text, tags in self.data) 49 50 # Run search and validate correct result returned 51 index, _ = scoring.search("lottery ticket", 1)[0] 52 self.assertEqual(index, 4) 53 54 # Run batch search 55 index, _ = scoring.batchsearch(["lottery ticket"], 1)[0][0] 56 self.assertEqual(index, 4) 57 58 # Validate count 59 self.assertEqual(scoring.count(), len(self.data)) 60 61 # Test delete 62 scoring.delete([4]) 63 self.assertEqual(scoring.count(), len(self.data) - 1) 64 65 # Run search after delete 66 index, _ = scoring.search("lottery ticket", 1)[0] 67 self.assertEqual(index, 5) 68 69 # Sparse vectors is a normalized sparse index 70 self.assertTrue(scoring.issparse() and scoring.isnormalized() and not scoring.isbayes()) 71 self.assertIsNone(scoring.weights("This is a test".split())) 72 73 # Close scoring 74 scoring.close() 75 76 # Test model caching 77 scoring = ScoringFactory.create({"method": "sparse", "path": "sparse-encoder-testing/splade-bert-tiny-nq"}, models=models) 78 self.assertIsNotNone(scoring.model) 79 scoring.close() 80 81 def testEmpty(self): 82 """ 83 Test empty sparse vectors 84 """ 85 86 scoring = ScoringFactory.create({"method": "sparse", "path": "sparse-encoder-testing/splade-bert-tiny-nq"}) 87 scoring.upsert((uid, {"text": text}, tags) for uid, text, tags in self.data) 88 self.assertEqual(scoring.count(), len(self.data)) 89 90 @unittest.skipIf(platform.system() == "Darwin", "Torch memory sharing not supported on macOS") 91 @patch("torch.cuda.device_count") 92 def testGPU(self, count): 93 """ 94 Test sparse vectors with GPU encoding 95 """ 96 97 # Mock accelerator count 98 count.return_value = 2 99 100 # Test multiple gpus 101 scoring = ScoringFactory.create({"method": "sparse", "path": "sparse-encoder-testing/splade-bert-tiny-nq", "gpu": "all"}) 102 self.assertIsNotNone(scoring) 103 scoring.close() 104 105 def testBayes(self): 106 """ 107 Test BB25 Bayesian normalization for sparse scoring 108 """ 109 110 config = { 111 "method": "sparse", 112 "path": "sparse-encoder-testing/splade-bert-tiny-nq", 113 "normalize": "bb25", 114 } 115 scoring = ScoringFactory.create(config) 116 scoring.index((uid, {"text": text}, tags) for uid, text, tags in self.data) 117 118 # Verify Bayesian mode flags 119 self.assertTrue(scoring.isbayes()) 120 self.assertTrue(scoring.isnormalized()) 121 122 # Search and validate scores are calibrated probabilities in [0, 1] 123 results = scoring.search("lottery ticket", 3) 124 self.assertGreater(len(results), 0) 125 for _, score in results: 126 self.assertGreaterEqual(score, 0.0) 127 self.assertLessEqual(score, 1.0) 128 129 # Batch search 130 results = scoring.batchsearch(["lottery ticket", "ice shelf"], 3) 131 self.assertEqual(len(results), 2) 132 for query_results in results: 133 for _, score in query_results: 134 self.assertGreaterEqual(score, 0.0) 135 self.assertLessEqual(score, 1.0) 136 137 scoring.close() 138 139 def testBayesDict(self): 140 """ 141 Test BB25 normalization with dict config 142 """ 143 144 config = { 145 "method": "sparse", 146 "path": "sparse-encoder-testing/splade-bert-tiny-nq", 147 "normalize": {"method": "bb25", "alpha": 2.0}, 148 } 149 scoring = ScoringFactory.create(config) 150 scoring.index((uid, {"text": text}, tags) for uid, text, tags in self.data) 151 152 self.assertTrue(scoring.isbayes()) 153 154 results = scoring.search("lottery ticket", 3) 155 self.assertGreater(len(results), 0) 156 for _, score in results: 157 self.assertGreaterEqual(score, 0.0) 158 self.assertLessEqual(score, 1.0) 159 160 scoring.close() 161 162 def testBayesNonBayes(self): 163 """ 164 Test that non-Bayesian string normalize values do not activate Bayesian mode 165 """ 166 167 config = { 168 "method": "sparse", 169 "path": "sparse-encoder-testing/splade-bert-tiny-nq", 170 "normalize": "default", 171 } 172 scoring = ScoringFactory.create(config) 173 self.assertFalse(scoring.isbayes()) 174 scoring.close() 175 176 def testIVFFlat(self): 177 """ 178 Test sparse vectors with IVFFlat clustering 179 """ 180 181 # Expand dataset 182 data = self.data * 1000 183 184 # Test higher volume IVFFlat index with clustering 185 config = { 186 "method": "sparse", 187 "vectormethod": "sentence-transformers", 188 "path": "sparse-encoder-testing/splade-bert-tiny-nq", 189 "ivfsparse": {"sample": 1.0}, 190 } 191 scoring = ScoringFactory.create(config) 192 scoring.index((uid, {"text": text}, tags) for uid, text, tags in data) 193 194 # Generate temp file path 195 index = os.path.join(tempfile.gettempdir(), "scoring") 196 os.makedirs(index, exist_ok=True) 197 198 # Save scoring instance 199 scoring.save(f"{index}/scoring.sparse.index") 200 201 # Reload scoring instance 202 scoring = ScoringFactory.create(config) 203 scoring.load(f"{index}/scoring.sparse.index") 204 205 # Run search and validate correct result returned 206 results = scoring.search("lottery ticket", 1) 207 self.assertGreater(len(results), 0) 208 scoring.close()