testwordvectors.py
1 """ 2 WordVectors module tests 3 """ 4 5 import os 6 import tempfile 7 import unittest 8 9 from unittest.mock import patch 10 11 import numpy as np 12 13 from huggingface_hub.errors import HFValidationError 14 from txtai.vectors import VectorsFactory 15 from txtai.vectors.dense.words import create, transform 16 17 18 class TestWordVectors(unittest.TestCase): 19 """ 20 Vectors tests. 21 """ 22 23 @classmethod 24 def setUpClass(cls): 25 """ 26 Sets the pretrained model to use 27 """ 28 29 # Test with pretrained glove quantized vectors 30 cls.path = "neuml/glove-6B-quantized" 31 32 @patch("os.cpu_count") 33 def testIndex(self, cpucount): 34 """ 35 Test word vectors indexing 36 """ 37 38 # Mock CPU count 39 cpucount.return_value = 1 40 41 # Generate data 42 documents = [(x, "This is a test", None) for x in range(1000)] 43 44 model = VectorsFactory.create({"path": self.path, "parallel": True}, None) 45 46 ids, dimension, batches, stream = model.index(documents, 1) 47 48 self.assertEqual(len(ids), 1000) 49 self.assertEqual(dimension, 300) 50 self.assertEqual(batches, 1000) 51 self.assertIsNotNone(os.path.exists(stream)) 52 53 # Test shape of serialized embeddings 54 with open(stream, "rb") as queue: 55 self.assertEqual(np.load(queue).shape, (1, 300)) 56 57 @patch("os.cpu_count") 58 def testIndexBatch(self, cpucount): 59 """ 60 Test word vectors indexing with batch size set 61 """ 62 63 # Mock CPU count 64 cpucount.return_value = 1 65 66 # Generate data 67 documents = [(x, "This is a test", None) for x in range(1000)] 68 69 model = VectorsFactory.create({"path": self.path, "parallel": True}, None) 70 71 ids, dimension, batches, stream = model.index(documents, 512) 72 73 self.assertEqual(len(ids), 1000) 74 self.assertEqual(dimension, 300) 75 self.assertEqual(batches, 2) 76 self.assertIsNotNone(os.path.exists(stream)) 77 78 # Test shape of serialized embeddings 79 with open(stream, "rb") as queue: 80 self.assertEqual(np.load(queue).shape, (512, 300)) 81 self.assertEqual(np.load(queue).shape, (488, 300)) 82 83 def testIndexSerial(self): 84 """ 85 Test word vector indexing in single process mode 86 """ 87 88 # Generate data 89 documents = [(x, "This is a test", None) for x in range(1000)] 90 91 model = VectorsFactory.create({"path": self.path, "parallel": False}, None) 92 93 ids, dimension, batches, stream = model.index(documents, 1) 94 95 self.assertEqual(len(ids), 1000) 96 self.assertEqual(dimension, 300) 97 self.assertEqual(batches, 1000) 98 self.assertIsNotNone(os.path.exists(stream)) 99 100 # Test shape of serialized embeddings 101 with open(stream, "rb") as queue: 102 self.assertEqual(np.load(queue).shape, (1, 300)) 103 104 def testIndexSerialBatch(self): 105 """ 106 Test word vector indexing in single process mode with batch size set 107 """ 108 109 # Generate data 110 documents = [(x, "This is a test", None) for x in range(1000)] 111 112 model = VectorsFactory.create({"path": self.path, "parallel": False}, None) 113 114 ids, dimension, batches, stream = model.index(documents, 512) 115 116 self.assertEqual(len(ids), 1000) 117 self.assertEqual(dimension, 300) 118 self.assertEqual(batches, 2) 119 self.assertIsNotNone(os.path.exists(stream)) 120 121 # Test shape of serialized embeddings 122 with open(stream, "rb") as queue: 123 self.assertEqual(np.load(queue).shape, (512, 300)) 124 self.assertEqual(np.load(queue).shape, (488, 300)) 125 126 def testLookup(self): 127 """ 128 Test word vector lookup 129 """ 130 131 model = VectorsFactory.create({"path": self.path}, None) 132 self.assertEqual(model.lookup(["txtai", "embeddings", "sentence"]).shape, (3, 300)) 133 134 def testMultiprocess(self): 135 """ 136 Test multiprocess helper methods 137 """ 138 139 create({"path": self.path}, None) 140 141 uid, vector = transform((0, "test", None)) 142 self.assertEqual(uid, 0) 143 self.assertEqual(vector.shape, (300,)) 144 145 def testNoExist(self): 146 """ 147 Test loading model that doesn't exist 148 """ 149 150 # Test non-existent path raises an exception 151 with self.assertRaises((IOError, HFValidationError)): 152 VectorsFactory.create({"method": "words", "path": os.path.join(tempfile.gettempdir(), "noexist")}, None) 153 154 def testTransform(self): 155 """ 156 Test word vector transform 157 """ 158 159 model = VectorsFactory.create({"path": self.path}, None) 160 self.assertEqual(len(model.transform((None, ["txtai"], None))), 300)