testagent.py
1 """ 2 Agent module tests 3 """ 4 5 import os 6 import tempfile 7 import unittest 8 9 from unittest.mock import patch 10 11 from datetime import datetime 12 13 from smolagents import CodeAgent, PythonInterpreterTool 14 15 from txtai.agent import Agent 16 from txtai.embeddings import Embeddings 17 18 # agents.md content 19 AGENTS = """ 20 Basic instructions here 21 """ 22 23 # Sample skill.md content 24 SKILL = """--- 25 name: hello 26 description: says hello world 27 --- 28 29 Says hello world 30 """ 31 32 33 class TestAgent(unittest.TestCase): 34 """ 35 Agent tests. 36 """ 37 38 def testExecute(self): 39 """ 40 Test executing main agent loop 41 """ 42 43 agent = Agent(llm="hf-internal-testing/tiny-random-LlamaForCausalLM", max_steps=1) 44 45 # Patch LLM to generate answer 46 agent.process.model.llm = lambda *args, **kwargs: 'Action:\n{"name": "final_answer", "arguments": "Hi"}' 47 48 self.assertEqual(agent("Hello"), "Hi") 49 50 def testInstructions(self): 51 """ 52 Test loading an agents.md file 53 """ 54 55 # Test loading instructions from file 56 agents = os.path.join(tempfile.gettempdir(), "agents.md") 57 with open(agents, "w", encoding="utf-8") as output: 58 output.write(AGENTS) 59 60 agent = Agent(instructions=agents, llm="hf-internal-testing/tiny-random-LlamaForCausalLM", max_iterations=1) 61 agent.process.model.llm = lambda *args, **kwargs: 'Action:\n{"name": "final_answer", "arguments": "Hi"}' 62 self.assertEqual(agent("Hello"), "Hi") 63 64 # Test loading from memory 65 agent = Agent(instructions=AGENTS, llm="hf-internal-testing/tiny-random-LlamaForCausalLM", max_iterations=1) 66 agent.process.model.llm = lambda *args, **kwargs: 'Action:\n{"name": "final_answer", "arguments": "Hi"}' 67 self.assertEqual(agent("Hello"), "Hi") 68 69 def testMemory(self): 70 """ 71 Test agent memory 72 """ 73 74 agent = Agent(llm="hf-internal-testing/tiny-random-LlamaForCausalLM", max_steps=1, memory=5) 75 76 # Patch LLM to generate answer 77 agent.process.model.llm = lambda *args, **kwargs: 'Action:\n{"name": "final_answer", "arguments": "Hi"}' 78 79 self.assertEqual(agent("Hello"), "Hi") 80 self.assertEqual(agent("Hello"), "Hi") 81 82 # Test that results are stored in shared memory 83 self.assertEqual(len(agent.memory.get(None)), 2) 84 85 # Test resetting shared memory 86 self.assertEqual(agent("Hello", reset=True), "Hi") 87 self.assertEqual(len(agent.memory.get(None)), 1) 88 89 # Test session memory 90 self.assertEqual(agent("Hello", session="session-0"), "Hi") 91 self.assertEqual(len(agent.memory.get("session-0")), 1) 92 93 # Test resetting session memory 94 self.assertEqual(agent("Hello", session="session-0", reset=True), "Hi") 95 self.assertEqual(len(agent.memory.get("session-0")), 1) 96 self.assertEqual(len(agent.memory.get(None)), 1) 97 98 def testMethod(self): 99 """ 100 Test agent process methods 101 """ 102 103 agent = Agent(method="code", llm="hf-internal-testing/tiny-random-LlamaForCausalLM", max_iterations=1) 104 self.assertIsInstance(agent.process, CodeAgent) 105 106 def testSkill(self): 107 """ 108 Test running a skill from a skill.md file 109 """ 110 111 skill = os.path.join(tempfile.gettempdir(), "skill.md") 112 with open(skill, "w", encoding="utf-8") as output: 113 output.write(SKILL) 114 115 agent = Agent(tools=[skill], llm="hf-internal-testing/tiny-random-LlamaForCausalLM", max_iterations=1) 116 117 self.assertIsInstance(agent.tools["hello"]("say hello"), str) 118 119 def testToolsBasic(self): 120 """ 121 Test adding basic function tools 122 """ 123 124 class DateTime: 125 """ 126 Date time instance 127 """ 128 129 def __call__(self, iso): 130 """ 131 Gets the current date and time 132 133 Args: 134 iso: date will be converted to iso format if True 135 136 Returns: 137 current date and time 138 """ 139 140 return datetime.today().isoformat() if iso else datetime.today() 141 142 today = {"name": "today", "description": "Gets the current date and time", "target": DateTime()} 143 144 def current(iso: str) -> str: 145 """ 146 Gets the current date and time 147 148 Args: 149 iso: date will be converted to iso format if True 150 151 Returns: 152 current date and time 153 """ 154 155 return datetime.today().isoformat() if iso else datetime.today() 156 157 agent = Agent(tools=[today, current, "websearch"], llm="hf-internal-testing/tiny-random-LlamaForCausalLM", max_steps=1) 158 159 self.assertIsNotNone(agent) 160 self.assertIsInstance(agent.tools["today"](True), str) 161 self.assertIsInstance(agent.tools["current"](True), str) 162 163 def testToolsDefaults(self): 164 """ 165 Test default toolkit tools 166 """ 167 168 agent = Agent(tools=["defaults"], llm="hf-internal-testing/tiny-random-LlamaForCausalLM", max_steps=1) 169 170 # Working directory 171 work = tempfile.gettempdir() 172 173 # Test file 174 path = os.path.join(work, "agent_tools.txt") 175 agent.tools["write"](path, "hello world") 176 177 # Test default tools 178 self.assertIsNotNone(agent.tools["bash"](["ls", work])) 179 self.assertGreater(len(agent.tools["glob"](work)), 0) 180 self.assertGreater(len(agent.tools["grep"]("world", "*")), 0) 181 self.assertEqual(agent.tools["todowrite"]("plan"), "plan") 182 183 agent.tools["edit"](path, "hello", "goodbye") 184 self.assertEqual(agent.tools["read"](path), "goodbye world".strip()) 185 186 def testToolsEmbeddings(self): 187 """ 188 Test adding Embeddings as a tool 189 """ 190 191 embeddings = Embeddings() 192 embeddings.index(["test"]) 193 194 # Generate temp file path and save 195 index = os.path.join(tempfile.gettempdir(), "embeddings.agent") 196 embeddings.save(index) 197 198 embeddings1 = { 199 "name": "embeddings1", 200 "description": "Searches a test database", 201 "target": embeddings, 202 } 203 204 embeddings2 = {"name": "embeddings2", "description": "Searches a test database", "path": index} 205 206 agent = Agent(tools=[embeddings1, embeddings2], llm="hf-internal-testing/tiny-random-LlamaForCausalLM", max_steps=1) 207 208 self.assertIsNotNone(agent) 209 self.assertIsInstance(agent.tools["embeddings1"]("test"), list) 210 211 # pylint: disable=C0115,C0116 212 @patch("mcpadapt.core.MCPAdapt") 213 def testToolsMCP(self, mcp): 214 """ 215 Test adding a MCP tool collection 216 """ 217 218 class MCPAdapt: 219 def __init__(self, *args): 220 self.args = args 221 222 def tools(self): 223 return [PythonInterpreterTool()] 224 225 # Patch MCP adapter for testing 226 mcp.side_effect = MCPAdapt 227 228 agent = Agent(tools=["http://localhost:8000/mcp"], llm="hf-internal-testing/tiny-random-LlamaForCausalLM", max_steps=1) 229 self.assertEqual(len(agent.tools), 2)