/ test / python / testdatabase / testencoder.py
testencoder.py
  1  """
  2  Test encoding/decoding database objects
  3  """
  4  
  5  import glob
  6  import os
  7  import unittest
  8  import tempfile
  9  
 10  from unittest.mock import patch
 11  
 12  from io import BytesIO
 13  
 14  from PIL import Image
 15  
 16  from txtai.embeddings import Embeddings
 17  
 18  # pylint: disable=C0411
 19  from utils import Utils
 20  
 21  
 22  class TestEncoder(unittest.TestCase):
 23      """
 24      Encoder tests.
 25      """
 26  
 27      @classmethod
 28      def setUpClass(cls):
 29          """
 30          Initialize test data.
 31          """
 32  
 33          cls.data = []
 34          for path in glob.glob(Utils.PATH + "/*jpg"):
 35              cls.data.append((path, {"object": Image.open(path)}, None))
 36  
 37          # Create embeddings model, backed by sentence-transformers & transformers
 38          cls.embeddings = Embeddings(
 39              {"method": "sentence-transformers", "path": "sentence-transformers/clip-ViT-B-32", "content": True, "objects": "image"}
 40          )
 41  
 42      @classmethod
 43      def tearDownClass(cls):
 44          """
 45          Cleanup data.
 46          """
 47  
 48          if cls.embeddings:
 49              cls.embeddings.close()
 50  
 51      def testDefault(self):
 52          """
 53          Test an index with default encoder
 54          """
 55  
 56          try:
 57              # Set default encoder
 58              self.embeddings.config["objects"] = True
 59  
 60              # Test all database providers
 61              for content in ["duckdb", "sqlite"]:
 62                  self.embeddings.config["content"] = content
 63  
 64                  data = [(0, {"object": bytearray([1, 2, 3]), "text": "default test"}, None)]
 65  
 66                  # Create an index
 67                  self.embeddings.index(data)
 68  
 69                  result = self.embeddings.search("select object from txtai limit 1")[0]
 70  
 71                  self.assertEqual(result["object"].getvalue(), bytearray([1, 2, 3]))
 72          finally:
 73              self.embeddings.config["objects"] = "image"
 74              self.embeddings.config["content"] = True
 75  
 76      def testImages(self):
 77          """
 78          Test an index with image encoder
 79          """
 80  
 81          # Create an index for the list of images
 82          self.embeddings.index(self.data)
 83  
 84          result = self.embeddings.search("select id, object from txtai where similar('universe') limit 1")[0]
 85  
 86          self.assertTrue(result["id"].endswith("stars.jpg"))
 87          self.assertTrue(isinstance(result["object"], Image.Image))
 88  
 89      @patch.dict(os.environ, {"ALLOW_PICKLE": "True"})
 90      def testPickle(self):
 91          """
 92          Test an index with pickle encoder
 93          """
 94  
 95          try:
 96              # Set pickle encoder
 97              self.embeddings.config["objects"] = "pickle"
 98              data = [(0, {"object": [1, 2, 3, 4, 5], "text": "default test"}, None)]
 99  
100              # Create an index
101              self.embeddings.index(data)
102  
103              result = self.embeddings.search("select object from txtai limit 1")[0]
104  
105              self.assertEqual(result["object"], [1, 2, 3, 4, 5])
106          finally:
107              self.embeddings.config["objects"] = "image"
108  
109      def testReindex(self):
110          """
111          Test reindex with objects
112          """
113  
114          # Create an index for the list of images
115          self.embeddings.index(self.data)
116  
117          # Reindex images
118          self.embeddings.reindex({"method": "sentence-transformers", "path": "sentence-transformers/clip-ViT-B-32"})
119  
120          result = self.embeddings.search("select id, object from txtai where similar('universe') limit 1")[0]
121  
122          self.assertTrue(result["id"].endswith("stars.jpg"))
123          self.assertTrue(isinstance(result["object"], Image.Image))
124  
125      def testReindexFunction(self):
126          """
127          Test reindex with objects and a function
128          """
129  
130          try:
131              # Streaming function that loads images on the fly
132              def prepare(documents):
133                  for uid, data, tags in documents:
134                      yield (uid, Image.open(data), tags)
135  
136              # Create an index for the list of images
137              self.embeddings.index(self.data)
138  
139              # Set default encoder and use function to load images
140              self.embeddings.config["objects"] = True
141  
142              # Save and load index to force default encoder
143              index = os.path.join(tempfile.gettempdir(), "objects")
144              self.embeddings.save(index)
145              self.embeddings.load(index)
146  
147              # Reindex images
148              self.embeddings.reindex({"method": "sentence-transformers", "path": "sentence-transformers/clip-ViT-B-32"}, function=prepare)
149  
150              result = self.embeddings.search("select id, object from txtai where similar('universe') limit 1")[0]
151  
152              self.assertTrue(result["id"].endswith("stars.jpg"))
153              self.assertTrue(isinstance(result["object"], BytesIO))
154          finally:
155              self.embeddings.config["objects"] = "image"