testsimilarity.py
1 """ 2 Similarity module tests 3 """ 4 5 import unittest 6 7 from txtai.pipeline import Similarity 8 9 10 class TestSimilarity(unittest.TestCase): 11 """ 12 Similarity tests. 13 """ 14 15 @classmethod 16 def setUpClass(cls): 17 """ 18 Create single labels instance. 19 """ 20 21 cls.data = [ 22 "US tops 5 million confirmed virus cases", 23 "Canada's last fully intact ice shelf has suddenly collapsed, forming a Manhattan-sized iceberg", 24 "Beijing mobilises invasion craft along coast as Taiwan tensions escalate", 25 "The National Park Service warns against sacrificing slower friends in a bear attack", 26 "Maine man wins $1M from $25 lottery ticket", 27 "Make huge profits without work, earn up to $100,000 a day", 28 ] 29 30 cls.similarity = Similarity("prajjwal1/bert-medium-mnli") 31 32 def testCrossEncoder(self): 33 """ 34 Test cross-encoder similarity model 35 """ 36 37 similarity = Similarity("cross-encoder/ms-marco-MiniLM-L-2-v2", crossencode=True) 38 uid = similarity("Who won the lottery?", self.data)[0][0] 39 self.assertEqual(self.data[uid], self.data[4]) 40 41 def testCrossEncoderBatch(self): 42 """ 43 Test cross-encoder similarity model with multiple inputs 44 """ 45 46 similarity = Similarity("cross-encoder/ms-marco-MiniLM-L-2-v2", crossencode=True) 47 results = [r[0][0] for r in similarity(["Who won the lottery?", "Where did an iceberg collapse?"], self.data)] 48 self.assertEqual(results, [4, 1]) 49 50 def testLateEncoder(self): 51 """ 52 Test late-encoder similarity model 53 """ 54 55 similarity = Similarity("neuml/pylate-bert-tiny", lateencode=True) 56 uid = similarity("Who won the lottery?", self.data)[0][0] 57 self.assertEqual(self.data[uid], self.data[4]) 58 59 # Test encode method 60 # pylint: disable=E1101 61 self.assertEqual(similarity.encode(["Who won the lottery?"], "data").shape, (1, 8, 128)) 62 63 def testLateEncoderBatch(self): 64 """ 65 Test late-encoder similarity model with multiple inputs 66 """ 67 68 similarity = Similarity("neuml/colbert-bert-tiny", lateencode=True) 69 results = [r[0][0] for r in similarity(["Who won the lottery?", "Where did an iceberg collapse?"], self.data)] 70 self.assertEqual(results, [4, 1]) 71 72 def testSimilarity(self): 73 """ 74 Test similarity with single query 75 """ 76 77 uid = self.similarity("feel good story", self.data)[0][0] 78 self.assertEqual(self.data[uid], self.data[4]) 79 80 def testSimilarityBatch(self): 81 """ 82 Test similarity with multiple queries 83 """ 84 85 results = [r[0][0] for r in self.similarity(["feel good story", "climate change"], self.data)] 86 self.assertEqual(results, [4, 1]) 87 88 def testSimilarityFixed(self): 89 """ 90 Test similarity with a fixed label text classification model 91 """ 92 93 similarity = Similarity(dynamic=False) 94 95 # Test with query as label text and label id 96 self.assertLessEqual(similarity("negative", ["This is the best sentence ever"])[0][1], 0.1) 97 self.assertLessEqual(similarity("0", ["This is the best sentence ever"])[0][1], 0.1) 98 99 def testSimilarityLong(self): 100 """ 101 Test similarity with long text 102 """ 103 104 uid = self.similarity("other", ["Very long text " * 1000, "other text"])[0][0] 105 self.assertEqual(uid, 1)