/ test / python / testcloud.py
testcloud.py
  1  """
  2  Cloud module tests
  3  """
  4  
  5  import os
  6  import tempfile
  7  import time
  8  import unittest
  9  
 10  from unittest.mock import patch
 11  
 12  from txtai.cloud import Cloud
 13  from txtai.embeddings import Embeddings
 14  
 15  
 16  class TestCloud(unittest.TestCase):
 17      """
 18      Cloud tests.
 19      """
 20  
 21      @classmethod
 22      def setUpClass(cls):
 23          """
 24          Initialize test data.
 25          """
 26  
 27          cls.data = [
 28              "US tops 5 million confirmed virus cases",
 29              "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg",
 30              "Beijing mobilises invasion craft along coast as Taiwan tensions escalate",
 31              "The National Park Service warns against sacrificing slower friends in a bear attack",
 32              "Maine man wins $1M from $25 lottery ticket",
 33              "Make huge profits without work, earn up to $100,000 a day",
 34          ]
 35  
 36          # Create embeddings model, backed by sentence-transformers & transformers
 37          cls.embeddings = Embeddings({"format": "json", "path": "sentence-transformers/nli-mpnet-base-v2", "content": True})
 38  
 39      @classmethod
 40      def tearDownClass(cls):
 41          """
 42          Cleanup data.
 43          """
 44  
 45          if cls.embeddings:
 46              cls.embeddings.close()
 47  
 48      def testCustom(self):
 49          """
 50          Test custom provider
 51          """
 52  
 53          # pylint: disable=E1120
 54          self.runHub("txtai.cloud.HuggingFaceHub")
 55  
 56      def testHub(self):
 57          """
 58          Test huggingface-hub integration
 59          """
 60  
 61          # pylint: disable=E1120
 62          self.runHub("huggingface-hub")
 63  
 64      def testInvalidProvider(self):
 65          """
 66          Test invalid provider identifier
 67          """
 68  
 69          # Test invalid external provider
 70          with self.assertRaises(ImportError):
 71              embeddings = Embeddings()
 72              embeddings.load(provider="ProviderNoExist", container="Invalid")
 73  
 74      def testNotImplemented(self):
 75          """
 76          Test exceptions for non-implemented methods
 77          """
 78  
 79          cloud = Cloud({})
 80  
 81          self.assertRaises(NotImplementedError, cloud.exists, None)
 82          self.assertRaises(NotImplementedError, cloud.load, None)
 83          self.assertRaises(NotImplementedError, cloud.save, None)
 84  
 85      def testObjectStorage(self):
 86          """
 87          Test object storage integration
 88          """
 89  
 90          # Run tests with uncompressed and compressed index
 91          for path in ["cloud.object", "cloud.object.tar.gz"]:
 92              self.runTests(path, {"provider": "local", "container": f"cloud.{time.time()}", "key": tempfile.gettempdir()})
 93  
 94      @patch("huggingface_hub.hf_hub_download")
 95      @patch("huggingface_hub.get_hf_file_metadata")
 96      @patch("huggingface_hub.upload_file")
 97      @patch("huggingface_hub.create_repo")
 98      def runHub(self, provider, create, upload, metadata, download):
 99          """
100          Run huggingface-hub tests. This method mocks write operations since a token won't be available.
101          """
102  
103          def filemeta(url, token):
104              return (url, token) if "Invalid" not in url else None
105  
106          def filedownload(**kwargs):
107              if "Invalid" in kwargs["repo_id"]:
108                  raise FileNotFoundError
109  
110              # Return either .gitattributes file or index
111              return attributes if kwargs["filename"] == ".gitattributes" else index
112  
113          # Patch write methods since token will not be available
114          create.return_value = None
115          upload.return_value = None
116          metadata.side_effect = filemeta
117          download.side_effect = filedownload
118  
119          # Create dummy index
120          self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
121  
122          # Generate temp file path
123          index = os.path.join(tempfile.gettempdir(), f"cloud.{provider}.tar.gz")
124          self.embeddings.save(index)
125  
126          # Initialize attributes file
127          # pylint: disable=R1732
128          with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp:
129              tmp.write("*.bin filter=lfs diff=lfs merge=lfs -text\n")
130              attributes = tmp.name
131  
132          # Run tests with uncompressed and compressed index
133          for path in [f"cloud.{provider}", f"cloud.{provider}.tar.gz"]:
134              self.runTests(path, {"provider": provider, "container": "neuml/txtai-intro"})
135  
136      def runTests(self, path, cloud):
137          """
138          Runs a series of cloud sync tests.
139          """
140  
141          # Create an index for the list of text
142          self.embeddings.index([(uid, text, None) for uid, text in enumerate(self.data)])
143  
144          # Generate temp file path
145          index = os.path.join(tempfile.gettempdir(), path)
146  
147          # Test exists handles missing cloud storage object
148          invalid = cloud.copy()
149          invalid["container"] = "InvalidPathToTest"
150          self.assertFalse(self.embeddings.exists(index, invalid))
151  
152          # Test exception raised when trying to load index and doesn't exist in cloud storage
153          # pylint: disable=W0719
154          with self.assertRaises(Exception):
155              self.embeddings.load(index, invalid)
156  
157          # Save index
158          self.embeddings.save(index, cloud)
159  
160          # Test object exists in cloud storage
161          self.assertTrue(self.embeddings.exists(index, cloud))
162  
163          # Test object exists locally
164          self.assertTrue(self.embeddings.exists(index))
165  
166          # Test index can be reloaded
167          self.embeddings.load(index, cloud)
168  
169          # Search for best match
170          result = self.embeddings.search("feel good story", 1)[0]
171          self.assertEqual(result["text"], self.data[4])