testllm.py
1 """ 2 LLM module tests 3 """ 4 5 import unittest 6 7 import torch 8 9 from transformers import AutoModelForCausalLM, AutoTokenizer 10 11 from txtai.pipeline import LLM, Generation 12 13 # pylint: disable=C0411 14 from utils import Utils 15 16 17 class TestLLM(unittest.TestCase): 18 """ 19 LLM tests. 20 """ 21 22 def testArguments(self): 23 """ 24 Test pipeline keyword arguments 25 """ 26 27 start = "Hello, how are" 28 29 # Test that text is generated with custom parameters 30 model = LLM("hf-internal-testing/tiny-random-gpt2", task="language-generation", dtype="torch.float32") 31 self.assertIsNotNone(model(start)) 32 33 model = LLM("hf-internal-testing/tiny-random-gpt2", task="language-generation", dtype=torch.float32) 34 self.assertIsNotNone(model(start)) 35 36 def testBatchSize(self): 37 """ 38 Test batch size 39 """ 40 41 model = LLM("sshleifer/tiny-gpt2") 42 self.assertIsNotNone(model(["Hello, how are"] * 2, batch_size=2)) 43 44 def testCustom(self): 45 """ 46 Test custom LLM framework 47 """ 48 49 model = LLM("hf-internal-testing/tiny-random-gpt2", task="language-generation", method="txtai.pipeline.HFGeneration") 50 self.assertIsNotNone(model("Hello, how are")) 51 52 def testCustomNotFound(self): 53 """ 54 Test resolving an unresolvable LLM framework 55 """ 56 57 with self.assertRaises(ImportError): 58 LLM("hf-internal-testing/tiny-random-gpt2", method="notfound.generation") 59 60 def testDefaultRole(self): 61 """ 62 Test default role 63 """ 64 65 model = LLM("hf-internal-testing/tiny-random-LlamaForCausalLM") 66 generator = model.generator 67 68 # Validate that the LLM supports chat messages 69 self.assertEqual(model.ischat(), True) 70 71 messages = [ 72 ("Hello", list), 73 ("\n<|im_start|>Hello<|im_end|>", str), 74 ("<|start|>Hello<|end|>", str), 75 ("<|start_of_role|>system<|end_of_role|>", str), 76 ("[INST]Hello[/INST]", str), 77 ] 78 79 for message, expected in messages: 80 # Test auto detection of formats 81 self.assertEqual(type(generator.format([message], "auto")[0]), expected) 82 83 # Test always setting user chat messages 84 self.assertEqual(type(generator.format([message], "user")[0]), list) 85 86 # Test always keeping as prompt text 87 self.assertEqual(type(generator.format([message], "prompt")[0]), str) 88 89 def testExternal(self): 90 """ 91 Test externally loaded model 92 """ 93 94 model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2") 95 tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") 96 97 model = LLM((model, tokenizer), template="{text}") 98 start = "Hello, how are" 99 100 # Test that text is generated 101 self.assertIsNotNone(model(start)) 102 103 def testMaxLength(self): 104 """ 105 Test max length 106 """ 107 108 model = LLM("sshleifer/tiny-gpt2") 109 self.assertIsInstance(model("Hello, how are", maxlength=10), str) 110 111 def testNotImplemented(self): 112 """ 113 Test exceptions for non-implemented methods 114 """ 115 116 generation = Generation() 117 self.assertRaises(NotImplementedError, generation.stream, None, None, None, None) 118 119 def testStop(self): 120 """ 121 Test stop strings 122 """ 123 124 model = LLM("sshleifer/tiny-gpt2") 125 self.assertIsNotNone(model("Hello, how are", stop=["you"])) 126 127 def testStream(self): 128 """ 129 Test streaming generation 130 """ 131 132 model = LLM("sshleifer/tiny-gpt2") 133 self.assertIsInstance(" ".join(x for x in model("Hello, how are", stream=True)), str) 134 135 def testStripThink(self): 136 """ 137 Test stripthink parameter 138 """ 139 140 # pylint: disable=W0613 141 def execute1(*args, **kwargs): 142 return ["<think>test</think>you"] 143 144 def execute2(*args, **kwargs): 145 return ["<|channel|>final<|message|> you"] 146 147 model = LLM("hf-internal-testing/tiny-random-LlamaForCausalLM") 148 149 for method in [execute1, execute2]: 150 # Override execute method 151 model.generator.execute = method 152 self.assertEqual(model("Hello, how are", stripthink=True), "you") 153 self.assertEqual(model("Hello, how are", stripthink=False), method()[0]) 154 155 def testStripThinkStream(self): 156 """ 157 Test stripthink parameter with streaming output 158 """ 159 160 # pylint: disable=W0613 161 def execute1(*args, **kwargs): 162 yield from "<think>test</think>you" 163 164 def execute2(*args, **kwargs): 165 yield from "<|channel|>final<|message|>you" 166 167 model = LLM("hf-internal-testing/tiny-random-LlamaForCausalLM") 168 169 for method in [execute1, execute2]: 170 # Override execute method 171 model.generator.execute = method 172 self.assertEqual("".join(model("Hello, how are", stripthink=True, stream=True)), "you") 173 self.assertEqual("".join(model("Hello, how are", stripthink=False, stream=True)), "".join(list(method()))) 174 175 def testVision(self): 176 """ 177 Test vision LLM 178 """ 179 180 model = LLM("neuml/tiny-random-qwen2vl") 181 result = model( 182 [{"role": "user", "content": [{"type": "text", "text": "What is in this image?"}, {"type": "image", "image": Utils.PATH + "/books.jpg"}]}] 183 ) 184 185 self.assertIsNotNone(result)