/ test / python / testdatabase / testsqlite.py
testsqlite.py
 1  """
 2  SQLite module tests
 3  """
 4  
 5  from txtai.embeddings import Embeddings
 6  
 7  from .testrdbms import Common
 8  
 9  
10  # pylint: disable=R0904
11  class TestSQLite(Common.TestRDBMS):
12      """
13      Embeddings with content stored in SQLite tests.
14      """
15  
16      @classmethod
17      def setUpClass(cls):
18          """
19          Initialize test data.
20          """
21  
22          cls.data = [
23              "US tops 5 million confirmed virus cases",
24              "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg",
25              "Beijing mobilises invasion craft along coast as Taiwan tensions escalate",
26              "The National Park Service warns against sacrificing slower friends in a bear attack",
27              "Maine man wins $1M from $25 lottery ticket",
28              "Make huge profits without work, earn up to $100,000 a day",
29          ]
30  
31          # Content backend
32          cls.backend = "sqlite"
33  
34          # Create embeddings model, backed by sentence-transformers & transformers
35          cls.embeddings = Embeddings({"path": "sentence-transformers/nli-mpnet-base-v2", "content": cls.backend})
36  
37      @classmethod
38      def tearDownClass(cls):
39          """
40          Cleanup data.
41          """
42  
43          if cls.embeddings:
44              cls.embeddings.close()
45  
46      def testFunction(self):
47          """
48          Test custom functions
49          """
50  
51          embeddings = Embeddings(
52              {
53                  "path": "sentence-transformers/nli-mpnet-base-v2",
54                  "content": self.backend,
55                  "functions": [{"name": "textlength", "function": "testdatabase.testsqlite.length"}],
56              }
57          )
58  
59          # Create an index for the list of text
60          embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
61  
62          # Search for best match
63          result = embeddings.search("select textlength(text) length from txtai where id = 0", 1)[0]
64  
65          self.assertEqual(int(result["length"]), 39)
66  
67  
68  def length(text):
69      """
70      Custom SQL function.
71      """
72  
73      return len(text)