testembeddings.py
1 """ 2 Embeddings module tests 3 """ 4 5 import json 6 import os 7 import tempfile 8 import unittest 9 10 from unittest.mock import patch 11 12 import numpy as np 13 14 from txtai.embeddings import Embeddings, Reducer 15 from txtai.serialize import SerializeFactory 16 17 18 # pylint: disable=R0904 19 class TestEmbeddings(unittest.TestCase): 20 """ 21 Embeddings tests. 22 """ 23 24 @classmethod 25 def setUpClass(cls): 26 """ 27 Initialize test data. 28 """ 29 30 cls.data = [ 31 "US tops 5 million confirmed virus cases", 32 "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg", 33 "Beijing mobilises invasion craft along coast as Taiwan tensions escalate", 34 "The National Park Service warns against sacrificing slower friends in a bear attack", 35 "Maine man wins $1M from $25 lottery ticket", 36 "Make huge profits without work, earn up to $100,000 a day", 37 ] 38 39 # Create embeddings model, backed by sentence-transformers & transformers 40 cls.embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2"}) 41 42 @classmethod 43 def tearDownClass(cls): 44 """ 45 Cleanup data. 46 """ 47 48 if cls.embeddings: 49 cls.embeddings.close() 50 51 def testAutoId(self): 52 """ 53 Test auto id generation 54 """ 55 56 # Default sequence id 57 embeddings = Embeddings() 58 embeddings.index(self.data) 59 60 uid = embeddings.search(self.data[4], 1)[0][0] 61 self.assertEqual(uid, 4) 62 63 # UUID 64 embeddings = Embeddings(autoid="uuid4") 65 embeddings.index(self.data) 66 67 uid = embeddings.search(self.data[4], 1)[0][0] 68 self.assertEqual(len(uid), 36) 69 70 def testColumns(self): 71 """ 72 Test custom text/object columns 73 """ 74 75 embeddings = Embeddings({"keyword": True, "columns": {"text": "value"}}) 76 data = [{"value": x} for x in self.data] 77 embeddings.index([(uid, text, None) for uid, text in enumerate(data)]) 78 79 # Run search 80 uid = embeddings.search("lottery", 1)[0][0] 81 self.assertEqual(uid, 4) 82 83 def testContext(self): 84 """ 85 Test embeddings context manager 86 """ 87 88 # Generate temp file path 89 index = os.path.join(tempfile.gettempdir(), "embeddings.context") 90 91 with Embeddings() as embeddings: 92 embeddings.index(self.data) 93 embeddings.save(index) 94 95 with Embeddings().load(index) as embeddings: 96 uid = embeddings.search(self.data[4], 1)[0][0] 97 self.assertEqual(uid, 4) 98 99 def testDefaults(self): 100 """ 101 Test default configuration 102 """ 103 104 # Run index with no config which will fall back to default configuration 105 embeddings = Embeddings() 106 embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 107 108 self.assertEqual(embeddings.count(), 6) 109 110 def testDelete(self): 111 """ 112 Test delete 113 """ 114 115 # Create an index for the list of text 116 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 117 118 # Delete best match 119 self.embeddings.delete([4]) 120 121 # Search for best match 122 uid = self.embeddings.search("feel good story", 1)[0][0] 123 124 self.assertEqual(self.embeddings.count(), 5) 125 self.assertEqual(uid, 5) 126 127 def testDense(self): 128 """ 129 Test dense alias 130 """ 131 132 # Dense flag is an alias for path 133 embeddings = Embeddings(dense="sentence-transformers/nli-mpnet-base-v2") 134 embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 135 136 self.assertEqual(embeddings.count(), 6) 137 138 def testEmpty(self): 139 """ 140 Test empty index 141 """ 142 143 # Test search against empty index 144 embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2"}) 145 self.assertEqual(embeddings.search("test"), []) 146 147 # Test index with no data 148 embeddings.index([]) 149 self.assertIsNone(embeddings.ann) 150 151 # Test upsert with no data 152 embeddings.index([(0, "this is a test", None)]) 153 embeddings.upsert([]) 154 self.assertIsNotNone(embeddings.ann) 155 156 def testEmptyString(self): 157 """ 158 Test empty string indexing 159 """ 160 161 # Test empty string 162 self.embeddings.index([(0, "", None)]) 163 self.assertTrue(self.embeddings.search("test")) 164 165 # Test empty string with dict 166 self.embeddings.index([(0, {"text": ""}, None)]) 167 self.assertTrue(self.embeddings.search("test")) 168 169 def testExternal(self): 170 """ 171 Test embeddings backed by external vectors 172 """ 173 174 def transform(data): 175 embeddings = [] 176 for text in data: 177 # Create dummy embedding using sum and mean of character ordinals 178 ordinals = [ord(c) for c in text] 179 embeddings.append(np.array([sum(ordinals), np.mean(ordinals)])) 180 181 return embeddings 182 183 # Index data using simple embeddings transform method 184 embeddings = Embeddings({"method": "external", "transform": transform}) 185 embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 186 187 # Run search 188 uid = embeddings.search(self.data[4], 1)[0][0] 189 self.assertEqual(uid, 4) 190 191 def testExternalPrecomputed(self): 192 """ 193 Test embeddings backed by external pre-computed vectors 194 """ 195 196 # Test with no transform function 197 data = np.random.rand(5, 10).astype(np.float32) 198 199 embeddings = Embeddings({"method": "external"}) 200 embeddings.index([(uid, row, None) for uid, row in enumerate(data)]) 201 202 # Run search 203 uid = embeddings.search(data[4], 1)[0][0] 204 self.assertEqual(uid, 4) 205 206 def testHybrid(self): 207 """ 208 Test hybrid search 209 """ 210 211 # Build data array 212 data = [(uid, text, None) for uid, text in enumerate(self.data)] 213 214 # Index data with sparse + dense vectors 215 embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2", "hybrid": True}) 216 embeddings.index(data) 217 218 # Run search 219 uid = embeddings.search("feel good story", 1)[0][0] 220 self.assertEqual(uid, 4) 221 222 # Generate temp file path 223 index = os.path.join(tempfile.gettempdir(), "embeddings.hybrid") 224 225 # Test load/save 226 embeddings.save(index) 227 embeddings.load(index) 228 229 # Run search 230 uid = embeddings.search("feel good story", 1)[0][0] 231 self.assertEqual(uid, 4) 232 233 # Index data with sparse + dense vectors and unnormalized scores 234 embeddings.config["scoring"]["normalize"] = False 235 embeddings.index(data) 236 237 # Run search 238 uid = embeddings.search("feel good story", 1)[0][0] 239 self.assertEqual(uid, 4) 240 241 # Index data with sparse + dense vectors and bb25 normalization 242 embeddings.config["scoring"]["normalize"] = "bb25" 243 embeddings.index(data) 244 245 # Run search 246 uid = embeddings.search("canada intact iceberg a", 1)[0][0] 247 self.assertEqual(uid, 1) 248 249 # Test upsert 250 data[0] = (0, "Feel good story: baby panda born", None) 251 embeddings.upsert([data[0]]) 252 253 uid = embeddings.search("feel good story", 1)[0][0] 254 self.assertEqual(uid, 0) 255 256 def testIds(self): 257 """ 258 Test legacy config ids loading 259 """ 260 261 # Create an index for the list of text 262 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 263 264 # Generate temp file path 265 index = os.path.join(tempfile.gettempdir(), "embeddings.ids") 266 267 # Save index 268 self.embeddings.save(index) 269 270 # Set ids on config to simulate legacy ids format 271 with open(f"{index}/config.json", "r", encoding="utf-8") as handle: 272 config = json.load(handle) 273 config["ids"] = list(range(len(self.data))) 274 275 with open(f"{index}/config.json", "w", encoding="utf-8") as handle: 276 json.dump(config, handle, default=str, indent=2) 277 278 # Reload index 279 self.embeddings.load(index) 280 281 # Run search 282 uid = self.embeddings.search("feel good story", 1)[0][0] 283 self.assertEqual(uid, 4) 284 285 # Check that ids is not in config 286 self.assertTrue("ids" not in self.embeddings.config) 287 288 @patch.dict(os.environ, {"ALLOW_PICKLE": "True"}) 289 def testIdsPickle(self): 290 """ 291 Test legacy pickle ids 292 """ 293 294 # Create an index for the list of text 295 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 296 297 # Generate temp file path 298 index = os.path.join(tempfile.gettempdir(), "embeddings.idspickle") 299 300 # Save index 301 self.embeddings.save(index) 302 303 # Create ids as pickle 304 path = os.path.join(tempfile.gettempdir(), "embeddings.idspickle", "ids") 305 serializer = SerializeFactory.create("pickle", allowpickle=True) 306 serializer.save(self.embeddings.ids.ids, path) 307 308 with self.assertWarns(RuntimeWarning): 309 self.embeddings.load(index) 310 311 # Run search 312 uid = self.embeddings.search("feel good story", 1)[0][0] 313 self.assertEqual(uid, 4) 314 315 def testIndex(self): 316 """ 317 Test index 318 """ 319 320 # Create an index for the list of text 321 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 322 323 # Search for best match 324 uid = self.embeddings.search("feel good story", 1)[0][0] 325 326 self.assertEqual(uid, 4) 327 328 def testKeyword(self): 329 """ 330 Test keyword only (sparse) search 331 """ 332 333 # Build data array 334 data = [(uid, text, None) for uid, text in enumerate(self.data)] 335 336 # Index data with sparse keyword vectors 337 embeddings = Embeddings({"keyword": True}) 338 embeddings.index(data) 339 340 # Run search 341 uid = embeddings.search("lottery ticket", 1)[0][0] 342 self.assertEqual(uid, 4) 343 344 # Test count method 345 self.assertEqual(embeddings.count(), len(data)) 346 347 # Generate temp file path 348 index = os.path.join(tempfile.gettempdir(), "embeddings.keyword") 349 350 # Test load/save 351 embeddings.save(index) 352 embeddings.load(index) 353 354 # Run search 355 uid = embeddings.search("lottery ticket", 1)[0][0] 356 self.assertEqual(uid, 4) 357 358 # Update data 359 data[0] = (0, "Feel good story: baby panda born", None) 360 embeddings.upsert([data[0]]) 361 362 # Search for best match 363 uid = embeddings.search("feel good story", 1)[0][0] 364 self.assertEqual(uid, 0) 365 366 def testQuantize(self): 367 """ 368 Test scalar quantization 369 """ 370 371 for ann in ["faiss", "numpy", "torch"]: 372 # Index data with 1-bit scalar quantization 373 embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2", "quantize": 1, "backend": ann}) 374 embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 375 376 # Search for best match 377 uid = embeddings.search("feel good story", 1)[0][0] 378 self.assertEqual(uid, 4) 379 380 def testReducer(self): 381 """ 382 Test reducer model 383 """ 384 385 # Test model with single PCA component 386 data = np.random.rand(5, 5).astype(np.float32) 387 reducer = Reducer(data, 1) 388 389 # Generate query and keep original data to ensure it changes 390 query = np.random.rand(1, 5).astype(np.float32) 391 original = query.copy() 392 393 # Run test 394 reducer(query) 395 self.assertFalse(np.array_equal(query, original)) 396 397 # Test model with multiple PCA components 398 reducer = Reducer(data, 3) 399 400 # Generate query and keep original data to ensure it changes 401 query = np.random.rand(5).astype(np.float32) 402 original = query.copy() 403 404 # Run test 405 reducer(query) 406 self.assertFalse(np.array_equal(query, original)) 407 408 @patch.dict(os.environ, {"ALLOW_PICKLE": "True"}) 409 def testReducerLegacy(self): 410 """ 411 Test reducer model with legacy model format 412 """ 413 414 # Test model with single PCA component 415 data = np.random.rand(5, 5).astype(np.float32) 416 reducer = Reducer(data, 1) 417 418 # Save legacy format 419 path = os.path.join(tempfile.gettempdir(), "reducer") 420 serializer = SerializeFactory.create("pickle", allowpickle=True) 421 serializer.save(reducer.model, path) 422 423 # Load legacy format 424 reducer = Reducer() 425 reducer.load(path) 426 427 # Generate query and keep original data to ensure it changes 428 query = np.random.rand(1, 5).astype(np.float32) 429 original = query.copy() 430 431 # Run test 432 reducer(query) 433 self.assertFalse(np.array_equal(query, original)) 434 435 def testSave(self): 436 """ 437 Test save 438 """ 439 440 # Create an index for the list of text 441 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 442 443 # Generate temp file path 444 index = os.path.join(tempfile.gettempdir(), "embeddings.base") 445 446 self.embeddings.save(index) 447 self.embeddings.load(index) 448 449 # Search for best match 450 uid = self.embeddings.search("feel good story", 1)[0][0] 451 452 self.assertEqual(uid, 4) 453 454 # Test offsets still work after save/load 455 self.embeddings.upsert([(0, "Looking out into the dreadful abyss", None)]) 456 self.assertEqual(self.embeddings.count(), len(self.data)) 457 458 def testShortcuts(self): 459 """ 460 Test embeddings creation shortcuts 461 """ 462 463 tests = [ 464 ({"keyword": True}, ["scoring"]), 465 ({"keyword": "sif"}, ["scoring"]), 466 ({"sparse": True}, ["scoring"]), 467 ({"dense": True}, ["ann"]), 468 ({"hybrid": True}, ["ann", "scoring"]), 469 ({"hybrid": "tfidf"}, ["ann", "scoring"]), 470 ({"hybrid": "sparse"}, ["ann", "scoring"]), 471 ({"graph": True}, ["graph"]), 472 ] 473 474 for config, checks in tests: 475 embeddings = Embeddings(config) 476 embeddings.index(["test"]) 477 478 for attr in checks: 479 self.assertIsNotNone(getattr(embeddings, attr)) 480 481 def testSimilarity(self): 482 """ 483 Test similarity 484 """ 485 486 # Get best matching id 487 uid = self.embeddings.similarity("feel good story", self.data)[0][0] 488 489 self.assertEqual(uid, 4) 490 491 def testSparse(self): 492 """ 493 Test sparse vector search 494 """ 495 496 # Build data array 497 data = [(uid, text, None) for uid, text in enumerate(self.data)] 498 499 # Index data with sparse vectors 500 embeddings = Embeddings({"sparse": "sparse-encoder-testing/splade-bert-tiny-nq"}) 501 embeddings.index(data) 502 503 # Run search 504 uid = embeddings.search("lottery ticket", 1)[0][0] 505 self.assertEqual(uid, 4) 506 507 # Test count method 508 self.assertEqual(embeddings.count(), len(data)) 509 510 # Generate temp file path 511 index = os.path.join(tempfile.gettempdir(), "embeddings.sparse") 512 513 # Test load/save 514 embeddings.save(index) 515 embeddings.load(index) 516 517 # Run search 518 uid = embeddings.search("lottery ticket", 1)[0][0] 519 self.assertEqual(uid, 4) 520 521 # Test similarity 522 uid = embeddings.similarity("lottery ticket", self.data)[0][0] 523 self.assertEqual(uid, 4) 524 525 # Update data 526 data[0] = (0, "Feel good story: baby panda born", None) 527 embeddings.upsert([data[0]]) 528 529 # Search for best match 530 uid = embeddings.search("feel good story", 1)[0][0] 531 self.assertEqual(uid, 0) 532 533 def testSubindex(self): 534 """ 535 Test subindex 536 """ 537 538 # Build data array 539 data = [(uid, text, None) for uid, text in enumerate(self.data)] 540 541 # Disable top-level indexing and create subindex 542 embeddings = Embeddings({"defaults": False, "indexes": {"index1": {"path": "sentence-transformers/nli-mpnet-base-v2"}}}) 543 embeddings.index(data) 544 545 # Test transform 546 self.assertEqual(embeddings.transform("feel good story").shape, (768,)) 547 self.assertEqual(embeddings.transform("feel good story", index="index1").shape, (768,)) 548 with self.assertRaises(KeyError): 549 embeddings.transform("feel good story", index="index2") 550 551 # Run search 552 uid = embeddings.search("feel good story", 1)[0][0] 553 self.assertEqual(uid, 4) 554 555 # Generate temp file path 556 index = os.path.join(tempfile.gettempdir(), "embeddings.subindex") 557 558 # Test load/save 559 embeddings.save(index) 560 embeddings.load(index) 561 562 # Run search 563 uid = embeddings.search("feel good story", 1)[0][0] 564 self.assertEqual(uid, 4) 565 566 # Update data 567 data[0] = (0, "Feel good story: baby panda born", None) 568 embeddings.upsert([data[0]]) 569 570 # Search for best match 571 uid = embeddings.search("feel good story", 10)[0][0] 572 self.assertEqual(uid, 0) 573 574 # Check missing text is set to id when top-level indexing is disabled 575 embeddings.upsert([(embeddings.count(), {"content": "empty text"}, None)]) 576 uid = embeddings.search(f"{embeddings.count() - 1}", 1)[0][0] 577 self.assertEqual(uid, embeddings.count() - 1) 578 579 # Close embeddings 580 embeddings.close() 581 582 def testTruncate(self): 583 """ 584 Test dimensionality truncation 585 """ 586 587 # Truncate vectors to a specified number of dimensions 588 embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2", "dimensionality": 750, "vectors": {"revision": "main"}}) 589 embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 590 591 # Search for best match 592 uid = embeddings.search("feel good story", 1)[0][0] 593 self.assertEqual(uid, 4) 594 595 def testUpsert(self): 596 """ 597 Test upsert 598 """ 599 600 # Build data array 601 data = [(uid, text, None) for uid, text in enumerate(self.data)] 602 603 # Reset embeddings for test 604 self.embeddings.ann = None 605 self.embeddings.ids = None 606 607 # Create an index for the list of text 608 self.embeddings.upsert(data) 609 610 # Update data 611 data[0] = (0, "Feel good story: baby panda born", None) 612 self.embeddings.upsert([data[0]]) 613 614 # Search for best match 615 uid = self.embeddings.search("feel good story", 1)[0][0] 616 617 self.assertEqual(uid, 0) 618 619 @patch("os.cpu_count") 620 def testWords(self, cpucount): 621 """ 622 Test embeddings backed by word vectors 623 """ 624 625 # Mock CPU count 626 cpucount.return_value = 1 627 628 # Create dataset 629 data = [(x, row.split(), None) for x, row in enumerate(self.data)] 630 631 # Create embeddings model, backed by word vectors 632 embeddings = Embeddings({"path": "neuml/glove-6B-quantized", "scoring": "bm25", "pca": 3, "quantize": True}) 633 634 # Call scoring and index methods 635 embeddings.score(data) 636 embeddings.index(data) 637 638 # Test search 639 self.assertIsNotNone(embeddings.search("win", 1)) 640 641 # Generate temp file path 642 index = os.path.join(tempfile.gettempdir(), "embeddings.wordvectors") 643 644 # Test save/load 645 embeddings.save(index) 646 embeddings.load(index) 647 648 # Test search 649 self.assertIsNotNone(embeddings.search("win", 1)) 650 651 @patch("os.cpu_count") 652 def testWordsUpsert(self, cpucount): 653 """ 654 Test embeddings backed by word vectors with upserts 655 """ 656 657 # Mock CPU count 658 cpucount.return_value = 1 659 660 # Create dataset 661 data = [(x, row.split(), None) for x, row in enumerate(self.data)] 662 663 # Create embeddings model, backed by word vectors 664 embeddings = Embeddings({"path": "neuml/glove-6B/model.sqlite", "scoring": "bm25", "pca": 3}) 665 666 # Call scoring and index methods 667 embeddings.score(data) 668 embeddings.index(data) 669 670 # Now upsert and override record 671 data = [(0, "win win", None)] 672 673 # Update scoring and run upsert 674 embeddings.score(data) 675 embeddings.upsert(data) 676 677 # Test search after upsert 678 uid = embeddings.search("win", 1)[0][0] 679 self.assertEqual(uid, 0)