/ test / python / testpipeline / testllm / testllama.py
testllama.py
 1  """
 2  Llama module tests
 3  """
 4  
 5  import unittest
 6  
 7  from unittest.mock import patch
 8  
 9  from txtai.pipeline import LLM
10  
11  
12  class TestLlama(unittest.TestCase):
13      """
14      llama.cpp tests.
15      """
16  
17      @patch("llama_cpp.Llama")
18      def testContext(self, llama):
19          """
20          Test n_ctx with llama.cpp
21          """
22  
23          class Llama:
24              """
25              Mock llama.cpp instance to test invalid context
26              """
27  
28              def __init__(self, **kwargs):
29                  if kwargs.get("n_ctx") == 0 or kwargs.get("n_ctx", 0) >= 10000:
30                      raise ValueError("Failed to create context")
31  
32                  # Save parameters
33                  self.params = kwargs
34  
35          # Mock llama.cpp instance
36          llama.side_effect = Llama
37  
38          # Model to test
39          path = "TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf"
40  
41          # Test omitting n_ctx falls back to default settings
42          llm = LLM(path)
43          self.assertNotIn("n_ctx", llm.generator.llm.params)
44  
45          # Test n_ctx=0 falls back to default settings
46          llm = LLM(path, n_ctx=0)
47          self.assertNotIn("n_ctx", llm.generator.llm.params)
48  
49          # Test n_ctx manually set
50          llm = LLM(path, n_ctx=1024)
51          self.assertEqual(llm.generator.llm.params["n_ctx"], 1024)
52  
53          # Mock a value for n_ctx that's too big
54          with self.assertRaises(ValueError):
55              llm = LLM(path, n_ctx=10000)
56  
57      def testGeneration(self):
58          """
59          Test generation with llama.cpp
60          """
61  
62          # Test model generation with llama.cpp
63          model = LLM("TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf", chat_format="chatml")
64  
65          # Test with prompt
66          self.assertEqual(model("2 + 2 = ", maxlength=10, seed=0, stop=["."], defaultrole="prompt")[0], "4")
67  
68          # Test with list of messages
69          messages = [{"role": "system", "content": "You are a helpful assistant. You answer math problems."}, {"role": "user", "content": "2+2?"}]
70          self.assertIsNotNone(model(messages, maxlength=10, seed=0, stop=["."]))
71  
72          # Test default role
73          self.assertIsNotNone(model("2 + 2 = ", maxlength=10, seed=0, stop=["."], defaultrole="user"))
74  
75          # Test streaming
76          self.assertEqual(" ".join(x for x in model("2 + 2 = ", maxlength=10, stream=True, seed=0, stop=["."], defaultrole="prompt"))[0], "4")