testencoder.py
1 """ 2 Test encoding/decoding database objects 3 """ 4 5 import glob 6 import os 7 import unittest 8 import tempfile 9 10 from unittest.mock import patch 11 12 from io import BytesIO 13 14 from PIL import Image 15 16 from txtai.embeddings import Embeddings 17 18 # pylint: disable=C0411 19 from utils import Utils 20 21 22 class TestEncoder(unittest.TestCase): 23 """ 24 Encoder tests. 25 """ 26 27 @classmethod 28 def setUpClass(cls): 29 """ 30 Initialize test data. 31 """ 32 33 cls.data = [] 34 for path in glob.glob(Utils.PATH + "/*jpg"): 35 cls.data.append((path, {"object": Image.open(path)}, None)) 36 37 # Create embeddings model, backed by sentence-transformers & transformers 38 cls.embeddings = Embeddings( 39 {"method": "sentence-transformers", "path": "sentence-transformers/clip-ViT-B-32", "content": True, "objects": "image"} 40 ) 41 42 @classmethod 43 def tearDownClass(cls): 44 """ 45 Cleanup data. 46 """ 47 48 if cls.embeddings: 49 cls.embeddings.close() 50 51 def testDefault(self): 52 """ 53 Test an index with default encoder 54 """ 55 56 try: 57 # Set default encoder 58 self.embeddings.config["objects"] = True 59 60 # Test all database providers 61 for content in ["duckdb", "sqlite"]: 62 self.embeddings.config["content"] = content 63 64 data = [(0, {"object": bytearray([1, 2, 3]), "text": "default test"}, None)] 65 66 # Create an index 67 self.embeddings.index(data) 68 69 result = self.embeddings.search("select object from txtai limit 1")[0] 70 71 self.assertEqual(result["object"].getvalue(), bytearray([1, 2, 3])) 72 finally: 73 self.embeddings.config["objects"] = "image" 74 self.embeddings.config["content"] = True 75 76 def testImages(self): 77 """ 78 Test an index with image encoder 79 """ 80 81 # Create an index for the list of images 82 self.embeddings.index(self.data) 83 84 result = self.embeddings.search("select id, object from txtai where similar('universe') limit 1")[0] 85 86 self.assertTrue(result["id"].endswith("stars.jpg")) 87 self.assertTrue(isinstance(result["object"], Image.Image)) 88 89 @patch.dict(os.environ, {"ALLOW_PICKLE": "True"}) 90 def testPickle(self): 91 """ 92 Test an index with pickle encoder 93 """ 94 95 try: 96 # Set pickle encoder 97 self.embeddings.config["objects"] = "pickle" 98 data = [(0, {"object": [1, 2, 3, 4, 5], "text": "default test"}, None)] 99 100 # Create an index 101 self.embeddings.index(data) 102 103 result = self.embeddings.search("select object from txtai limit 1")[0] 104 105 self.assertEqual(result["object"], [1, 2, 3, 4, 5]) 106 finally: 107 self.embeddings.config["objects"] = "image" 108 109 def testReindex(self): 110 """ 111 Test reindex with objects 112 """ 113 114 # Create an index for the list of images 115 self.embeddings.index(self.data) 116 117 # Reindex images 118 self.embeddings.reindex({"method": "sentence-transformers", "path": "sentence-transformers/clip-ViT-B-32"}) 119 120 result = self.embeddings.search("select id, object from txtai where similar('universe') limit 1")[0] 121 122 self.assertTrue(result["id"].endswith("stars.jpg")) 123 self.assertTrue(isinstance(result["object"], Image.Image)) 124 125 def testReindexFunction(self): 126 """ 127 Test reindex with objects and a function 128 """ 129 130 try: 131 # Streaming function that loads images on the fly 132 def prepare(documents): 133 for uid, data, tags in documents: 134 yield (uid, Image.open(data), tags) 135 136 # Create an index for the list of images 137 self.embeddings.index(self.data) 138 139 # Set default encoder and use function to load images 140 self.embeddings.config["objects"] = True 141 142 # Save and load index to force default encoder 143 index = os.path.join(tempfile.gettempdir(), "objects") 144 self.embeddings.save(index) 145 self.embeddings.load(index) 146 147 # Reindex images 148 self.embeddings.reindex({"method": "sentence-transformers", "path": "sentence-transformers/clip-ViT-B-32"}, function=prepare) 149 150 result = self.embeddings.search("select id, object from txtai where similar('universe') limit 1")[0] 151 152 self.assertTrue(result["id"].endswith("stars.jpg")) 153 self.assertTrue(isinstance(result["object"], BytesIO)) 154 finally: 155 self.embeddings.config["objects"] = "image"