/ test / python / testvectors / testdense / testm2v.py
testm2v.py
 1  """
 2  Model2Vec module tests
 3  """
 4  
 5  import os
 6  import unittest
 7  
 8  import numpy as np
 9  
10  from txtai.vectors import VectorsFactory
11  
12  
13  class TestModel2Vec(unittest.TestCase):
14      """
15      Model2vec vectors tests
16      """
17  
18      @classmethod
19      def setUpClass(cls):
20          """
21          Create Model2Vec instance.
22          """
23  
24          cls.model = VectorsFactory.create({"path": "minishlab/potion-base-8M"}, None)
25  
26      def testIndex(self):
27          """
28          Test indexing with Model2Vec vectors
29          """
30  
31          ids, dimension, batches, stream = self.model.index([(0, "test", None)])
32  
33          self.assertEqual(len(ids), 1)
34          self.assertEqual(dimension, 256)
35          self.assertEqual(batches, 1)
36          self.assertIsNotNone(os.path.exists(stream))
37  
38          # Test shape of serialized embeddings
39          with open(stream, "rb") as queue:
40              self.assertEqual(np.load(queue).shape, (1, 256))