testsbert.py
1 """ 2 Sentence Transformers module tests 3 """ 4 5 import os 6 import platform 7 import unittest 8 9 from unittest.mock import patch 10 11 import numpy as np 12 13 from txtai.vectors import VectorsFactory 14 15 16 class TestSTVectors(unittest.TestCase): 17 """ 18 STVectors tests 19 """ 20 21 def testIndex(self): 22 """ 23 Test indexing with sentence-transformers vectors 24 """ 25 26 model = VectorsFactory.create({"method": "sentence-transformers", "path": "paraphrase-MiniLM-L3-v2"}, None) 27 ids, dimension, batches, stream = model.index([(0, "test", None)]) 28 29 self.assertEqual(len(ids), 1) 30 self.assertEqual(dimension, 384) 31 self.assertEqual(batches, 1) 32 self.assertIsNotNone(os.path.exists(stream)) 33 34 # Test shape of serialized embeddings 35 with open(stream, "rb") as queue: 36 self.assertEqual(np.load(queue).shape, (1, 384)) 37 38 @unittest.skipIf(platform.system() == "Darwin", "Torch memory sharing not supported on macOS") 39 @patch("torch.cuda.device_count") 40 def testMultiGPU(self, count): 41 """ 42 Test multiple gpu encoding 43 """ 44 45 # Mock accelerator count 46 count.return_value = 2 47 48 model = VectorsFactory.create({"method": "sentence-transformers", "path": "paraphrase-MiniLM-L3-v2", "gpu": "all"}, None) 49 ids, dimension, batches, stream = model.index([(0, "test", None)]) 50 51 self.assertEqual(len(ids), 1) 52 self.assertEqual(dimension, 384) 53 self.assertEqual(batches, 1) 54 self.assertIsNotNone(os.path.exists(stream)) 55 56 # Test shape of serialized embeddings 57 with open(stream, "rb") as queue: 58 self.assertEqual(np.load(queue).shape, (1, 384)) 59 60 # Close the multiprocessing pool 61 model.close()