testkeyword.py
1 """ 2 Keyword scoring tests 3 """ 4 5 import os 6 import tempfile 7 import unittest 8 9 from unittest.mock import patch 10 11 from txtai.scoring import Normalize, ScoringFactory, Scoring 12 13 14 # pylint: disable=R0904 15 class TestKeyword(unittest.TestCase): 16 """ 17 Sparse keyword scoring tests. 18 """ 19 20 @classmethod 21 def setUpClass(cls): 22 """ 23 Initialize test data. 24 """ 25 26 cls.data = [ 27 "US tops 5 million confirmed virus cases", 28 "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg", 29 "Beijing mobilises invasion craft along coast as Taiwan tensions escalate", 30 "The National Park Service warns against sacrificing slower friends in a bear attack", 31 "Maine man wins $1M from $25 lottery ticket", 32 "wins wins wins", 33 "Make huge profits without work, earn up to $100,000 a day", 34 ] 35 36 cls.data = [(uid, x, None) for uid, x in enumerate(cls.data)] 37 38 def testBM25(self): 39 """ 40 Test bm25 41 """ 42 43 self.runTests("bm25") 44 45 def testCustom(self): 46 """ 47 Test custom method 48 """ 49 50 self.runTests("txtai.scoring.BM25") 51 52 def testCustomNotFound(self): 53 """ 54 Test unresolvable custom method 55 """ 56 57 with self.assertRaises(ImportError): 58 ScoringFactory.create("notfound.scoring") 59 60 def testNotImplemented(self): 61 """ 62 Test exceptions for non-implemented methods 63 """ 64 65 scoring = Scoring() 66 67 self.assertRaises(NotImplementedError, scoring.insert, None, None) 68 self.assertRaises(NotImplementedError, scoring.delete, None) 69 self.assertRaises(NotImplementedError, scoring.weights, None) 70 self.assertRaises(NotImplementedError, scoring.search, None, None) 71 self.assertRaises(NotImplementedError, scoring.batchsearch, None, None, None) 72 self.assertRaises(NotImplementedError, scoring.count) 73 self.assertRaises(NotImplementedError, scoring.load, None) 74 self.assertRaises(NotImplementedError, scoring.save, None) 75 self.assertRaises(NotImplementedError, scoring.close) 76 self.assertRaises(NotImplementedError, scoring.issparse) 77 self.assertRaises(NotImplementedError, scoring.isnormalized) 78 self.assertRaises(NotImplementedError, scoring.isbayes) 79 80 @patch("sqlalchemy.orm.Query.params") 81 def testPGText(self, query): 82 """ 83 Test PGText 84 """ 85 86 # Mock database query 87 query.return_value = [(3, 1.0)] 88 89 # Create scoring 90 path = os.path.join(tempfile.gettempdir(), "pgtext.sqlite") 91 scoring = ScoringFactory.create({"method": "pgtext", "url": f"sqlite:///{path}", "schema": "txtai"}) 92 scoring.index((uid, {"text": text}, tags) for uid, text, tags in self.data) 93 94 # Run search and validate correct result returned 95 index, _ = scoring.search("bear", 1)[0] 96 self.assertEqual(index, 3) 97 98 # Run batch search 99 index, _ = scoring.batchsearch(["bear"], 1)[0][0] 100 self.assertEqual(index, 3) 101 102 # Validate save/load/delete 103 scoring.save(None) 104 scoring.load(None) 105 106 # Validate count 107 self.assertEqual(scoring.count(), len(self.data)) 108 109 # Test delete 110 scoring.delete([0]) 111 self.assertEqual(scoring.count(), len(self.data) - 1) 112 113 # PGText is a normalized sparse index 114 self.assertTrue(scoring.issparse() and scoring.isnormalized() and not scoring.isbayes()) 115 self.assertIsNone(scoring.weights("This is a test".split())) 116 117 # Close scoring 118 scoring.close() 119 120 def testSIF(self): 121 """ 122 Test sif 123 """ 124 125 self.runTests("sif") 126 127 def testTFIDF(self): 128 """ 129 Test tfidf 130 """ 131 132 self.runTests("tfidf") 133 134 def runTests(self, method): 135 """ 136 Runs a series of tests for a scoring method. 137 138 Args: 139 method: scoring method 140 """ 141 142 config = {"method": method} 143 144 self.index(config) 145 self.upsert(config) 146 self.weights(config) 147 self.search(config) 148 self.delete(config) 149 self.normalize(config) 150 self.content(config) 151 self.empty(config) 152 self.copy(config) 153 self.settings(config) 154 self.tokenization(config) 155 156 def index(self, config, data=None): 157 """ 158 Test scoring index method. 159 160 Args: 161 config: scoring config 162 data: data to index with scoring method 163 164 Returns: 165 scoring 166 """ 167 168 # Derive input data 169 data = data if data else self.data 170 171 scoring = ScoringFactory.create(config) 172 scoring.index(data) 173 174 keys = [k for k, v in sorted(scoring.idf.items(), key=lambda x: x[1])] 175 176 # Test count 177 self.assertEqual(scoring.count(), len(data)) 178 179 # Win should be lowest score 180 self.assertEqual(keys[0], "wins") 181 182 # Test save/load 183 self.assertIsNotNone(self.save(scoring, config, f"scoring.{config['method']}.index")) 184 185 # Test search returns none when terms disabled (default) 186 self.assertIsNone(scoring.search("query")) 187 188 return scoring 189 190 def upsert(self, config): 191 """ 192 Test scoring upsert method 193 """ 194 195 scoring = ScoringFactory.create({**config, **{"tokenizer": {"alphanum": True, "stopwords": True}}}) 196 scoring.upsert(self.data) 197 198 # Test count 199 self.assertEqual(scoring.count(), len(self.data)) 200 201 # Test stop word is removed 202 self.assertFalse("and" in scoring.idf) 203 204 def save(self, scoring, config, name): 205 """ 206 Test scoring index save/load. 207 208 Args: 209 scoring: scoring index 210 config: scoring config 211 name: output file name 212 213 Returns: 214 scoring 215 """ 216 217 # Generate temp file path 218 index = os.path.join(tempfile.gettempdir(), "scoring") 219 os.makedirs(index, exist_ok=True) 220 221 # Save scoring instance 222 scoring.save(f"{index}/{name}") 223 224 # Reload scoring instance 225 scoring = ScoringFactory.create(config) 226 scoring.load(f"{index}/{name}") 227 228 return scoring 229 230 def weights(self, config): 231 """ 232 Test standard and tag weighted scores. 233 234 Args: 235 config: scoring config 236 """ 237 238 document = (1, ["bear", "wins"], None) 239 240 scoring = self.index(config) 241 weights = scoring.weights(document[1]) 242 243 # Default weights 244 self.assertNotEqual(weights[0], weights[1]) 245 246 data = self.data[:] 247 248 uid, text, _ = data[3] 249 data[3] = (uid, text, "wins") 250 251 scoring = self.index(config, data) 252 weights = scoring.weights(document[1]) 253 254 # Modified weights 255 self.assertEqual(weights[0], weights[1]) 256 257 def search(self, config): 258 """ 259 Test scoring search. 260 261 Args: 262 config: scoring config 263 """ 264 265 # Create combined config 266 config = {**config, **{"terms": True}} 267 268 # Create scoring instance 269 scoring = ScoringFactory.create(config) 270 scoring.index(self.data) 271 272 # Run search and validate correct result returned 273 index, _ = scoring.search("bear", 1)[0] 274 self.assertEqual(index, 3) 275 276 # Run batch search 277 index, _ = scoring.batchsearch(["bear"], 1)[0][0] 278 self.assertEqual(index, 3) 279 280 # Run wildcard search 281 index, _ = scoring.search("bea*", 1)[0] 282 self.assertEqual(index, 3) 283 284 # Test save/reload 285 self.save(scoring, config, f"scoring.{config['method']}.search") 286 287 # Re-run search and validate correct result returned 288 index, _ = scoring.search("bear", 1)[0] 289 self.assertEqual(index, 3) 290 291 def delete(self, config): 292 """ 293 Test delete. 294 """ 295 296 # Create combined config 297 config = {**config, **{"terms": True, "content": True}} 298 299 # Create scoring instance 300 scoring = ScoringFactory.create(config) 301 scoring.index(self.data) 302 303 # Run search and validate correct result returned 304 index = scoring.search("bear", 1)[0]["id"] 305 306 # Delete result and validate the query no longer returns results 307 scoring.delete([index]) 308 self.assertFalse(scoring.search("bear", 1)) 309 310 # Save and validate count 311 self.save(scoring, config, f"scoring.{config['method']}.delete") 312 self.assertEqual(scoring.count(), len(self.data) - 1) 313 314 def normalize(self, config): 315 """ 316 Test scoring search with normalized scores. 317 318 Args: 319 method: scoring method 320 """ 321 322 # Default normalization 323 scoring = ScoringFactory.create({**config, **{"terms": True, "normalize": True}}) 324 scoring.index(self.data) 325 326 # Run search and validate correct result returned 327 index, score = scoring.search(self.data[3][1], 1)[0] 328 self.assertEqual(index, 3) 329 self.assertEqual(score, 1.0) 330 331 # Bayesian normalization with default dynamic alpha/beta settings 332 baseline = ScoringFactory.create({**config, **{"terms": True}}) 333 baseline.index(self.data) 334 335 scoring = ScoringFactory.create({**config, **{"terms": True, "normalize": "bayes"}}) 336 scoring.index(self.data) 337 338 query = "wins" 339 base = baseline.search(query, 3) 340 bayes = scoring.search(query, 3) 341 342 # Bayesian normalization should preserve ranking order while mapping scores to [0, 1] 343 self.assertEqual([uid for uid, _ in base], [uid for uid, _ in bayes]) 344 self.assertTrue(all(0.0 <= score <= 1.0 for _, score in bayes)) 345 346 # BB25 alias should resolve to Bayesian normalization 347 scoring = ScoringFactory.create({**config, **{"terms": True, "normalize": "bb25"}}) 348 scoring.index(self.data) 349 bb25 = scoring.search(query, 3) 350 self.assertEqual([uid for uid, _ in base], [uid for uid, _ in bb25]) 351 self.assertTrue(all(0.0 <= score <= 1.0 for _, score in bb25)) 352 353 # BB25 candidate-set behavior: zero scores remain 0, positive scores are transformed 354 normalizer = Normalize("bb25") 355 scores = normalizer([(0, 0.0), (1, 1.0), (2, 2.0)], scoring.avgscore) 356 self.assertEqual(scores[0][1], 0.0) 357 self.assertGreater(scores[1][1], 0.0) 358 self.assertGreater(scores[2][1], scores[1][1]) 359 360 # Test negative scores 361 scores = normalizer([(0, -100.0)], scoring.avgscore) 362 self.assertEqual(scores[0][1], 0.0) 363 364 # Bayesian normalization with custom parameters 365 config = {**config, **{"terms": True, "normalize": {"method": "bayes", "alpha": 2.0}}} 366 scoring = ScoringFactory.create(config) 367 scoring.index(self.data) 368 369 custom = scoring.search(query, 3) 370 self.assertEqual([uid for uid, _ in base], [uid for uid, _ in custom]) 371 self.assertTrue(all(0.0 <= score <= 1.0 for _, score in custom)) 372 373 def content(self, config): 374 """ 375 Test scoring search with content. 376 377 Args: 378 config: scoring config 379 """ 380 381 scoring = ScoringFactory.create({**config, **{"terms": True, "content": True}}) 382 scoring.index(self.data) 383 384 # Test text with content 385 text = "Great news today" 386 scoring.index([(scoring.total, text, None)]) 387 388 # Run search and validate correct result returned 389 result = scoring.search("great news", 1)[0]["text"] 390 self.assertEqual(result, text) 391 392 # Test reading text from dictionary 393 text = "Feel good story: baby panda born" 394 scoring.index([(scoring.total, {"text": text}, None)]) 395 396 # Run search and validate correct result returned 397 result = scoring.search("feel good story", 1)[0]["text"] 398 self.assertEqual(result, text) 399 400 def empty(self, config): 401 """ 402 Test scoring index properly handles an index call when no data present. 403 404 Args: 405 config: scoring config 406 """ 407 408 # Create scoring index with no data 409 scoring = ScoringFactory.create(config) 410 scoring.index([]) 411 412 # Assert index call returns and index has a count of 0 413 self.assertEqual(scoring.total, 0) 414 415 def copy(self, config): 416 """ 417 Test scoring index copy method. 418 """ 419 420 # Create scoring instance 421 scoring = ScoringFactory.create({**config, **{"terms": True}}) 422 scoring.index(self.data) 423 424 # Generate temp file path 425 index = os.path.join(tempfile.gettempdir(), "scoring") 426 os.makedirs(index, exist_ok=True) 427 428 # Create file to test replacing existing file 429 path = f"{index}/scoring.{config['method']}.copy" 430 with open(f"{index}.terms", "w", encoding="utf-8") as f: 431 f.write("TEST") 432 433 # Save scoring instance 434 scoring.save(path) 435 self.assertTrue(os.path.exists(path)) 436 437 @patch("sys.byteorder", "big") 438 def settings(self, config): 439 """ 440 Test various settings. 441 442 Args: 443 config: scoring config 444 """ 445 446 # Create combined config 447 config = {**config, **{"terms": {"cachelimit": 0, "cutoff": 0.25, "wal": True}}} 448 449 # Create scoring instance 450 scoring = ScoringFactory.create(config) 451 scoring.index(self.data) 452 453 # Save/load index 454 self.save(scoring, config, f"scoring.{config['method']}.settings") 455 456 index, _ = scoring.search("bear bear bear wins", 1)[0] 457 self.assertEqual(index, 3) 458 459 # Save to same path 460 self.save(scoring, config, f"scoring.{config['method']}.settings") 461 462 # Save to different path 463 self.save(scoring, config, f"scoring.{config['method']}.move") 464 465 # Validate counts 466 self.assertEqual(scoring.count(), len(self.data)) 467 468 def tokenization(self, config): 469 """ 470 Test tokenization methods. 471 472 Args: 473 config: scoring config 474 """ 475 476 # Test whitespace tokenization 477 config = {**config, **{"terms": True, "tokenizer": {"whitespace": True}}} 478 479 # Create scoring instance 480 scoring = ScoringFactory.create(config) 481 scoring.index([(0, "abc-def-123", None)]) 482 483 self.assertEqual(scoring.search("abc-def-123")[0][0], 0) 484 485 # Test regular expression tokenization 486 config = {**config, **{"tokenizer": {"regexp": r"\w{5,}"}}} 487 488 # Create scoring instance 489 scoring = ScoringFactory.create(config) 490 scoring.index([(0, "hello test", None)]) 491 492 self.assertEqual(scoring.search("hello")[0][0], 0) 493 self.assertFalse(scoring.search("test")) 494 495 # Test ngram tokenization 496 ngrams = {"ngrams": 3, "lpad": " ", "rpad": " ", "unique": True} 497 config = {**config, **{"tokenizer": {"ngrams": ngrams}}} 498 499 # Create scoring instance 500 scoring = ScoringFactory.create(config) 501 scoring.index([(0, "hello test", None)]) 502 503 self.assertEqual(scoring.search("hello")[0][0], 0)