testllama.py
1 """ 2 Llama module tests 3 """ 4 5 import unittest 6 7 from unittest.mock import patch 8 9 from txtai.pipeline import LLM 10 11 12 class TestLlama(unittest.TestCase): 13 """ 14 llama.cpp tests. 15 """ 16 17 @patch("llama_cpp.Llama") 18 def testContext(self, llama): 19 """ 20 Test n_ctx with llama.cpp 21 """ 22 23 class Llama: 24 """ 25 Mock llama.cpp instance to test invalid context 26 """ 27 28 def __init__(self, **kwargs): 29 if kwargs.get("n_ctx") == 0 or kwargs.get("n_ctx", 0) >= 10000: 30 raise ValueError("Failed to create context") 31 32 # Save parameters 33 self.params = kwargs 34 35 # Mock llama.cpp instance 36 llama.side_effect = Llama 37 38 # Model to test 39 path = "TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf" 40 41 # Test omitting n_ctx falls back to default settings 42 llm = LLM(path) 43 self.assertNotIn("n_ctx", llm.generator.llm.params) 44 45 # Test n_ctx=0 falls back to default settings 46 llm = LLM(path, n_ctx=0) 47 self.assertNotIn("n_ctx", llm.generator.llm.params) 48 49 # Test n_ctx manually set 50 llm = LLM(path, n_ctx=1024) 51 self.assertEqual(llm.generator.llm.params["n_ctx"], 1024) 52 53 # Mock a value for n_ctx that's too big 54 with self.assertRaises(ValueError): 55 llm = LLM(path, n_ctx=10000) 56 57 def testGeneration(self): 58 """ 59 Test generation with llama.cpp 60 """ 61 62 # Test model generation with llama.cpp 63 model = LLM("TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf", chat_format="chatml") 64 65 # Test with prompt 66 self.assertEqual(model("2 + 2 = ", maxlength=10, seed=0, stop=["."], defaultrole="prompt")[0], "4") 67 68 # Test with list of messages 69 messages = [{"role": "system", "content": "You are a helpful assistant. You answer math problems."}, {"role": "user", "content": "2+2?"}] 70 self.assertIsNotNone(model(messages, maxlength=10, seed=0, stop=["."])) 71 72 # Test default role 73 self.assertIsNotNone(model("2 + 2 = ", maxlength=10, seed=0, stop=["."], defaultrole="user")) 74 75 # Test streaming 76 self.assertEqual(" ".join(x for x in model("2 + 2 = ", maxlength=10, stream=True, seed=0, stop=["."], defaultrole="prompt"))[0], "4")