/ test / python / testagent.py
testagent.py
  1  """
  2  Agent module tests
  3  """
  4  
  5  import os
  6  import tempfile
  7  import unittest
  8  
  9  from unittest.mock import patch
 10  
 11  from datetime import datetime
 12  
 13  from smolagents import CodeAgent, PythonInterpreterTool
 14  
 15  from txtai.agent import Agent
 16  from txtai.embeddings import Embeddings
 17  
 18  # agents.md content
 19  AGENTS = """
 20  Basic instructions here
 21  """
 22  
 23  # Sample skill.md content
 24  SKILL = """---
 25  name: hello
 26  description: says hello world
 27  ---
 28  
 29  Says hello world
 30  """
 31  
 32  
 33  class TestAgent(unittest.TestCase):
 34      """
 35      Agent tests.
 36      """
 37  
 38      def testExecute(self):
 39          """
 40          Test executing main agent loop
 41          """
 42  
 43          agent = Agent(llm="hf-internal-testing/tiny-random-LlamaForCausalLM", max_steps=1)
 44  
 45          # Patch LLM to generate answer
 46          agent.process.model.llm = lambda *args, **kwargs: 'Action:\n{"name": "final_answer", "arguments": "Hi"}'
 47  
 48          self.assertEqual(agent("Hello"), "Hi")
 49  
 50      def testInstructions(self):
 51          """
 52          Test loading an agents.md file
 53          """
 54  
 55          # Test loading instructions from file
 56          agents = os.path.join(tempfile.gettempdir(), "agents.md")
 57          with open(agents, "w", encoding="utf-8") as output:
 58              output.write(AGENTS)
 59  
 60          agent = Agent(instructions=agents, llm="hf-internal-testing/tiny-random-LlamaForCausalLM", max_iterations=1)
 61          agent.process.model.llm = lambda *args, **kwargs: 'Action:\n{"name": "final_answer", "arguments": "Hi"}'
 62          self.assertEqual(agent("Hello"), "Hi")
 63  
 64          # Test loading from memory
 65          agent = Agent(instructions=AGENTS, llm="hf-internal-testing/tiny-random-LlamaForCausalLM", max_iterations=1)
 66          agent.process.model.llm = lambda *args, **kwargs: 'Action:\n{"name": "final_answer", "arguments": "Hi"}'
 67          self.assertEqual(agent("Hello"), "Hi")
 68  
 69      def testMemory(self):
 70          """
 71          Test agent memory
 72          """
 73  
 74          agent = Agent(llm="hf-internal-testing/tiny-random-LlamaForCausalLM", max_steps=1, memory=5)
 75  
 76          # Patch LLM to generate answer
 77          agent.process.model.llm = lambda *args, **kwargs: 'Action:\n{"name": "final_answer", "arguments": "Hi"}'
 78  
 79          self.assertEqual(agent("Hello"), "Hi")
 80          self.assertEqual(agent("Hello"), "Hi")
 81  
 82          # Test that results are stored in shared memory
 83          self.assertEqual(len(agent.memory.get(None)), 2)
 84  
 85          # Test resetting shared memory
 86          self.assertEqual(agent("Hello", reset=True), "Hi")
 87          self.assertEqual(len(agent.memory.get(None)), 1)
 88  
 89          # Test session memory
 90          self.assertEqual(agent("Hello", session="session-0"), "Hi")
 91          self.assertEqual(len(agent.memory.get("session-0")), 1)
 92  
 93          # Test resetting session memory
 94          self.assertEqual(agent("Hello", session="session-0", reset=True), "Hi")
 95          self.assertEqual(len(agent.memory.get("session-0")), 1)
 96          self.assertEqual(len(agent.memory.get(None)), 1)
 97  
 98      def testMethod(self):
 99          """
100          Test agent process methods
101          """
102  
103          agent = Agent(method="code", llm="hf-internal-testing/tiny-random-LlamaForCausalLM", max_iterations=1)
104          self.assertIsInstance(agent.process, CodeAgent)
105  
106      def testSkill(self):
107          """
108          Test running a skill from a skill.md file
109          """
110  
111          skill = os.path.join(tempfile.gettempdir(), "skill.md")
112          with open(skill, "w", encoding="utf-8") as output:
113              output.write(SKILL)
114  
115          agent = Agent(tools=[skill], llm="hf-internal-testing/tiny-random-LlamaForCausalLM", max_iterations=1)
116  
117          self.assertIsInstance(agent.tools["hello"]("say hello"), str)
118  
119      def testToolsBasic(self):
120          """
121          Test adding basic function tools
122          """
123  
124          class DateTime:
125              """
126              Date time instance
127              """
128  
129              def __call__(self, iso):
130                  """
131                  Gets the current date and time
132  
133                  Args:
134                      iso: date will be converted to iso format if True
135  
136                  Returns:
137                      current date and time
138                  """
139  
140                  return datetime.today().isoformat() if iso else datetime.today()
141  
142          today = {"name": "today", "description": "Gets the current date and time", "target": DateTime()}
143  
144          def current(iso: str) -> str:
145              """
146              Gets the current date and time
147  
148              Args:
149                  iso: date will be converted to iso format if True
150  
151              Returns:
152                  current date and time
153              """
154  
155              return datetime.today().isoformat() if iso else datetime.today()
156  
157          agent = Agent(tools=[today, current, "websearch"], llm="hf-internal-testing/tiny-random-LlamaForCausalLM", max_steps=1)
158  
159          self.assertIsNotNone(agent)
160          self.assertIsInstance(agent.tools["today"](True), str)
161          self.assertIsInstance(agent.tools["current"](True), str)
162  
163      def testToolsDefaults(self):
164          """
165          Test default toolkit tools
166          """
167  
168          agent = Agent(tools=["defaults"], llm="hf-internal-testing/tiny-random-LlamaForCausalLM", max_steps=1)
169  
170          # Working directory
171          work = tempfile.gettempdir()
172  
173          # Test file
174          path = os.path.join(work, "agent_tools.txt")
175          agent.tools["write"](path, "hello world")
176  
177          # Test default tools
178          self.assertIsNotNone(agent.tools["bash"](["ls", work]))
179          self.assertGreater(len(agent.tools["glob"](work)), 0)
180          self.assertGreater(len(agent.tools["grep"]("world", "*")), 0)
181          self.assertEqual(agent.tools["todowrite"]("plan"), "plan")
182  
183          agent.tools["edit"](path, "hello", "goodbye")
184          self.assertEqual(agent.tools["read"](path), "goodbye world".strip())
185  
186      def testToolsEmbeddings(self):
187          """
188          Test adding Embeddings as a tool
189          """
190  
191          embeddings = Embeddings()
192          embeddings.index(["test"])
193  
194          # Generate temp file path and save
195          index = os.path.join(tempfile.gettempdir(), "embeddings.agent")
196          embeddings.save(index)
197  
198          embeddings1 = {
199              "name": "embeddings1",
200              "description": "Searches a test database",
201              "target": embeddings,
202          }
203  
204          embeddings2 = {"name": "embeddings2", "description": "Searches a test database", "path": index}
205  
206          agent = Agent(tools=[embeddings1, embeddings2], llm="hf-internal-testing/tiny-random-LlamaForCausalLM", max_steps=1)
207  
208          self.assertIsNotNone(agent)
209          self.assertIsInstance(agent.tools["embeddings1"]("test"), list)
210  
211      # pylint: disable=C0115,C0116
212      @patch("mcpadapt.core.MCPAdapt")
213      def testToolsMCP(self, mcp):
214          """
215          Test adding a MCP tool collection
216          """
217  
218          class MCPAdapt:
219              def __init__(self, *args):
220                  self.args = args
221  
222              def tools(self):
223                  return [PythonInterpreterTool()]
224  
225          # Patch MCP adapter for testing
226          mcp.side_effect = MCPAdapt
227  
228          agent = Agent(tools=["http://localhost:8000/mcp"], llm="hf-internal-testing/tiny-random-LlamaForCausalLM", max_steps=1)
229          self.assertEqual(len(agent.tools), 2)