/ test / python / testvectors / testdense / testwordvectors.py
testwordvectors.py
  1  """
  2  WordVectors module tests
  3  """
  4  
  5  import os
  6  import tempfile
  7  import unittest
  8  
  9  from unittest.mock import patch
 10  
 11  import numpy as np
 12  
 13  from huggingface_hub.errors import HFValidationError
 14  from txtai.vectors import VectorsFactory
 15  from txtai.vectors.dense.words import create, transform
 16  
 17  
 18  class TestWordVectors(unittest.TestCase):
 19      """
 20      Vectors tests.
 21      """
 22  
 23      @classmethod
 24      def setUpClass(cls):
 25          """
 26          Sets the pretrained model to use
 27          """
 28  
 29          # Test with pretrained glove quantized vectors
 30          cls.path = "neuml/glove-6B-quantized"
 31  
 32      @patch("os.cpu_count")
 33      def testIndex(self, cpucount):
 34          """
 35          Test word vectors indexing
 36          """
 37  
 38          # Mock CPU count
 39          cpucount.return_value = 1
 40  
 41          # Generate data
 42          documents = [(x, "This is a test", None) for x in range(1000)]
 43  
 44          model = VectorsFactory.create({"path": self.path, "parallel": True}, None)
 45  
 46          ids, dimension, batches, stream = model.index(documents, 1)
 47  
 48          self.assertEqual(len(ids), 1000)
 49          self.assertEqual(dimension, 300)
 50          self.assertEqual(batches, 1000)
 51          self.assertIsNotNone(os.path.exists(stream))
 52  
 53          # Test shape of serialized embeddings
 54          with open(stream, "rb") as queue:
 55              self.assertEqual(np.load(queue).shape, (1, 300))
 56  
 57      @patch("os.cpu_count")
 58      def testIndexBatch(self, cpucount):
 59          """
 60          Test word vectors indexing with batch size set
 61          """
 62  
 63          # Mock CPU count
 64          cpucount.return_value = 1
 65  
 66          # Generate data
 67          documents = [(x, "This is a test", None) for x in range(1000)]
 68  
 69          model = VectorsFactory.create({"path": self.path, "parallel": True}, None)
 70  
 71          ids, dimension, batches, stream = model.index(documents, 512)
 72  
 73          self.assertEqual(len(ids), 1000)
 74          self.assertEqual(dimension, 300)
 75          self.assertEqual(batches, 2)
 76          self.assertIsNotNone(os.path.exists(stream))
 77  
 78          # Test shape of serialized embeddings
 79          with open(stream, "rb") as queue:
 80              self.assertEqual(np.load(queue).shape, (512, 300))
 81              self.assertEqual(np.load(queue).shape, (488, 300))
 82  
 83      def testIndexSerial(self):
 84          """
 85          Test word vector indexing in single process mode
 86          """
 87  
 88          # Generate data
 89          documents = [(x, "This is a test", None) for x in range(1000)]
 90  
 91          model = VectorsFactory.create({"path": self.path, "parallel": False}, None)
 92  
 93          ids, dimension, batches, stream = model.index(documents, 1)
 94  
 95          self.assertEqual(len(ids), 1000)
 96          self.assertEqual(dimension, 300)
 97          self.assertEqual(batches, 1000)
 98          self.assertIsNotNone(os.path.exists(stream))
 99  
100          # Test shape of serialized embeddings
101          with open(stream, "rb") as queue:
102              self.assertEqual(np.load(queue).shape, (1, 300))
103  
104      def testIndexSerialBatch(self):
105          """
106          Test word vector indexing in single process mode with batch size set
107          """
108  
109          # Generate data
110          documents = [(x, "This is a test", None) for x in range(1000)]
111  
112          model = VectorsFactory.create({"path": self.path, "parallel": False}, None)
113  
114          ids, dimension, batches, stream = model.index(documents, 512)
115  
116          self.assertEqual(len(ids), 1000)
117          self.assertEqual(dimension, 300)
118          self.assertEqual(batches, 2)
119          self.assertIsNotNone(os.path.exists(stream))
120  
121          # Test shape of serialized embeddings
122          with open(stream, "rb") as queue:
123              self.assertEqual(np.load(queue).shape, (512, 300))
124              self.assertEqual(np.load(queue).shape, (488, 300))
125  
126      def testLookup(self):
127          """
128          Test word vector lookup
129          """
130  
131          model = VectorsFactory.create({"path": self.path}, None)
132          self.assertEqual(model.lookup(["txtai", "embeddings", "sentence"]).shape, (3, 300))
133  
134      def testMultiprocess(self):
135          """
136          Test multiprocess helper methods
137          """
138  
139          create({"path": self.path}, None)
140  
141          uid, vector = transform((0, "test", None))
142          self.assertEqual(uid, 0)
143          self.assertEqual(vector.shape, (300,))
144  
145      def testNoExist(self):
146          """
147          Test loading model that doesn't exist
148          """
149  
150          # Test non-existent path raises an exception
151          with self.assertRaises((IOError, HFValidationError)):
152              VectorsFactory.create({"method": "words", "path": os.path.join(tempfile.gettempdir(), "noexist")}, None)
153  
154      def testTransform(self):
155          """
156          Test word vector transform
157          """
158  
159          model = VectorsFactory.create({"path": self.path}, None)
160          self.assertEqual(len(model.transform((None, ["txtai"], None))), 300)