/ tests / test_agent2_tool_adapter.py
test_agent2_tool_adapter.py
  1  """Tests for restai.agent2.tool_adapter — schema building and AdaptedTool invocation."""
  2  import asyncio
  3  from typing import Optional
  4  
  5  from restai.agent2.tool_adapter import (
  6      AdaptedTool,
  7      _python_type_to_json_type,
  8      build_json_schema,
  9  )
 10  
 11  
 12  # ---------- build_json_schema ----------
 13  
 14  
 15  def test_build_json_schema_simple_function():
 16      def foo(x: str, y: int = 5) -> str:
 17          return x * y
 18  
 19      schema = build_json_schema(foo)
 20      assert schema["type"] == "object"
 21      assert "x" in schema["properties"]
 22      assert "y" in schema["properties"]
 23      assert schema["properties"]["x"]["type"] == "string"
 24      assert schema["properties"]["y"]["type"] == "integer"
 25      assert "x" in schema["required"]
 26      assert "y" not in schema["required"]
 27  
 28  
 29  def test_build_json_schema_optional_params():
 30      def bar(a: str, b: Optional[int] = None) -> str:
 31          return a
 32  
 33      schema = build_json_schema(bar)
 34      assert "a" in schema["properties"]
 35      assert "b" in schema["properties"]
 36      assert "a" in schema["required"]
 37      assert "b" not in schema["required"]
 38      # Optional[int] unwraps to integer
 39      assert schema["properties"]["b"]["type"] == "integer"
 40  
 41  
 42  def test_build_json_schema_no_annotations():
 43      def baz(x, y=10):
 44          return x
 45  
 46      schema = build_json_schema(baz)
 47      assert "x" in schema["properties"]
 48      assert "y" in schema["properties"]
 49      assert "x" in schema["required"]
 50      assert "y" not in schema["required"]
 51  
 52  
 53  # ---------- _python_type_to_json_type ----------
 54  
 55  
 56  def test_type_mapping_str():
 57      assert _python_type_to_json_type(str) == {"type": "string"}
 58  
 59  
 60  def test_type_mapping_int():
 61      assert _python_type_to_json_type(int) == {"type": "integer"}
 62  
 63  
 64  def test_type_mapping_float():
 65      assert _python_type_to_json_type(float) == {"type": "number"}
 66  
 67  
 68  def test_type_mapping_bool():
 69      assert _python_type_to_json_type(bool) == {"type": "boolean"}
 70  
 71  
 72  def test_type_mapping_list():
 73      assert _python_type_to_json_type(list)["type"] == "array"
 74  
 75  
 76  def test_type_mapping_dict():
 77      assert _python_type_to_json_type(dict) == {"type": "object"}
 78  
 79  
 80  # ---------- AdaptedTool.call ----------
 81  
 82  
 83  def test_adapted_tool_call_sync():
 84      def add(a: int, b: int) -> int:
 85          return a + b
 86  
 87      tool = AdaptedTool(
 88          name="add",
 89          description="Add two numbers",
 90          input_schema=build_json_schema(add),
 91          fn=add,
 92          is_async=False,
 93      )
 94      result = asyncio.run(tool.call({"a": 3, "b": 4}))
 95      assert result == "7"
 96  
 97  
 98  def test_adapted_tool_call_async():
 99      async def greet(name: str) -> str:
100          return f"Hello, {name}!"
101  
102      tool = AdaptedTool(
103          name="greet",
104          description="Greet someone",
105          input_schema=build_json_schema(greet),
106          fn=greet,
107          is_async=True,
108      )
109      result = asyncio.run(tool.call({"name": "World"}))
110      assert result == "Hello, World!"
111  
112  
113  def test_adapted_tool_call_returns_empty_for_none():
114      def noop() -> None:
115          return None
116  
117      tool = AdaptedTool(
118          name="noop",
119          description="No-op",
120          input_schema={},
121          fn=noop,
122          is_async=False,
123      )
124      result = asyncio.run(tool.call({}))
125      assert result == ""
126  
127  
128  def test_adapted_tool_call_bad_args_returns_error():
129      def strict(x: int) -> int:
130          return x + 1
131  
132      tool = AdaptedTool(
133          name="strict",
134          description="Strict",
135          input_schema=build_json_schema(strict),
136          fn=strict,
137          is_async=False,
138      )
139      result = asyncio.run(tool.call({"wrong_param": 1}))
140      assert "Error calling tool" in result