testduckdb.py
1 """ 2 DuckDB module tests 3 """ 4 5 import os 6 import unittest 7 8 from txtai.embeddings import Embeddings 9 10 from .testrdbms import Common 11 12 13 # pylint: disable=R0904 14 class TestDuckDB(Common.TestRDBMS): 15 """ 16 Embeddings with content stored in DuckDB. 17 """ 18 19 @classmethod 20 def setUpClass(cls): 21 """ 22 Initialize test data. 23 """ 24 25 cls.data = [ 26 "US tops 5 million confirmed virus cases", 27 "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg", 28 "Beijing mobilises invasion craft along coast as Taiwan tensions escalate", 29 "The National Park Service warns against sacrificing slower friends in a bear attack", 30 "Maine man wins $1M from $25 lottery ticket", 31 "Make huge profits without work, earn up to $100,000 a day", 32 ] 33 34 # Content backend 35 cls.backend = "duckdb" 36 37 # Create embeddings model, backed by sentence-transformers & transformers 38 cls.embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2", "content": cls.backend}) 39 40 @classmethod 41 def tearDownClass(cls): 42 """ 43 Cleanup data. 44 """ 45 46 if cls.embeddings: 47 cls.embeddings.close() 48 49 @unittest.skipIf(os.name == "nt", "testArchive skipped on Windows") 50 def testArchive(self): 51 """ 52 Test embeddings index archiving 53 """ 54 55 super().testArchive() 56 57 def testFunction(self): 58 """ 59 Test custom functions 60 """ 61 62 embeddings = Embeddings( 63 { 64 "path": "sentence-transformers/nli-mpnet-base-v2", 65 "content": self.backend, 66 "functions": [{"name": "textlength", "function": "testdatabase.testduckdb.length"}], 67 } 68 ) 69 70 # Create an index for the list of text 71 embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 72 73 # Search for best match 74 result = embeddings.search("select textlength(text) length from txtai where id = 0", 1)[0] 75 76 self.assertEqual(int(result["length"]), 39) 77 78 79 def length(text): 80 """ 81 Custom SQL function. 82 """ 83 84 return len(text)