testpooling.py
1 """ 2 Pooling module tests 3 """ 4 5 import unittest 6 7 from txtai.models import Models, ClsPooling, LastPooling, MeanPooling, PoolingFactory 8 9 10 class TestPooling(unittest.TestCase): 11 """ 12 Pooling tests. 13 """ 14 15 @classmethod 16 def setUpClass(cls): 17 """ 18 Initialize device 19 """ 20 21 # Device id 22 cls.device = Models.deviceid(True) 23 24 def testCLS(self): 25 """ 26 Test CLS pooling 27 """ 28 29 # Test CLS pooling 30 pooling = PoolingFactory.create({"path": "flax-sentence-embeddings/multi-qa_v1-MiniLM-L6-cls_dot", "device": self.device}) 31 self.assertEqual(type(pooling), ClsPooling) 32 33 pooling = PoolingFactory.create({"method": "clspooling", "path": "sentence-transformers/nli-mpnet-base-v2", "device": self.device}) 34 self.assertEqual(type(pooling), ClsPooling) 35 36 # Test CLS pooling encoding 37 self.assertEqual(pooling.encode(["test"])[0].shape, (768,)) 38 39 def testLast(self): 40 """ 41 Test last pooling 42 """ 43 44 # Test last pooling 45 pooling = PoolingFactory.create({"path": "neuml/bert-tiny-sts-last-pooling", "device": self.device}) 46 self.assertEqual(type(pooling), LastPooling) 47 48 pooling = PoolingFactory.create({"method": "lastpooling", "path": "sentence-transformers/nli-mpnet-base-v2", "device": self.device}) 49 self.assertEqual(type(pooling), LastPooling) 50 51 # Test last pooling encoding 52 self.assertEqual(pooling.encode(["test"])[0].shape, (768,)) 53 54 def testLength(self): 55 """ 56 Test pooling with max_seq_length 57 """ 58 59 # Test reading max_seq_length parmaeter 60 pooling = PoolingFactory.create({"path": "sentence-transformers/nli-mpnet-base-v2", "device": self.device, "maxlength": True}) 61 self.assertEqual(pooling.maxlength, 75) 62 63 # Test specified maxlength 64 pooling = PoolingFactory.create({"path": "sentence-transformers/nli-mpnet-base-v2", "device": self.device, "maxlength": 256}) 65 self.assertEqual(pooling.maxlength, 256) 66 67 # Test max_seq_length is ignored when parameter is omitted 68 pooling = PoolingFactory.create({"path": "sentence-transformers/nli-mpnet-base-v2", "device": self.device}) 69 self.assertEqual(pooling.maxlength, 512) 70 71 # Test maxlength when max_seq_length not present 72 pooling = PoolingFactory.create({"path": "hf-internal-testing/tiny-random-gpt2", "device": self.device, "maxlength": True}) 73 self.assertEqual(pooling.maxlength, 1024) 74 75 def testMean(self): 76 """ 77 Test mean pooling 78 """ 79 80 # Test mean pooling 81 pooling = PoolingFactory.create({"path": "sentence-transformers/nli-mpnet-base-v2", "device": self.device}) 82 self.assertEqual(type(pooling), MeanPooling) 83 84 pooling = PoolingFactory.create( 85 {"method": "meanpooling", "path": "flax-sentence-embeddings/multi-qa_v1-MiniLM-L6-cls_dot", "device": self.device} 86 ) 87 self.assertEqual(type(pooling), MeanPooling) 88 89 def testMuvera(self): 90 """ 91 Test late pooling with MUVERA fixed dimensional encoding 92 """ 93 94 # Test MUVERA encoding 95 for model in ["neuml/colbert-bert-tiny", "neuml/pylate-bert-tiny"]: 96 # Test defaults 97 pooling = PoolingFactory.create({"path": model, "device": self.device}) 98 self.assertEqual(pooling.encode(["test"], category="query").shape, (1, 10240)) 99 100 # Test custom settings 101 pooling = PoolingFactory.create( 102 {"path": model, "device": self.device, "modelargs": {"muvera": {"repetitions": 5, "hashes": 2, "projection": 8}}} 103 ) 104 self.assertEqual(pooling.encode(["test"], category="data").shape, (1, 160)) 105 106 def testPrompts(self): 107 """ 108 Test instruction prompts 109 """ 110 111 # Load model with prompts 112 pooling = PoolingFactory.create({"path": "neuml/bert-tiny-prompts", "device": self.device, "loadprompts": True}) 113 114 # Test prompts are prepended 115 self.assertEqual(pooling.preencode(["abc"], "query")[0], "query: abc") 116 self.assertEqual(pooling.preencode(["text"], "data")[0], "document: text") 117 118 # Load model with prompts disabled (default) 119 pooling = PoolingFactory.create({"path": "neuml/bert-tiny-prompts", "device": self.device}) 120 121 # Test that prompts are not prepended 122 self.assertEqual(pooling.preencode(["abc"], "query")[0], "abc") 123 self.assertEqual(pooling.preencode(["text"], "data")[0], "text")