/ test / python / testapi / testapiagent.py
testapiagent.py
 1  """
 2  Agent API module tests
 3  """
 4  
 5  import os
 6  import tempfile
 7  import unittest
 8  
 9  from unittest.mock import patch
10  
11  from fastapi.testclient import TestClient
12  
13  from txtai.api import API, application
14  
15  # Configuration for agents
16  AGENTS = """
17  agent:
18      test:
19          max_iterations: 1
20          tools:
21              - name: testtool
22                description: Test tool
23                target: testapi.testapiagent.TestTool
24  
25  llm:
26      path: hf-internal-testing/tiny-random-LlamaForCausalLM
27  """
28  
29  
30  # pylint: disable=R0904
31  class TestAgent(unittest.TestCase):
32      """
33      API tests for agents.
34      """
35  
36      @staticmethod
37      @patch.dict(os.environ, {"CONFIG": os.path.join(tempfile.gettempdir(), "testapi.yml"), "API_CLASS": "txtai.api.API"})
38      def start():
39          """
40          Starts a mock FastAPI client.
41          """
42  
43          config = os.path.join(tempfile.gettempdir(), "testapi.yml")
44  
45          with open(config, "w", encoding="utf-8") as output:
46              output.write(AGENTS)
47  
48          # Create new application and set on client
49          application.app = application.create()
50          client = TestClient(application.app)
51          application.start()
52  
53          # Patch LLM to generate answer
54          agent = application.get().agents["test"]
55          agent.process.model.llm = lambda *args, **kwargs: 'Action:\n{"name": "final_answer", "arguments": "Hi"}'
56  
57          return client
58  
59      @classmethod
60      def setUpClass(cls):
61          """
62          Create API client on creation of class.
63          """
64  
65          cls.client = TestAgent.start()
66  
67      def testAgent(self):
68          """
69          Test agent via API
70          """
71  
72          results = self.client.post("agent", json={"name": "test", "text": "Hello"}).json()
73          self.assertEqual(results, "Hi")
74  
75      def testEmpty(self):
76          """
77          Test empty API configuration
78          """
79  
80          api = API({})
81  
82          self.assertIsNone(api.agent("junk", "test"))
83  
84  
85  class TestTool:
86      """
87      Class to test agent tools
88      """
89  
90      def __call__(self):
91          pass