/ test / python / testpipeline / testtext / testsimilarity.py
testsimilarity.py
  1  """
  2  Similarity module tests
  3  """
  4  
  5  import unittest
  6  
  7  from txtai.pipeline import Similarity
  8  
  9  
 10  class TestSimilarity(unittest.TestCase):
 11      """
 12      Similarity tests.
 13      """
 14  
 15      @classmethod
 16      def setUpClass(cls):
 17          """
 18          Create single labels instance.
 19          """
 20  
 21          cls.data = [
 22              "US tops 5 million confirmed virus cases",
 23              "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg",
 24              "Beijing mobilises invasion craft along coast as Taiwan tensions escalate",
 25              "The National Park Service warns against sacrificing slower friends in a bear attack",
 26              "Maine man wins $1M from $25 lottery ticket",
 27              "Make huge profits without work, earn up to $100,000 a day",
 28          ]
 29  
 30          cls.similarity = Similarity("prajjwal1/bert-medium-mnli")
 31  
 32      def testCrossEncoder(self):
 33          """
 34          Test cross-encoder similarity model
 35          """
 36  
 37          similarity = Similarity("cross-encoder/ms-marco-MiniLM-L-2-v2", crossencode=True)
 38          uid = similarity("Who won the lottery?", self.data)[0][0]
 39          self.assertEqual(self.data[uid], self.data[4])
 40  
 41      def testCrossEncoderBatch(self):
 42          """
 43          Test cross-encoder similarity model with multiple inputs
 44          """
 45  
 46          similarity = Similarity("cross-encoder/ms-marco-MiniLM-L-2-v2", crossencode=True)
 47          results = [r[0][0] for r in similarity(["Who won the lottery?", "Where did an iceberg collapse?"], self.data)]
 48          self.assertEqual(results, [4, 1])
 49  
 50      def testLateEncoder(self):
 51          """
 52          Test late-encoder similarity model
 53          """
 54  
 55          similarity = Similarity("neuml/pylate-bert-tiny", lateencode=True)
 56          uid = similarity("Who won the lottery?", self.data)[0][0]
 57          self.assertEqual(self.data[uid], self.data[4])
 58  
 59          # Test encode method
 60          # pylint: disable=E1101
 61          self.assertEqual(similarity.encode(["Who won the lottery?"], "data").shape, (1, 8, 128))
 62  
 63      def testLateEncoderBatch(self):
 64          """
 65          Test late-encoder similarity model with multiple inputs
 66          """
 67  
 68          similarity = Similarity("neuml/colbert-bert-tiny", lateencode=True)
 69          results = [r[0][0] for r in similarity(["Who won the lottery?", "Where did an iceberg collapse?"], self.data)]
 70          self.assertEqual(results, [4, 1])
 71  
 72      def testSimilarity(self):
 73          """
 74          Test similarity with single query
 75          """
 76  
 77          uid = self.similarity("feel good story", self.data)[0][0]
 78          self.assertEqual(self.data[uid], self.data[4])
 79  
 80      def testSimilarityBatch(self):
 81          """
 82          Test similarity with multiple queries
 83          """
 84  
 85          results = [r[0][0] for r in self.similarity(["feel good story", "climate change"], self.data)]
 86          self.assertEqual(results, [4, 1])
 87  
 88      def testSimilarityFixed(self):
 89          """
 90          Test similarity with a fixed label text classification model
 91          """
 92  
 93          similarity = Similarity(dynamic=False)
 94  
 95          # Test with query as label text and label id
 96          self.assertLessEqual(similarity("negative", ["This is the best sentence ever"])[0][1], 0.1)
 97          self.assertLessEqual(similarity("0", ["This is the best sentence ever"])[0][1], 0.1)
 98  
 99      def testSimilarityLong(self):
100          """
101          Test similarity with long text
102          """
103  
104          uid = self.similarity("other", ["Very long text " * 1000, "other text"])[0][0]
105          self.assertEqual(uid, 1)