/ test / python / testmodels / testpooling.py
testpooling.py
  1  """
  2  Pooling module tests
  3  """
  4  
  5  import unittest
  6  
  7  from txtai.models import Models, ClsPooling, LastPooling, MeanPooling, PoolingFactory
  8  
  9  
 10  class TestPooling(unittest.TestCase):
 11      """
 12      Pooling tests.
 13      """
 14  
 15      @classmethod
 16      def setUpClass(cls):
 17          """
 18          Initialize device
 19          """
 20  
 21          # Device id
 22          cls.device = Models.deviceid(True)
 23  
 24      def testCLS(self):
 25          """
 26          Test CLS pooling
 27          """
 28  
 29          # Test CLS pooling
 30          pooling = PoolingFactory.create({"path": "flax-sentence-embeddings/multi-qa_v1-MiniLM-L6-cls_dot", "device": self.device})
 31          self.assertEqual(type(pooling), ClsPooling)
 32  
 33          pooling = PoolingFactory.create({"method": "clspooling", "path": "sentence-transformers/nli-mpnet-base-v2", "device": self.device})
 34          self.assertEqual(type(pooling), ClsPooling)
 35  
 36          # Test CLS pooling encoding
 37          self.assertEqual(pooling.encode(["test"])[0].shape, (768,))
 38  
 39      def testLast(self):
 40          """
 41          Test last pooling
 42          """
 43  
 44          # Test last pooling
 45          pooling = PoolingFactory.create({"path": "neuml/bert-tiny-sts-last-pooling", "device": self.device})
 46          self.assertEqual(type(pooling), LastPooling)
 47  
 48          pooling = PoolingFactory.create({"method": "lastpooling", "path": "sentence-transformers/nli-mpnet-base-v2", "device": self.device})
 49          self.assertEqual(type(pooling), LastPooling)
 50  
 51          # Test last pooling encoding
 52          self.assertEqual(pooling.encode(["test"])[0].shape, (768,))
 53  
 54      def testLength(self):
 55          """
 56          Test pooling with max_seq_length
 57          """
 58  
 59          # Test reading max_seq_length parmaeter
 60          pooling = PoolingFactory.create({"path": "sentence-transformers/nli-mpnet-base-v2", "device": self.device, "maxlength": True})
 61          self.assertEqual(pooling.maxlength, 75)
 62  
 63          # Test specified maxlength
 64          pooling = PoolingFactory.create({"path": "sentence-transformers/nli-mpnet-base-v2", "device": self.device, "maxlength": 256})
 65          self.assertEqual(pooling.maxlength, 256)
 66  
 67          # Test max_seq_length is ignored when parameter is omitted
 68          pooling = PoolingFactory.create({"path": "sentence-transformers/nli-mpnet-base-v2", "device": self.device})
 69          self.assertEqual(pooling.maxlength, 512)
 70  
 71          # Test maxlength when max_seq_length not present
 72          pooling = PoolingFactory.create({"path": "hf-internal-testing/tiny-random-gpt2", "device": self.device, "maxlength": True})
 73          self.assertEqual(pooling.maxlength, 1024)
 74  
 75      def testMean(self):
 76          """
 77          Test mean pooling
 78          """
 79  
 80          # Test mean pooling
 81          pooling = PoolingFactory.create({"path": "sentence-transformers/nli-mpnet-base-v2", "device": self.device})
 82          self.assertEqual(type(pooling), MeanPooling)
 83  
 84          pooling = PoolingFactory.create(
 85              {"method": "meanpooling", "path": "flax-sentence-embeddings/multi-qa_v1-MiniLM-L6-cls_dot", "device": self.device}
 86          )
 87          self.assertEqual(type(pooling), MeanPooling)
 88  
 89      def testMuvera(self):
 90          """
 91          Test late pooling with MUVERA fixed dimensional encoding
 92          """
 93  
 94          # Test MUVERA encoding
 95          for model in ["neuml/colbert-bert-tiny", "neuml/pylate-bert-tiny"]:
 96              # Test defaults
 97              pooling = PoolingFactory.create({"path": model, "device": self.device})
 98              self.assertEqual(pooling.encode(["test"], category="query").shape, (1, 10240))
 99  
100              # Test custom settings
101              pooling = PoolingFactory.create(
102                  {"path": model, "device": self.device, "modelargs": {"muvera": {"repetitions": 5, "hashes": 2, "projection": 8}}}
103              )
104              self.assertEqual(pooling.encode(["test"], category="data").shape, (1, 160))
105  
106      def testPrompts(self):
107          """
108          Test instruction prompts
109          """
110  
111          # Load model with prompts
112          pooling = PoolingFactory.create({"path": "neuml/bert-tiny-prompts", "device": self.device, "loadprompts": True})
113  
114          # Test prompts are prepended
115          self.assertEqual(pooling.preencode(["abc"], "query")[0], "query: abc")
116          self.assertEqual(pooling.preencode(["text"], "data")[0], "document: text")
117  
118          # Load model with prompts disabled (default)
119          pooling = PoolingFactory.create({"path": "neuml/bert-tiny-prompts", "device": self.device})
120  
121          # Test that prompts are not prepended
122          self.assertEqual(pooling.preencode(["abc"], "query")[0], "abc")
123          self.assertEqual(pooling.preencode(["text"], "data")[0], "text")