/ test / python / testvectors / testdense / testsbert.py
testsbert.py
 1  """
 2  Sentence Transformers module tests
 3  """
 4  
 5  import os
 6  import platform
 7  import unittest
 8  
 9  from unittest.mock import patch
10  
11  import numpy as np
12  
13  from txtai.vectors import VectorsFactory
14  
15  
16  class TestSTVectors(unittest.TestCase):
17      """
18      STVectors tests
19      """
20  
21      def testIndex(self):
22          """
23          Test indexing with sentence-transformers vectors
24          """
25  
26          model = VectorsFactory.create({"method": "sentence-transformers", "path": "paraphrase-MiniLM-L3-v2"}, None)
27          ids, dimension, batches, stream = model.index([(0, "test", None)])
28  
29          self.assertEqual(len(ids), 1)
30          self.assertEqual(dimension, 384)
31          self.assertEqual(batches, 1)
32          self.assertIsNotNone(os.path.exists(stream))
33  
34          # Test shape of serialized embeddings
35          with open(stream, "rb") as queue:
36              self.assertEqual(np.load(queue).shape, (1, 384))
37  
38      @unittest.skipIf(platform.system() == "Darwin", "Torch memory sharing not supported on macOS")
39      @patch("torch.cuda.device_count")
40      def testMultiGPU(self, count):
41          """
42          Test multiple gpu encoding
43          """
44  
45          # Mock accelerator count
46          count.return_value = 2
47  
48          model = VectorsFactory.create({"method": "sentence-transformers", "path": "paraphrase-MiniLM-L3-v2", "gpu": "all"}, None)
49          ids, dimension, batches, stream = model.index([(0, "test", None)])
50  
51          self.assertEqual(len(ids), 1)
52          self.assertEqual(dimension, 384)
53          self.assertEqual(batches, 1)
54          self.assertIsNotNone(os.path.exists(stream))
55  
56          # Test shape of serialized embeddings
57          with open(stream, "rb") as queue:
58              self.assertEqual(np.load(queue).shape, (1, 384))
59  
60          # Close the multiprocessing pool
61          model.close()