testrdbms.py
1 """ 2 Common file database module tests 3 """ 4 5 import contextlib 6 import io 7 import os 8 import tempfile 9 import unittest 10 11 from unittest.mock import patch 12 13 from txtai.embeddings import Embeddings, IndexNotFoundError 14 from txtai.database import Embedded, RDBMS, SQLError 15 16 17 class Common: 18 """ 19 Wraps common file database tests to prevent unit test discovery for this class. 20 """ 21 22 # pylint: disable=R0904 23 class TestRDBMS(unittest.TestCase): 24 """ 25 Embeddings with content stored in a file database tests. 26 """ 27 28 @classmethod 29 def setUpClass(cls): 30 """ 31 Initialize test data. 32 """ 33 34 cls.data = [ 35 "US tops 5 million confirmed virus cases", 36 "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg", 37 "Beijing mobilises invasion craft along coast as Taiwan tensions escalate", 38 "The National Park Service warns against sacrificing slower friends in a bear attack", 39 "Maine man wins $1M from $25 lottery ticket", 40 "Make huge profits without work, earn up to $100,000 a day", 41 ] 42 43 # Content backend 44 cls.backend = None 45 46 # Create embeddings model, backed by sentence-transformers & transformers 47 cls.embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2", "content": cls.backend}) 48 49 @classmethod 50 def tearDownClass(cls): 51 """ 52 Cleanup data. 53 """ 54 55 if cls.embeddings: 56 cls.embeddings.close() 57 58 def testArchive(self): 59 """ 60 Test embeddings index archiving 61 """ 62 63 for extension in ["tar.bz2", "tar.gz", "tar.xz", "zip"]: 64 # Create an index for the list of text 65 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 66 67 # Generate temp file path 68 index = os.path.join(tempfile.gettempdir(), f"embeddings.{self.category()}.{extension}") 69 70 self.embeddings.save(index) 71 self.embeddings.load(index) 72 73 # Search for best match 74 result = self.embeddings.search("feel good story", 1)[0] 75 76 self.assertEqual(result["text"], self.data[4]) 77 78 # Test offsets still work after save/load 79 self.embeddings.upsert([(0, "Looking out into the dreadful abyss", None)]) 80 self.assertEqual(self.embeddings.count(), len(self.data)) 81 82 def testAutoId(self): 83 """ 84 Test auto id generation 85 """ 86 87 # Default sequence id 88 embeddings = Embeddings(path="sentence-transformers/nli-mpnet-base-v2", content=self.backend) 89 embeddings.index(self.data) 90 91 result = embeddings.search("feel good story", 1)[0] 92 self.assertEqual(result["text"], self.data[4]) 93 94 # UUID 95 embeddings.config["autoid"] = "uuid4" 96 embeddings.index(self.data) 97 98 result = embeddings.search(self.data[4], 1)[0] 99 self.assertEqual(len(result["id"]), 36) 100 101 def testCheckpoint(self): 102 """ 103 Test embeddings index checkpoints 104 """ 105 106 # Checkpoint directory 107 checkpoint = os.path.join(tempfile.gettempdir(), f"embeddings.{self.category()}.checkpoint") 108 109 # Save embeddings checkpoint 110 self.embeddings.index(self.data, checkpoint=checkpoint) 111 112 # Reindex with checkpoint 113 self.embeddings.index(self.data, checkpoint=checkpoint) 114 115 # Search for best match 116 result = self.embeddings.search("feel good story", 1)[0] 117 self.assertEqual(result["text"], self.data[4]) 118 119 def testColumns(self): 120 """ 121 Test custom text/object columns 122 """ 123 124 embeddings = Embeddings({"keyword": True, "content": self.backend, "columns": {"text": "value"}}) 125 data = [{"value": x} for x in self.data] 126 embeddings.index([(uid, text, None) for uid, text in enumerate(data)]) 127 128 # Run search 129 result = embeddings.search("lottery", 1)[0] 130 self.assertEqual(result["text"], self.data[4]) 131 132 def testClose(self): 133 """ 134 Test embeddings close 135 """ 136 137 embeddings = None 138 139 # Create index twice to test open/close and ensure resources are freed 140 for _ in range(2): 141 embeddings = Embeddings( 142 {"path": "sentence-transformers/nli-mpnet-base-v2", "scoring": {"method": "bm25", "terms": True}, "content": self.backend} 143 ) 144 145 # Add record to index 146 embeddings.index([(0, "Close test", None)]) 147 148 # Save index 149 index = os.path.join(tempfile.gettempdir(), f"embeddings.{self.category()}.close") 150 embeddings.save(index) 151 152 # Close index 153 embeddings.close() 154 155 # Test embeddings is empty 156 self.assertIsNone(embeddings.ann) 157 self.assertIsNone(embeddings.database) 158 159 def testData(self): 160 """ 161 Test content storage and retrieval 162 """ 163 164 data = self.data + [{"date": "2021-01-01", "text": "Baby panda", "flag": 1}] 165 166 # Create an index for the list of text 167 self.embeddings.index([(uid, text, None) for uid, text in enumerate(data)]) 168 169 # Search for best match 170 result = self.embeddings.search("feel good story", 1)[0] 171 self.assertEqual(result["text"], data[-1]["text"]) 172 173 def testDelete(self): 174 """ 175 Test delete 176 """ 177 178 # Create an index for the list of text 179 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 180 181 # Delete best match 182 self.embeddings.delete([4]) 183 184 # Search for best match 185 result = self.embeddings.search("feel good story", 1)[0] 186 187 self.assertEqual(self.embeddings.count(), 5) 188 self.assertEqual(result["text"], self.data[5]) 189 190 def testEmpty(self): 191 """ 192 Test empty index 193 """ 194 195 # Test search against empty index 196 embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2", "content": self.backend}) 197 self.assertEqual(embeddings.search("test"), []) 198 199 # Test index with no data 200 embeddings.index([]) 201 self.assertIsNone(embeddings.ann) 202 203 # Test upsert with no data 204 embeddings.index([(0, "this is a test", None)]) 205 embeddings.upsert([]) 206 self.assertIsNotNone(embeddings.ann) 207 208 def testEmptyString(self): 209 """ 210 Test empty string indexing 211 """ 212 213 # Test empty string 214 self.embeddings.index([(0, "", None)]) 215 self.assertTrue(self.embeddings.search("test")) 216 217 # Test empty string with dict 218 self.embeddings.index([(0, {"text": ""}, None)]) 219 self.assertTrue(self.embeddings.search("test")) 220 221 def testExplain(self): 222 """ 223 Test query explain 224 """ 225 226 # Test explain with similarity 227 result = self.embeddings.explain("feel good story", self.data)[0] 228 self.assertEqual(result["text"], self.data[4]) 229 self.assertEqual(len(result.get("tokens")), 8) 230 231 def testExplainBatch(self): 232 """ 233 Test query explain batch 234 """ 235 236 # Test explain with query 237 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 238 239 result = self.embeddings.batchexplain(["feel good story"], limit=1)[0][0] 240 self.assertEqual(result["text"], self.data[4]) 241 self.assertEqual(len(result.get("tokens")), 8) 242 243 def testExplainEmpty(self): 244 """ 245 Test query explain with no filtering criteria 246 """ 247 248 self.assertEqual(self.embeddings.explain("select * from txtai limit 1")[0]["id"], "0") 249 250 def testExpressions(self): 251 """ 252 Test expressions 253 """ 254 255 # Test indexed expressions 256 embeddings = Embeddings( 257 path="sentence-transformers/nli-mpnet-base-v2", 258 content=self.backend, 259 expressions=[{"name": "textlength", "expression": "length(text)", "index": True}], 260 ) 261 embeddings.index(self.data) 262 263 result = embeddings.search("SELECT textlength FROM txtai WHERE id = 0", 1)[0] 264 self.assertEqual(result["textlength"], len(self.data[0])) 265 266 def testGenerator(self): 267 """ 268 Test index with a generator 269 """ 270 271 def documents(): 272 for uid, text in enumerate(self.data): 273 yield (uid, text, None) 274 275 # Create an index for the list of text 276 self.embeddings.index(documents()) 277 278 # Search for best match 279 result = self.embeddings.search("feel good story", 1)[0] 280 281 self.assertEqual(result["text"], self.data[4]) 282 283 def testHybrid(self): 284 """ 285 Test hybrid search 286 """ 287 288 # Build data array 289 data = [(uid, text, None) for uid, text in enumerate(self.data)] 290 291 # Index data with sparse + dense vectors. 292 embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2", "hybrid": True, "content": self.backend}) 293 embeddings.index(data) 294 295 # Run search 296 result = embeddings.search("feel good story", 1)[0] 297 self.assertEqual(result["text"], data[4][1]) 298 299 # Generate temp file path 300 index = os.path.join(tempfile.gettempdir(), f"embeddings.{self.category()}.hybrid") 301 302 # Test load/save 303 embeddings.save(index) 304 embeddings.load(index) 305 306 # Run search 307 result = embeddings.search("feel good story", 1)[0] 308 self.assertEqual(result["text"], data[4][1]) 309 310 # Index data with sparse + dense vectors and unnormalized scores. 311 embeddings.config["scoring"]["normalize"] = False 312 embeddings.index(data) 313 314 # Run search 315 result = embeddings.search("feel good story", 1)[0] 316 self.assertEqual(result["text"], data[4][1]) 317 318 # Index data with sparse + dense vectors and bb25 normalized scores 319 embeddings.config["scoring"]["normalize"] = "bb25" 320 embeddings.index(data) 321 322 # Run search 323 result = embeddings.search("canada intact iceberg a", 1)[0] 324 self.assertEqual(result["text"], data[1][1]) 325 326 # Test upsert 327 data[0] = (0, "Feel good story: baby panda born", None) 328 embeddings.upsert([data[0]]) 329 330 result = embeddings.search("feel good story", 1)[0] 331 self.assertEqual(result["text"], data[0][1]) 332 333 def testIndex(self): 334 """ 335 Test index 336 """ 337 338 # Create an index for the list of text 339 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 340 341 # Search for best match 342 result = self.embeddings.search("feel good story", 1)[0] 343 344 self.assertEqual(result["text"], self.data[4]) 345 346 def testIndexTokens(self): 347 """ 348 Test index with tokens 349 """ 350 351 # Create an index for the list of text 352 self.embeddings.index([(uid, text.split(), None) for uid, text in enumerate(self.data)]) 353 354 # Search for best match 355 result = self.embeddings.search("feel good story", 1)[0] 356 357 self.assertEqual(result["text"], self.data[4]) 358 359 def testInfo(self): 360 """ 361 Test info 362 """ 363 364 # Create an index for the list of text 365 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 366 367 output = io.StringIO() 368 with contextlib.redirect_stdout(output): 369 self.embeddings.info() 370 371 self.assertIn("txtai", output.getvalue()) 372 373 def testInstructions(self): 374 """ 375 Test indexing with instruction prefixes. 376 """ 377 378 embeddings = Embeddings( 379 { 380 "path": "sentence-transformers/nli-mpnet-base-v2", 381 "content": self.backend, 382 "instructions": {"query": "query: ", "data": "passage: "}, 383 } 384 ) 385 386 embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 387 388 # Search for best match 389 result = embeddings.search("feel good story", 1)[0] 390 391 self.assertEqual(result["text"], self.data[4]) 392 393 def testInvalidData(self): 394 """ 395 Test invalid JSON data 396 """ 397 398 # Test invalid JSON value 399 with self.assertRaises(ValueError): 400 self.embeddings.index([(0, {"text": "This is a test", "flag": float("NaN")}, None)]) 401 402 def testKeyword(self): 403 """ 404 Test keyword only (sparse) search 405 """ 406 407 # Build data array 408 data = [(uid, text, None) for uid, text in enumerate(self.data)] 409 410 # Index data with sparse keyword vectors 411 embeddings = Embeddings({"keyword": True, "content": self.backend}) 412 embeddings.index(data) 413 414 # Run search 415 result = embeddings.search("lottery ticket", 1)[0] 416 self.assertEqual(result["text"], data[4][1]) 417 418 # Test count method 419 self.assertEqual(embeddings.count(), len(data)) 420 421 # Generate temp file path 422 index = os.path.join(tempfile.gettempdir(), f"embeddings.{self.category()}.keyword") 423 424 # Test load/save 425 embeddings.save(index) 426 embeddings.load(index) 427 428 # Run search 429 result = embeddings.search("lottery ticket", 1)[0] 430 self.assertEqual(result["text"], data[4][1]) 431 432 # Update data 433 data[0] = (0, "Feel good story: baby panda born", None) 434 embeddings.upsert([data[0]]) 435 436 # Search for best match 437 result = embeddings.search("feel good story", 1)[0] 438 self.assertEqual(result["text"], data[0][1]) 439 440 def testMultiData(self): 441 """ 442 Test indexing with multiple data types (text, documents) 443 """ 444 445 embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2", "content": self.backend, "batch": len(self.data)}) 446 447 # Create an index using mixed data (text and documents) 448 data = [] 449 for uid, text in enumerate(self.data): 450 data.append((uid, text, None)) 451 data.append((uid, {"content": text}, None)) 452 453 embeddings.index(data) 454 455 # Search for best match 456 result = embeddings.search("feel good story", 1)[0] 457 458 self.assertEqual(result["text"], self.data[4]) 459 460 def testMultiSave(self): 461 """ 462 Test multiple successive saves 463 """ 464 465 # Create an index for the list of text 466 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 467 468 # Save original index 469 index = os.path.join(tempfile.gettempdir(), f"embeddings.{self.category()}.insert") 470 self.embeddings.save(index) 471 472 # Modify index 473 self.embeddings.upsert([(0, "Looking out into the dreadful abyss", None)]) 474 475 # Save to a different location 476 indexupdate = os.path.join(tempfile.gettempdir(), f"embeddings.{self.category()}.update") 477 self.embeddings.save(indexupdate) 478 479 # Save to same location 480 self.embeddings.save(index) 481 482 # Test all indexes match 483 result = self.embeddings.search("feel good story", 1)[0] 484 self.assertEqual(result["text"], self.data[4]) 485 486 self.embeddings.load(index) 487 result = self.embeddings.search("feel good story", 1)[0] 488 self.assertEqual(result["text"], self.data[4]) 489 490 self.embeddings.load(indexupdate) 491 result = self.embeddings.search("feel good story", 1)[0] 492 self.assertEqual(result["text"], self.data[4]) 493 494 def testNoIndex(self): 495 """ 496 Test an embeddings instance with no available indexes 497 """ 498 499 # Disable top-level indexing 500 embeddings = Embeddings( 501 { 502 "content": self.backend, 503 "defaults": False, 504 } 505 ) 506 embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 507 508 with self.assertRaises(IndexNotFoundError): 509 embeddings.search("select id, text, score from txtai where similar('feel good story')") 510 511 def testNotImplemented(self): 512 """ 513 Test exceptions for non-implemented methods 514 """ 515 516 db = RDBMS({}) 517 518 self.assertRaises(NotImplementedError, db.connect, None) 519 self.assertRaises(NotImplementedError, db.getcursor) 520 self.assertRaises(NotImplementedError, db.jsonprefix) 521 self.assertRaises(NotImplementedError, db.jsoncolumn, None) 522 self.assertRaises(NotImplementedError, db.rows) 523 self.assertRaises(NotImplementedError, db.addfunctions) 524 525 db = Embedded({}) 526 self.assertRaises(NotImplementedError, db.copy, None) 527 528 def testObject(self): 529 """ 530 Test object field 531 """ 532 533 # Encode object 534 embeddings = Embeddings({"defaults": False, "content": self.backend, "objects": True}) 535 embeddings.index([{"object": "binary data".encode("utf-8")}]) 536 537 # Decode and test extracted object 538 obj = embeddings.search("select object from txtai where id = 0")[0]["object"] 539 self.assertEqual(str(obj.getvalue(), "utf-8"), "binary data") 540 541 @patch.dict(os.environ, {"ALLOW_PICKLE": "True"}) 542 def testPickle(self): 543 """ 544 Test pickle configuration 545 """ 546 547 embeddings = Embeddings( 548 { 549 "format": "pickle", 550 "path": "sentence-transformers/nli-mpnet-base-v2", 551 "content": self.backend, 552 } 553 ) 554 555 embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 556 557 # Generate temp file path 558 index = os.path.join(tempfile.gettempdir(), f"embeddings.{self.category()}.pickle") 559 560 embeddings.save(index) 561 562 # Check that config exists 563 self.assertTrue(os.path.exists(os.path.join(index, "config"))) 564 565 # Check that index can be reloaded 566 embeddings.load(index) 567 self.assertEqual(embeddings.count(), 6) 568 569 def testQuantize(self): 570 """ 571 Test scalar quantization 572 """ 573 574 # Index data with 1-bit scalar quantization 575 embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2", "quantize": 1, "content": self.backend}) 576 embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 577 578 # Search for best match 579 result = self.embeddings.search("feel good story", 1)[0] 580 self.assertEqual(result["text"], self.data[4]) 581 582 def testQueryModel(self): 583 """ 584 Test index 585 """ 586 587 embeddings = Embeddings( 588 {"path": "sentence-transformers/nli-mpnet-base-v2", "content": self.backend, "query": {"path": "neuml/t5-small-txtsql"}} 589 ) 590 591 # Create an index for the list of text 592 embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 593 594 # Search for best match 595 result = embeddings.search("feel good story with win in text", 1)[0] 596 597 self.assertEqual(result["text"], self.data[4]) 598 599 def testReindex(self): 600 """ 601 Test reindex 602 """ 603 604 # Create an index for the list of text 605 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 606 607 # Delete records to test indexids still match 608 self.embeddings.delete(([0, 1])) 609 610 # Reindex 611 self.embeddings.reindex({"path": "sentence-transformers/nli-mpnet-base-v2"}) 612 613 # Search for best match 614 result = self.embeddings.search("feel good story", 1)[0] 615 616 self.assertEqual(result["text"], self.data[4]) 617 618 def testSave(self): 619 """ 620 Test save 621 """ 622 623 # Create an index for the list of text 624 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 625 626 # Generate temp file path 627 index = os.path.join(tempfile.gettempdir(), f"embeddings.{self.category()}") 628 629 self.embeddings.save(index) 630 self.embeddings.load(index) 631 632 # Search for best match 633 result = self.embeddings.search("feel good story", 1)[0] 634 635 self.assertEqual(result["text"], self.data[4]) 636 637 # Test offsets still work after save/load 638 self.embeddings.upsert([(0, "Looking out into the dreadful abyss", None)]) 639 self.assertEqual(self.embeddings.count(), len(self.data)) 640 641 def testSettings(self): 642 """ 643 Test custom SQLite settings 644 """ 645 646 # Index with write-ahead logging enabled 647 embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2", "content": self.backend, "sqlite": {"wal": True}}) 648 649 # Create an index for the list of text 650 embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 651 652 # Search for best match 653 result = embeddings.search("feel good story", 1)[0] 654 655 self.assertEqual(result["text"], self.data[4]) 656 657 def testSQL(self): 658 """ 659 Test running a SQL query 660 """ 661 662 # Create an index for the list of text 663 self.embeddings.index([(uid, {"text": text, "length": len(text), "attribute": f"ID{uid}"}, None) for uid, text in enumerate(self.data)]) 664 665 # Test similar 666 result = self.embeddings.search( 667 "select text, score from txtai where similar('feel good story') group by text, score having count(*) > 0 order by score desc", 1 668 )[0] 669 self.assertEqual(result["text"], self.data[4]) 670 671 # Test similar with limits 672 result = self.embeddings.search("select * from txtai where similar('feel good story', 1) limit 1")[0] 673 self.assertEqual(result["text"], self.data[4]) 674 675 # Test similar with offset 676 result = self.embeddings.search("select * from txtai where similar('feel good story') offset 1")[0] 677 self.assertEqual(result["text"], self.data[5]) 678 679 # Test where 680 result = self.embeddings.search("select * from txtai where text like '%iceberg%'", 1)[0] 681 self.assertEqual(result["text"], self.data[1]) 682 683 # Test count 684 result = self.embeddings.search("select count(*) from txtai")[0] 685 self.assertEqual(list(result.values())[0], len(self.data)) 686 687 # Test columns 688 result = self.embeddings.search("select id, text, length, data, entry from txtai")[0] 689 self.assertEqual(sorted(result.keys()), ["data", "entry", "id", "length", "text"]) 690 691 # Test column filtering 692 result = self.embeddings.search("select text from txtai where attribute = 'ID4'", 1)[0] 693 self.assertEqual(result["text"], self.data[4]) 694 695 # Test SQL parse error 696 with self.assertRaises(SQLError): 697 self.embeddings.search("select * from txtai where bad,query") 698 699 def testSQLBind(self): 700 """ 701 Test SQL statements with bind parameters 702 """ 703 704 # Create an index for the list of text 705 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 706 707 # Test similar clause bind parameters 708 result = self.embeddings.search("select id, text, score from txtai where similar(:x)", parameters={"x": "feel good story"})[0] 709 self.assertEqual(result["text"], self.data[4]) 710 711 # Test similar clause bind and non-bind parameters 712 result = self.embeddings.search("select id, text, score from txtai where similar(:x, 0.5)", parameters={"x": "feel good story"})[0] 713 self.assertEqual(result["text"], self.data[4]) 714 715 # Test where filtering with bind parameters 716 result = self.embeddings.search("select * from txtai where text like :x", parameters={"x": "%iceberg%"})[0] 717 self.assertEqual(result["text"], self.data[1]) 718 719 def testSparse(self): 720 """ 721 Test sparse vector search 722 """ 723 724 # Build data array 725 data = [(uid, text, None) for uid, text in enumerate(self.data)] 726 727 # Index data with sparse vectors 728 embeddings = Embeddings({"sparse": "sparse-encoder-testing/splade-bert-tiny-nq", "content": self.backend}) 729 embeddings.index(data) 730 731 # Run search 732 result = embeddings.search("lottery ticket", 1)[0] 733 self.assertEqual(result["text"], data[4][1]) 734 735 # Test count method 736 self.assertEqual(embeddings.count(), len(data)) 737 738 # Generate temp file path 739 index = os.path.join(tempfile.gettempdir(), f"embeddings.{self.category()}.sparse") 740 741 # Test load/save 742 embeddings.save(index) 743 embeddings.load(index) 744 745 # Run search 746 result = embeddings.search("lottery ticket", 1)[0] 747 self.assertEqual(result["text"], data[4][1]) 748 749 # Update data 750 data[0] = (0, "Feel good story: baby panda born", None) 751 embeddings.upsert([data[0]]) 752 753 # Search for best match 754 result = embeddings.search("feel good story", 1)[0] 755 self.assertEqual(result["text"], data[0][1]) 756 757 def testSubindex(self): 758 """ 759 Test subindex 760 """ 761 762 # Build data array 763 data = [(uid, text, None) for uid, text in enumerate(self.data)] 764 765 # Disable top-level indexing and create subindex 766 embeddings = Embeddings( 767 {"content": self.backend, "defaults": False, "indexes": {"index1": {"path": "sentence-transformers/nli-mpnet-base-v2"}}} 768 ) 769 embeddings.index(data) 770 771 # Test transform 772 self.assertEqual(embeddings.transform("feel good story").shape, (768,)) 773 774 # Run search 775 result = embeddings.search("feel good story", 1)[0] 776 self.assertEqual(result["text"], data[4][1]) 777 778 # Run SQL search 779 result = embeddings.search("select id, text, score from txtai where similar('feel good story', 10, 0.5)")[0] 780 self.assertEqual(result["text"], data[4][1]) 781 782 # Test missing index 783 with self.assertRaises(IndexNotFoundError): 784 embeddings.search("select id, text, score from txtai where similar('feel good story', 'notindex')") 785 786 # Generate temp file path 787 index = os.path.join(tempfile.gettempdir(), f"embeddings.{self.category()}.subindex") 788 789 # Test load/save 790 embeddings.save(index) 791 embeddings.load(index) 792 793 # Run search 794 result = embeddings.search("feel good story", 1)[0] 795 self.assertEqual(result["text"], data[4][1]) 796 797 # Update data 798 data[0] = (0, "Feel good story: baby panda born", None) 799 embeddings.upsert([data[0]]) 800 801 # Search for best match 802 result = embeddings.search("feel good story", 1)[0] 803 self.assertEqual(result["text"], data[0][1]) 804 805 # Check missing text is set to id when top-level indexing is disabled 806 embeddings.upsert([(embeddings.count(), {"content": "empty text"}, None)]) 807 result = embeddings.search(f"{embeddings.count() - 1}", 1)[0] 808 self.assertEqual(result["text"], str(embeddings.count() - 1)) 809 810 # Close embeddings 811 embeddings.close() 812 813 def testSubindexEmpty(self): 814 """ 815 Test loading an empty subindex 816 """ 817 818 # Build data array 819 data = [(uid, {"column1": text}, None) for uid, text in enumerate(self.data)] 820 821 # Disable top-level indexing and create subindexes 822 embeddings = Embeddings( 823 { 824 "content": self.backend, 825 "defaults": False, 826 "indexes": { 827 "index1": {"path": "sentence-transformers/nli-mpnet-base-v2", "columns": {"text": "column1"}}, 828 "index2": {"path": "sentence-transformers/nli-mpnet-base-v2", "columns": {"text": "column2"}}, 829 }, 830 } 831 ) 832 embeddings.index(data) 833 834 # Generate temp file path 835 index = os.path.join(tempfile.gettempdir(), f"embeddings.{self.category()}.subindexempty") 836 837 # Save index 838 embeddings.save(index) 839 840 # Test exists 841 self.assertTrue(embeddings.exists(index)) 842 843 # Load index 844 embeddings.load(index) 845 846 # Test search 847 result = embeddings.search("feel good story", 1)[0] 848 self.assertEqual(result["text"], data[4][1]["text"]) 849 850 def testTerms(self): 851 """ 852 Test extracting keyword terms from query 853 """ 854 855 result = self.embeddings.terms("select * from txtai where similar('keyword terms')") 856 self.assertEqual(result, "keyword terms") 857 858 def testTruncate(self): 859 """ 860 Test dimensionality truncation 861 """ 862 863 # Truncate vectors to a specified number of dimensions 864 embeddings = Embeddings( 865 {"path": "sentence-transformers/nli-mpnet-base-v2", "dimensionality": 750, "content": self.backend, "vectors": {"revision": "main"}} 866 ) 867 embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 868 869 # Search for best match 870 result = self.embeddings.search("feel good story", 1)[0] 871 self.assertEqual(result["text"], self.data[4]) 872 873 def testUpsert(self): 874 """ 875 Test upsert 876 """ 877 878 # Build data array 879 data = [(uid, text, None) for uid, text in enumerate(self.data)] 880 881 # Reset embeddings for test 882 self.embeddings.ann = None 883 self.embeddings.database = None 884 885 # Create an index for the list of text 886 self.embeddings.upsert(data) 887 888 # Update data 889 data[0] = (0, "Feel good story: baby panda born", None) 890 self.embeddings.upsert([data[0]]) 891 892 # Search for best match 893 result = self.embeddings.search("feel good story", 1)[0] 894 self.assertEqual(result["text"], data[0][1]) 895 896 def testUpsertBatch(self): 897 """ 898 Test upsert batch 899 """ 900 901 try: 902 # Build data array 903 data = [(uid, text, None) for uid, text in enumerate(self.data)] 904 905 # Reset embeddings for test 906 self.embeddings.ann = None 907 self.embeddings.database = None 908 909 # Create an index for the list of text 910 self.embeddings.upsert(data) 911 912 # Set batch size to 1 913 self.embeddings.config["batch"] = 1 914 915 # Update data 916 data[0] = (0, "Feel good story: baby panda born", None) 917 data[1] = (0, "Not good news", None) 918 self.embeddings.upsert([data[0], data[1]]) 919 920 # Search for best match 921 result = self.embeddings.search("feel good story", 1)[0] 922 923 self.assertEqual(result["text"], data[0][1]) 924 finally: 925 del self.embeddings.config["batch"] 926 927 def category(self): 928 """ 929 Content backend category. 930 931 Returns: 932 category 933 """ 934 935 return self.__class__.__name__.lower().replace("test", "")