testm2v.py
1 """ 2 Model2Vec module tests 3 """ 4 5 import os 6 import unittest 7 8 import numpy as np 9 10 from txtai.vectors import VectorsFactory 11 12 13 class TestModel2Vec(unittest.TestCase): 14 """ 15 Model2vec vectors tests 16 """ 17 18 @classmethod 19 def setUpClass(cls): 20 """ 21 Create Model2Vec instance. 22 """ 23 24 cls.model = VectorsFactory.create({"path": "minishlab/potion-base-8M"}, None) 25 26 def testIndex(self): 27 """ 28 Test indexing with Model2Vec vectors 29 """ 30 31 ids, dimension, batches, stream = self.model.index([(0, "test", None)]) 32 33 self.assertEqual(len(ids), 1) 34 self.assertEqual(dimension, 256) 35 self.assertEqual(batches, 1) 36 self.assertIsNotNone(os.path.exists(stream)) 37 38 # Test shape of serialized embeddings 39 with open(stream, "rb") as queue: 40 self.assertEqual(np.load(queue).shape, (1, 256))