/ test / python / testpipeline / testllm / testllm.py
testllm.py
  1  """
  2  LLM module tests
  3  """
  4  
  5  import unittest
  6  
  7  import torch
  8  
  9  from transformers import AutoModelForCausalLM, AutoTokenizer
 10  
 11  from txtai.pipeline import LLM, Generation
 12  
 13  # pylint: disable=C0411
 14  from utils import Utils
 15  
 16  
 17  class TestLLM(unittest.TestCase):
 18      """
 19      LLM tests.
 20      """
 21  
 22      def testArguments(self):
 23          """
 24          Test pipeline keyword arguments
 25          """
 26  
 27          start = "Hello, how are"
 28  
 29          # Test that text is generated with custom parameters
 30          model = LLM("hf-internal-testing/tiny-random-gpt2", task="language-generation", dtype="torch.float32")
 31          self.assertIsNotNone(model(start))
 32  
 33          model = LLM("hf-internal-testing/tiny-random-gpt2", task="language-generation", dtype=torch.float32)
 34          self.assertIsNotNone(model(start))
 35  
 36      def testBatchSize(self):
 37          """
 38          Test batch size
 39          """
 40  
 41          model = LLM("sshleifer/tiny-gpt2")
 42          self.assertIsNotNone(model(["Hello, how are"] * 2, batch_size=2))
 43  
 44      def testCustom(self):
 45          """
 46          Test custom LLM framework
 47          """
 48  
 49          model = LLM("hf-internal-testing/tiny-random-gpt2", task="language-generation", method="txtai.pipeline.HFGeneration")
 50          self.assertIsNotNone(model("Hello, how are"))
 51  
 52      def testCustomNotFound(self):
 53          """
 54          Test resolving an unresolvable LLM framework
 55          """
 56  
 57          with self.assertRaises(ImportError):
 58              LLM("hf-internal-testing/tiny-random-gpt2", method="notfound.generation")
 59  
 60      def testDefaultRole(self):
 61          """
 62          Test default role
 63          """
 64  
 65          model = LLM("hf-internal-testing/tiny-random-LlamaForCausalLM")
 66          generator = model.generator
 67  
 68          # Validate that the LLM supports chat messages
 69          self.assertEqual(model.ischat(), True)
 70  
 71          messages = [
 72              ("Hello", list),
 73              ("\n<|im_start|>Hello<|im_end|>", str),
 74              ("<|start|>Hello<|end|>", str),
 75              ("<|start_of_role|>system<|end_of_role|>", str),
 76              ("[INST]Hello[/INST]", str),
 77          ]
 78  
 79          for message, expected in messages:
 80              # Test auto detection of formats
 81              self.assertEqual(type(generator.format([message], "auto")[0]), expected)
 82  
 83              # Test always setting user chat messages
 84              self.assertEqual(type(generator.format([message], "user")[0]), list)
 85  
 86              # Test always keeping as prompt text
 87              self.assertEqual(type(generator.format([message], "prompt")[0]), str)
 88  
 89      def testExternal(self):
 90          """
 91          Test externally loaded model
 92          """
 93  
 94          model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
 95          tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
 96  
 97          model = LLM((model, tokenizer), template="{text}")
 98          start = "Hello, how are"
 99  
100          # Test that text is generated
101          self.assertIsNotNone(model(start))
102  
103      def testMaxLength(self):
104          """
105          Test max length
106          """
107  
108          model = LLM("sshleifer/tiny-gpt2")
109          self.assertIsInstance(model("Hello, how are", maxlength=10), str)
110  
111      def testNotImplemented(self):
112          """
113          Test exceptions for non-implemented methods
114          """
115  
116          generation = Generation()
117          self.assertRaises(NotImplementedError, generation.stream, None, None, None, None)
118  
119      def testStop(self):
120          """
121          Test stop strings
122          """
123  
124          model = LLM("sshleifer/tiny-gpt2")
125          self.assertIsNotNone(model("Hello, how are", stop=["you"]))
126  
127      def testStream(self):
128          """
129          Test streaming generation
130          """
131  
132          model = LLM("sshleifer/tiny-gpt2")
133          self.assertIsInstance(" ".join(x for x in model("Hello, how are", stream=True)), str)
134  
135      def testStripThink(self):
136          """
137          Test stripthink parameter
138          """
139  
140          # pylint: disable=W0613
141          def execute1(*args, **kwargs):
142              return ["<think>test</think>you"]
143  
144          def execute2(*args, **kwargs):
145              return ["<|channel|>final<|message|> you"]
146  
147          model = LLM("hf-internal-testing/tiny-random-LlamaForCausalLM")
148  
149          for method in [execute1, execute2]:
150              # Override execute method
151              model.generator.execute = method
152              self.assertEqual(model("Hello, how are", stripthink=True), "you")
153              self.assertEqual(model("Hello, how are", stripthink=False), method()[0])
154  
155      def testStripThinkStream(self):
156          """
157          Test stripthink parameter with streaming output
158          """
159  
160          # pylint: disable=W0613
161          def execute1(*args, **kwargs):
162              yield from "<think>test</think>you"
163  
164          def execute2(*args, **kwargs):
165              yield from "<|channel|>final<|message|>you"
166  
167          model = LLM("hf-internal-testing/tiny-random-LlamaForCausalLM")
168  
169          for method in [execute1, execute2]:
170              # Override execute method
171              model.generator.execute = method
172              self.assertEqual("".join(model("Hello, how are", stripthink=True, stream=True)), "you")
173              self.assertEqual("".join(model("Hello, how are", stripthink=False, stream=True)), "".join(list(method())))
174  
175      def testVision(self):
176          """
177          Test vision LLM
178          """
179  
180          model = LLM("neuml/tiny-random-qwen2vl")
181          result = model(
182              [{"role": "user", "content": [{"type": "text", "text": "What is in this image?"}, {"type": "image", "image": Utils.PATH + "/books.jpg"}]}]
183          )
184  
185          self.assertIsNotNone(result)