testsqlite.py
1 """ 2 SQLite module tests 3 """ 4 5 from txtai.embeddings import Embeddings 6 7 from .testrdbms import Common 8 9 10 # pylint: disable=R0904 11 class TestSQLite(Common.TestRDBMS): 12 """ 13 Embeddings with content stored in SQLite tests. 14 """ 15 16 @classmethod 17 def setUpClass(cls): 18 """ 19 Initialize test data. 20 """ 21 22 cls.data = [ 23 "US tops 5 million confirmed virus cases", 24 "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg", 25 "Beijing mobilises invasion craft along coast as Taiwan tensions escalate", 26 "The National Park Service warns against sacrificing slower friends in a bear attack", 27 "Maine man wins $1M from $25 lottery ticket", 28 "Make huge profits without work, earn up to $100,000 a day", 29 ] 30 31 # Content backend 32 cls.backend = "sqlite" 33 34 # Create embeddings model, backed by sentence-transformers & transformers 35 cls.embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2", "content": cls.backend}) 36 37 @classmethod 38 def tearDownClass(cls): 39 """ 40 Cleanup data. 41 """ 42 43 if cls.embeddings: 44 cls.embeddings.close() 45 46 def testFunction(self): 47 """ 48 Test custom functions 49 """ 50 51 embeddings = Embeddings( 52 { 53 "path": "sentence-transformers/nli-mpnet-base-v2", 54 "content": self.backend, 55 "functions": [{"name": "textlength", "function": "testdatabase.testsqlite.length"}], 56 } 57 ) 58 59 # Create an index for the list of text 60 embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 61 62 # Search for best match 63 result = embeddings.search("select textlength(text) length from txtai where id = 0", 1)[0] 64 65 self.assertEqual(int(result["length"]), 39) 66 67 68 def length(text): 69 """ 70 Custom SQL function. 71 """ 72 73 return len(text)