/ test / python / testscoring / testkeyword.py
testkeyword.py
  1  """
  2  Keyword scoring tests
  3  """
  4  
  5  import os
  6  import tempfile
  7  import unittest
  8  
  9  from unittest.mock import patch
 10  
 11  from txtai.scoring import Normalize, ScoringFactory, Scoring
 12  
 13  
 14  # pylint: disable=R0904
 15  class TestKeyword(unittest.TestCase):
 16      """
 17      Sparse keyword scoring tests.
 18      """
 19  
 20      @classmethod
 21      def setUpClass(cls):
 22          """
 23          Initialize test data.
 24          """
 25  
 26          cls.data = [
 27              "US tops 5 million confirmed virus cases",
 28              "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg",
 29              "Beijing mobilises invasion craft along coast as Taiwan tensions escalate",
 30              "The National Park Service warns against sacrificing slower friends in a bear attack",
 31              "Maine man wins $1M from $25 lottery ticket",
 32              "wins wins wins",
 33              "Make huge profits without work, earn up to $100,000 a day",
 34          ]
 35  
 36          cls.data = [(uid, x, None) for uid, x in enumerate(cls.data)]
 37  
 38      def testBM25(self):
 39          """
 40          Test bm25
 41          """
 42  
 43          self.runTests("bm25")
 44  
 45      def testCustom(self):
 46          """
 47          Test custom method
 48          """
 49  
 50          self.runTests("txtai.scoring.BM25")
 51  
 52      def testCustomNotFound(self):
 53          """
 54          Test unresolvable custom method
 55          """
 56  
 57          with self.assertRaises(ImportError):
 58              ScoringFactory.create("notfound.scoring")
 59  
 60      def testNotImplemented(self):
 61          """
 62          Test exceptions for non-implemented methods
 63          """
 64  
 65          scoring = Scoring()
 66  
 67          self.assertRaises(NotImplementedError, scoring.insert, None, None)
 68          self.assertRaises(NotImplementedError, scoring.delete, None)
 69          self.assertRaises(NotImplementedError, scoring.weights, None)
 70          self.assertRaises(NotImplementedError, scoring.search, None, None)
 71          self.assertRaises(NotImplementedError, scoring.batchsearch, None, None, None)
 72          self.assertRaises(NotImplementedError, scoring.count)
 73          self.assertRaises(NotImplementedError, scoring.load, None)
 74          self.assertRaises(NotImplementedError, scoring.save, None)
 75          self.assertRaises(NotImplementedError, scoring.close)
 76          self.assertRaises(NotImplementedError, scoring.issparse)
 77          self.assertRaises(NotImplementedError, scoring.isnormalized)
 78          self.assertRaises(NotImplementedError, scoring.isbayes)
 79  
 80      @patch("sqlalchemy.orm.Query.params")
 81      def testPGText(self, query):
 82          """
 83          Test PGText
 84          """
 85  
 86          # Mock database query
 87          query.return_value = [(3, 1.0)]
 88  
 89          # Create scoring
 90          path = os.path.join(tempfile.gettempdir(), "pgtext.sqlite")
 91          scoring = ScoringFactory.create({"method": "pgtext", "url": f"sqlite:///{path}", "schema": "txtai"})
 92          scoring.index((uid, {"text": text}, tags) for uid, text, tags in self.data)
 93  
 94          # Run search and validate correct result returned
 95          index, _ = scoring.search("bear", 1)[0]
 96          self.assertEqual(index, 3)
 97  
 98          # Run batch search
 99          index, _ = scoring.batchsearch(["bear"], 1)[0][0]
100          self.assertEqual(index, 3)
101  
102          # Validate save/load/delete
103          scoring.save(None)
104          scoring.load(None)
105  
106          # Validate count
107          self.assertEqual(scoring.count(), len(self.data))
108  
109          # Test delete
110          scoring.delete([0])
111          self.assertEqual(scoring.count(), len(self.data) - 1)
112  
113          # PGText is a normalized sparse index
114          self.assertTrue(scoring.issparse() and scoring.isnormalized() and not scoring.isbayes())
115          self.assertIsNone(scoring.weights("This is a test".split()))
116  
117          # Close scoring
118          scoring.close()
119  
120      def testSIF(self):
121          """
122          Test sif
123          """
124  
125          self.runTests("sif")
126  
127      def testTFIDF(self):
128          """
129          Test tfidf
130          """
131  
132          self.runTests("tfidf")
133  
134      def runTests(self, method):
135          """
136          Runs a series of tests for a scoring method.
137  
138          Args:
139              method: scoring method
140          """
141  
142          config = {"method": method}
143  
144          self.index(config)
145          self.upsert(config)
146          self.weights(config)
147          self.search(config)
148          self.delete(config)
149          self.normalize(config)
150          self.content(config)
151          self.empty(config)
152          self.copy(config)
153          self.settings(config)
154          self.tokenization(config)
155  
156      def index(self, config, data=None):
157          """
158          Test scoring index method.
159  
160          Args:
161              config: scoring config
162              data: data to index with scoring method
163  
164          Returns:
165              scoring
166          """
167  
168          # Derive input data
169          data = data if data else self.data
170  
171          scoring = ScoringFactory.create(config)
172          scoring.index(data)
173  
174          keys = [k for k, v in sorted(scoring.idf.items(), key=lambda x: x[1])]
175  
176          # Test count
177          self.assertEqual(scoring.count(), len(data))
178  
179          # Win should be lowest score
180          self.assertEqual(keys[0], "wins")
181  
182          # Test save/load
183          self.assertIsNotNone(self.save(scoring, config, f"scoring.{config['method']}.index"))
184  
185          # Test search returns none when terms disabled (default)
186          self.assertIsNone(scoring.search("query"))
187  
188          return scoring
189  
190      def upsert(self, config):
191          """
192          Test scoring upsert method
193          """
194  
195          scoring = ScoringFactory.create({**config, **{"tokenizer": {"alphanum": True, "stopwords": True}}})
196          scoring.upsert(self.data)
197  
198          # Test count
199          self.assertEqual(scoring.count(), len(self.data))
200  
201          # Test stop word is removed
202          self.assertFalse("and" in scoring.idf)
203  
204      def save(self, scoring, config, name):
205          """
206          Test scoring index save/load.
207  
208          Args:
209              scoring: scoring index
210              config: scoring config
211              name: output file name
212  
213          Returns:
214              scoring
215          """
216  
217          # Generate temp file path
218          index = os.path.join(tempfile.gettempdir(), "scoring")
219          os.makedirs(index, exist_ok=True)
220  
221          # Save scoring instance
222          scoring.save(f"{index}/{name}")
223  
224          # Reload scoring instance
225          scoring = ScoringFactory.create(config)
226          scoring.load(f"{index}/{name}")
227  
228          return scoring
229  
230      def weights(self, config):
231          """
232          Test standard and tag weighted scores.
233  
234          Args:
235              config: scoring config
236          """
237  
238          document = (1, ["bear", "wins"], None)
239  
240          scoring = self.index(config)
241          weights = scoring.weights(document[1])
242  
243          # Default weights
244          self.assertNotEqual(weights[0], weights[1])
245  
246          data = self.data[:]
247  
248          uid, text, _ = data[3]
249          data[3] = (uid, text, "wins")
250  
251          scoring = self.index(config, data)
252          weights = scoring.weights(document[1])
253  
254          # Modified weights
255          self.assertEqual(weights[0], weights[1])
256  
257      def search(self, config):
258          """
259          Test scoring search.
260  
261          Args:
262              config: scoring config
263          """
264  
265          # Create combined config
266          config = {**config, **{"terms": True}}
267  
268          # Create scoring instance
269          scoring = ScoringFactory.create(config)
270          scoring.index(self.data)
271  
272          # Run search and validate correct result returned
273          index, _ = scoring.search("bear", 1)[0]
274          self.assertEqual(index, 3)
275  
276          # Run batch search
277          index, _ = scoring.batchsearch(["bear"], 1)[0][0]
278          self.assertEqual(index, 3)
279  
280          # Run wildcard search
281          index, _ = scoring.search("bea*", 1)[0]
282          self.assertEqual(index, 3)
283  
284          # Test save/reload
285          self.save(scoring, config, f"scoring.{config['method']}.search")
286  
287          # Re-run search and validate correct result returned
288          index, _ = scoring.search("bear", 1)[0]
289          self.assertEqual(index, 3)
290  
291      def delete(self, config):
292          """
293          Test delete.
294          """
295  
296          # Create combined config
297          config = {**config, **{"terms": True, "content": True}}
298  
299          # Create scoring instance
300          scoring = ScoringFactory.create(config)
301          scoring.index(self.data)
302  
303          # Run search and validate correct result returned
304          index = scoring.search("bear", 1)[0]["id"]
305  
306          # Delete result and validate the query no longer returns results
307          scoring.delete([index])
308          self.assertFalse(scoring.search("bear", 1))
309  
310          # Save and validate count
311          self.save(scoring, config, f"scoring.{config['method']}.delete")
312          self.assertEqual(scoring.count(), len(self.data) - 1)
313  
314      def normalize(self, config):
315          """
316          Test scoring search with normalized scores.
317  
318          Args:
319              method: scoring method
320          """
321  
322          # Default normalization
323          scoring = ScoringFactory.create({**config, **{"terms": True, "normalize": True}})
324          scoring.index(self.data)
325  
326          # Run search and validate correct result returned
327          index, score = scoring.search(self.data[3][1], 1)[0]
328          self.assertEqual(index, 3)
329          self.assertEqual(score, 1.0)
330  
331          # Bayesian normalization with default dynamic alpha/beta settings
332          baseline = ScoringFactory.create({**config, **{"terms": True}})
333          baseline.index(self.data)
334  
335          scoring = ScoringFactory.create({**config, **{"terms": True, "normalize": "bayes"}})
336          scoring.index(self.data)
337  
338          query = "wins"
339          base = baseline.search(query, 3)
340          bayes = scoring.search(query, 3)
341  
342          # Bayesian normalization should preserve ranking order while mapping scores to [0, 1]
343          self.assertEqual([uid for uid, _ in base], [uid for uid, _ in bayes])
344          self.assertTrue(all(0.0 <= score <= 1.0 for _, score in bayes))
345  
346          # BB25 alias should resolve to Bayesian normalization
347          scoring = ScoringFactory.create({**config, **{"terms": True, "normalize": "bb25"}})
348          scoring.index(self.data)
349          bb25 = scoring.search(query, 3)
350          self.assertEqual([uid for uid, _ in base], [uid for uid, _ in bb25])
351          self.assertTrue(all(0.0 <= score <= 1.0 for _, score in bb25))
352  
353          # BB25 candidate-set behavior: zero scores remain 0, positive scores are transformed
354          normalizer = Normalize("bb25")
355          scores = normalizer([(0, 0.0), (1, 1.0), (2, 2.0)], scoring.avgscore)
356          self.assertEqual(scores[0][1], 0.0)
357          self.assertGreater(scores[1][1], 0.0)
358          self.assertGreater(scores[2][1], scores[1][1])
359  
360          # Test negative scores
361          scores = normalizer([(0, -100.0)], scoring.avgscore)
362          self.assertEqual(scores[0][1], 0.0)
363  
364          # Bayesian normalization with custom parameters
365          config = {**config, **{"terms": True, "normalize": {"method": "bayes", "alpha": 2.0}}}
366          scoring = ScoringFactory.create(config)
367          scoring.index(self.data)
368  
369          custom = scoring.search(query, 3)
370          self.assertEqual([uid for uid, _ in base], [uid for uid, _ in custom])
371          self.assertTrue(all(0.0 <= score <= 1.0 for _, score in custom))
372  
373      def content(self, config):
374          """
375          Test scoring search with content.
376  
377          Args:
378              config: scoring config
379          """
380  
381          scoring = ScoringFactory.create({**config, **{"terms": True, "content": True}})
382          scoring.index(self.data)
383  
384          # Test text with content
385          text = "Great news today"
386          scoring.index([(scoring.total, text, None)])
387  
388          # Run search and validate correct result returned
389          result = scoring.search("great news", 1)[0]["text"]
390          self.assertEqual(result, text)
391  
392          # Test reading text from dictionary
393          text = "Feel good story: baby panda born"
394          scoring.index([(scoring.total, {"text": text}, None)])
395  
396          # Run search and validate correct result returned
397          result = scoring.search("feel good story", 1)[0]["text"]
398          self.assertEqual(result, text)
399  
400      def empty(self, config):
401          """
402          Test scoring index properly handles an index call when no data present.
403  
404          Args:
405              config: scoring config
406          """
407  
408          # Create scoring index with no data
409          scoring = ScoringFactory.create(config)
410          scoring.index([])
411  
412          # Assert index call returns and index has a count of 0
413          self.assertEqual(scoring.total, 0)
414  
415      def copy(self, config):
416          """
417          Test scoring index copy method.
418          """
419  
420          # Create scoring instance
421          scoring = ScoringFactory.create({**config, **{"terms": True}})
422          scoring.index(self.data)
423  
424          # Generate temp file path
425          index = os.path.join(tempfile.gettempdir(), "scoring")
426          os.makedirs(index, exist_ok=True)
427  
428          # Create file to test replacing existing file
429          path = f"{index}/scoring.{config['method']}.copy"
430          with open(f"{index}.terms", "w", encoding="utf-8") as f:
431              f.write("TEST")
432  
433          # Save scoring instance
434          scoring.save(path)
435          self.assertTrue(os.path.exists(path))
436  
437      @patch("sys.byteorder", "big")
438      def settings(self, config):
439          """
440          Test various settings.
441  
442          Args:
443              config: scoring config
444          """
445  
446          # Create combined config
447          config = {**config, **{"terms": {"cachelimit": 0, "cutoff": 0.25, "wal": True}}}
448  
449          # Create scoring instance
450          scoring = ScoringFactory.create(config)
451          scoring.index(self.data)
452  
453          # Save/load index
454          self.save(scoring, config, f"scoring.{config['method']}.settings")
455  
456          index, _ = scoring.search("bear bear bear wins", 1)[0]
457          self.assertEqual(index, 3)
458  
459          # Save to same path
460          self.save(scoring, config, f"scoring.{config['method']}.settings")
461  
462          # Save to different path
463          self.save(scoring, config, f"scoring.{config['method']}.move")
464  
465          # Validate counts
466          self.assertEqual(scoring.count(), len(self.data))
467  
468      def tokenization(self, config):
469          """
470          Test tokenization methods.
471  
472          Args:
473              config: scoring config
474          """
475  
476          # Test whitespace tokenization
477          config = {**config, **{"terms": True, "tokenizer": {"whitespace": True}}}
478  
479          # Create scoring instance
480          scoring = ScoringFactory.create(config)
481          scoring.index([(0, "abc-def-123", None)])
482  
483          self.assertEqual(scoring.search("abc-def-123")[0][0], 0)
484  
485          # Test regular expression tokenization
486          config = {**config, **{"tokenizer": {"regexp": r"\w{5,}"}}}
487  
488          # Create scoring instance
489          scoring = ScoringFactory.create(config)
490          scoring.index([(0, "hello test", None)])
491  
492          self.assertEqual(scoring.search("hello")[0][0], 0)
493          self.assertFalse(scoring.search("test"))
494  
495          # Test ngram tokenization
496          ngrams = {"ngrams": 3, "lpad": "  ", "rpad": " ", "unique": True}
497          config = {**config, **{"tokenizer": {"ngrams": ngrams}}}
498  
499          # Create scoring instance
500          scoring = ScoringFactory.create(config)
501          scoring.index([(0, "hello test", None)])
502  
503          self.assertEqual(scoring.search("hello")[0][0], 0)