testcloud.py
1 """ 2 Cloud module tests 3 """ 4 5 import os 6 import tempfile 7 import time 8 import unittest 9 10 from unittest.mock import patch 11 12 from txtai.cloud import Cloud 13 from txtai.embeddings import Embeddings 14 15 16 class TestCloud(unittest.TestCase): 17 """ 18 Cloud tests. 19 """ 20 21 @classmethod 22 def setUpClass(cls): 23 """ 24 Initialize test data. 25 """ 26 27 cls.data = [ 28 "US tops 5 million confirmed virus cases", 29 "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg", 30 "Beijing mobilises invasion craft along coast as Taiwan tensions escalate", 31 "The National Park Service warns against sacrificing slower friends in a bear attack", 32 "Maine man wins $1M from $25 lottery ticket", 33 "Make huge profits without work, earn up to $100,000 a day", 34 ] 35 36 # Create embeddings model, backed by sentence-transformers & transformers 37 cls.embeddings = Embeddings({"format": "json", "path": "sentence-transformers/nli-mpnet-base-v2", "content": True}) 38 39 @classmethod 40 def tearDownClass(cls): 41 """ 42 Cleanup data. 43 """ 44 45 if cls.embeddings: 46 cls.embeddings.close() 47 48 def testCustom(self): 49 """ 50 Test custom provider 51 """ 52 53 # pylint: disable=E1120 54 self.runHub("txtai.cloud.HuggingFaceHub") 55 56 def testHub(self): 57 """ 58 Test huggingface-hub integration 59 """ 60 61 # pylint: disable=E1120 62 self.runHub("huggingface-hub") 63 64 def testInvalidProvider(self): 65 """ 66 Test invalid provider identifier 67 """ 68 69 # Test invalid external provider 70 with self.assertRaises(ImportError): 71 embeddings = Embeddings() 72 embeddings.load(provider="ProviderNoExist", container="Invalid") 73 74 def testNotImplemented(self): 75 """ 76 Test exceptions for non-implemented methods 77 """ 78 79 cloud = Cloud({}) 80 81 self.assertRaises(NotImplementedError, cloud.exists, None) 82 self.assertRaises(NotImplementedError, cloud.load, None) 83 self.assertRaises(NotImplementedError, cloud.save, None) 84 85 def testObjectStorage(self): 86 """ 87 Test object storage integration 88 """ 89 90 # Run tests with uncompressed and compressed index 91 for path in ["cloud.object", "cloud.object.tar.gz"]: 92 self.runTests(path, {"provider": "local", "container": f"cloud.{time.time()}", "key": tempfile.gettempdir()}) 93 94 @patch("huggingface_hub.hf_hub_download") 95 @patch("huggingface_hub.get_hf_file_metadata") 96 @patch("huggingface_hub.upload_file") 97 @patch("huggingface_hub.create_repo") 98 def runHub(self, provider, create, upload, metadata, download): 99 """ 100 Run huggingface-hub tests. This method mocks write operations since a token won't be available. 101 """ 102 103 def filemeta(url, token): 104 return (url, token) if "Invalid" not in url else None 105 106 def filedownload(**kwargs): 107 if "Invalid" in kwargs["repo_id"]: 108 raise FileNotFoundError 109 110 # Return either .gitattributes file or index 111 return attributes if kwargs["filename"] == ".gitattributes" else index 112 113 # Patch write methods since token will not be available 114 create.return_value = None 115 upload.return_value = None 116 metadata.side_effect = filemeta 117 download.side_effect = filedownload 118 119 # Create dummy index 120 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 121 122 # Generate temp file path 123 index = os.path.join(tempfile.gettempdir(), f"cloud.{provider}.tar.gz") 124 self.embeddings.save(index) 125 126 # Initialize attributes file 127 # pylint: disable=R1732 128 with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp: 129 tmp.write("*.bin filter=lfs diff=lfs merge=lfs -text\n") 130 attributes = tmp.name 131 132 # Run tests with uncompressed and compressed index 133 for path in [f"cloud.{provider}", f"cloud.{provider}.tar.gz"]: 134 self.runTests(path, {"provider": provider, "container": "neuml/txtai-intro"}) 135 136 def runTests(self, path, cloud): 137 """ 138 Runs a series of cloud sync tests. 139 """ 140 141 # Create an index for the list of text 142 self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)]) 143 144 # Generate temp file path 145 index = os.path.join(tempfile.gettempdir(), path) 146 147 # Test exists handles missing cloud storage object 148 invalid = cloud.copy() 149 invalid["container"] = "InvalidPathToTest" 150 self.assertFalse(self.embeddings.exists(index, invalid)) 151 152 # Test exception raised when trying to load index and doesn't exist in cloud storage 153 # pylint: disable=W0719 154 with self.assertRaises(Exception): 155 self.embeddings.load(index, invalid) 156 157 # Save index 158 self.embeddings.save(index, cloud) 159 160 # Test object exists in cloud storage 161 self.assertTrue(self.embeddings.exists(index, cloud)) 162 163 # Test object exists locally 164 self.assertTrue(self.embeddings.exists(index)) 165 166 # Test index can be reloaded 167 self.embeddings.load(index, cloud) 168 169 # Search for best match 170 result = self.embeddings.search("feel good story", 1)[0] 171 self.assertEqual(result["text"], self.data[4])