/ test / tools / test_from_function.py
test_from_function.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  from collections.abc import Callable
  6  from typing import Annotated, Literal
  7  
  8  import pytest
  9  
 10  from haystack.components.agents.state import State
 11  from haystack.tools.errors import SchemaGenerationError
 12  from haystack.tools.from_function import _remove_title_from_schema, create_tool_from_function, tool
 13  from haystack.tools.tool import Tool
 14  
 15  
 16  def function_with_docstring(city: str) -> str:
 17      """Get weather report for a city."""
 18      return f"Weather report for {city}: 20°C, sunny"
 19  
 20  
 21  def test_from_function_description_from_docstring():
 22      tool = create_tool_from_function(function=function_with_docstring)
 23  
 24      assert tool.name == "function_with_docstring"
 25      assert tool.description == "Get weather report for a city."
 26      assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
 27      assert tool.function == function_with_docstring
 28  
 29  
 30  def test_from_function_with_empty_description():
 31      tool = create_tool_from_function(function=function_with_docstring, description="")
 32  
 33      assert tool.name == "function_with_docstring"
 34      assert tool.description == ""
 35      assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
 36      assert tool.function == function_with_docstring
 37  
 38  
 39  def test_from_function_with_custom_description():
 40      tool = create_tool_from_function(function=function_with_docstring, description="custom description")
 41  
 42      assert tool.name == "function_with_docstring"
 43      assert tool.description == "custom description"
 44      assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
 45      assert tool.function == function_with_docstring
 46  
 47  
 48  def test_from_function_with_custom_name():
 49      tool = create_tool_from_function(function=function_with_docstring, name="custom_name")
 50  
 51      assert tool.name == "custom_name"
 52      assert tool.description == "Get weather report for a city."
 53      assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
 54      assert tool.function == function_with_docstring
 55  
 56  
 57  def test_from_function_annotated():
 58      def function_with_annotations(
 59          city: Annotated[str, "the city for which to get the weather"] = "Munich",
 60          unit: Annotated[Literal["Celsius", "Fahrenheit"], "the unit for the temperature"] = "Celsius",
 61          nullable_param: Annotated[str | None, "a nullable parameter"] = None,
 62      ) -> str:
 63          """A simple function to get the current weather for a location."""
 64          return f"Weather report for {city}: 20 {unit}, sunny"
 65  
 66      tool = create_tool_from_function(function=function_with_annotations)
 67  
 68      assert tool.name == "function_with_annotations"
 69      assert tool.description == "A simple function to get the current weather for a location."
 70      assert tool.parameters == {
 71          "type": "object",
 72          "properties": {
 73              "city": {"type": "string", "description": "the city for which to get the weather", "default": "Munich"},
 74              "unit": {
 75                  "type": "string",
 76                  "enum": ["Celsius", "Fahrenheit"],
 77                  "description": "the unit for the temperature",
 78                  "default": "Celsius",
 79              },
 80              "nullable_param": {
 81                  "anyOf": [{"type": "string"}, {"type": "null"}],
 82                  "description": "a nullable parameter",
 83                  "default": None,
 84              },
 85          },
 86      }
 87  
 88  
 89  def test_from_function_missing_type_hint():
 90      def function_missing_type_hint(city) -> str:
 91          return f"Weather report for {city}: 20°C, sunny"
 92  
 93      with pytest.raises(ValueError):
 94          create_tool_from_function(function=function_missing_type_hint)
 95  
 96  
 97  def test_from_function_schema_generation_error():
 98      def function_with_invalid_type_hint(city: "invalid") -> str:  # noqa: F821
 99          return f"Weather report for {city}: 20°C, sunny"
100  
101      with pytest.raises(SchemaGenerationError):
102          create_tool_from_function(function=function_with_invalid_type_hint)
103  
104  
105  def test_from_function_with_callable_params_skipped():
106      def function_with_callback(query: str, callback: Callable[[str], None] | None = None) -> str:
107          """A function with a callable parameter."""
108          return query
109  
110      tool = create_tool_from_function(function=function_with_callback)
111  
112      assert tool.name == "function_with_callback"
113      param_names = list(tool.parameters.get("properties", {}).keys())
114      assert "callback" not in param_names
115      assert "query" in param_names
116  
117  
118  def test_from_function_state_param_excluded_from_schema():
119      def function_with_state(city: str, state: State) -> str:
120          """Get weather for a city, with access to agent state."""
121          return f"Weather in {city}: sunny"
122  
123      tool = create_tool_from_function(function=function_with_state)
124  
125      assert tool.name == "function_with_state"
126      param_names = list(tool.parameters.get("properties", {}).keys())
127      assert "state" not in param_names
128      assert "city" in param_names
129      assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
130  
131  
132  def test_tool_decorator_state_param_excluded_from_schema():
133      @tool
134      def function_with_state(city: str, state: State) -> str:
135          """Get weather for a city, with access to agent state."""
136          return f"Weather in {city}: sunny"
137  
138      param_names = list(function_with_state.parameters.get("properties", {}).keys())
139      assert "state" not in param_names
140      assert "city" in param_names
141  
142  
143  def test_from_function_optional_state_param_excluded_from_schema():
144      def function_with_optional_state(city: str, state: State | None = None) -> str:
145          """Get weather for a city, optionally using agent state."""
146          return f"Weather in {city}: sunny"
147  
148      tool = create_tool_from_function(function=function_with_optional_state)
149  
150      param_names = list(tool.parameters.get("properties", {}).keys())
151      assert "state" not in param_names
152      assert "city" in param_names
153  
154  
155  def test_tool_decorator():
156      @tool
157      def get_weather(city: str) -> str:
158          """Get weather report for a city."""
159          return f"Weather report for {city}: 20°C, sunny"
160  
161      assert get_weather.name == "get_weather"
162      assert get_weather.description == "Get weather report for a city."
163      assert get_weather.parameters == {
164          "type": "object",
165          "properties": {"city": {"type": "string"}},
166          "required": ["city"],
167      }
168      assert callable(get_weather.function)
169      assert get_weather.function("Berlin") == "Weather report for Berlin: 20°C, sunny"
170  
171  
172  # Test function for decorator deserialization
173  @tool
174  def weather_tool_with_decorator(city: str) -> str:
175      """Get weather report for a city."""
176      return f"Weather report for {city}: 20°C, sunny"
177  
178  
179  def test_tool_decorator_deserialization():
180      serialized = weather_tool_with_decorator.to_dict()
181      deserialized = Tool.from_dict(serialized)
182      assert deserialized.name == "weather_tool_with_decorator"
183      assert deserialized.description == "Get weather report for a city."
184      assert deserialized.parameters == {
185          "type": "object",
186          "properties": {"city": {"type": "string"}},
187          "required": ["city"],
188      }
189  
190  
191  def test_tool_decorator_with_annotated_params():
192      @tool
193      def get_weather(
194          city: Annotated[str, "The target city"] = "Berlin",
195          output_format: Annotated[Literal["short", "long"], "Output format"] = "short",
196      ) -> str:
197          """Get weather report for a city."""
198          return f"Weather report for {city} ({output_format} format): 20°C, sunny"
199  
200      assert get_weather.name == "get_weather"
201      assert get_weather.description == "Get weather report for a city."
202      assert get_weather.parameters == {
203          "type": "object",
204          "properties": {
205              "city": {"type": "string", "description": "The target city", "default": "Berlin"},
206              "output_format": {
207                  "type": "string",
208                  "enum": ["short", "long"],
209                  "description": "Output format",
210                  "default": "short",
211              },
212          },
213      }
214      assert callable(get_weather.function)
215      assert get_weather.function("Berlin", "short") == "Weather report for Berlin (short format): 20°C, sunny"
216  
217  
218  def test_tool_decorator_with_parameters():
219      @tool(name="fetch_weather", description="A tool to check the weather.")
220      def get_weather(
221          city: Annotated[str, "The target city"] = "Berlin",
222          output_format: Annotated[Literal["short", "long"], "Output format"] = "short",
223      ) -> str:
224          """Get weather report for a city."""
225          return f"Weather report for {city} ({output_format} format): 20°C, sunny"
226  
227      assert get_weather.name == "fetch_weather"
228      assert get_weather.description == "A tool to check the weather."
229  
230  
231  def test_tool_decorator_with_inputs_and_outputs():
232      @tool(inputs_from_state={"output_format": "output_format"}, outputs_to_state={"output": {"source": "output"}})
233      def get_weather(
234          city: Annotated[str, "The target city"] = "Berlin",
235          output_format: Annotated[Literal["short", "long"], "Output format"] = "short",
236      ) -> str:
237          """Get weather report for a city."""
238          return f"Weather report for {city} ({output_format} format): 20°C, sunny"
239  
240      assert get_weather.name == "get_weather"
241      assert get_weather.inputs_from_state == {"output_format": "output_format"}
242      assert get_weather.outputs_to_state == {"output": {"source": "output"}}
243      # Inputs should be excluded from auto-generated parameters
244      assert get_weather.parameters == {
245          "type": "object",
246          "properties": {"city": {"type": "string", "description": "The target city", "default": "Berlin"}},
247      }
248  
249  
250  def test_remove_title_from_schema():
251      complex_schema = {
252          "properties": {
253              "parameter1": {
254                  "anyOf": [{"type": "string"}, {"type": "integer"}],
255                  "default": "default_value",
256                  "title": "Parameter1",
257              },
258              "parameter2": {
259                  "default": [1, 2, 3],
260                  "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
261                  "title": "Parameter2",
262                  "type": "array",
263              },
264              "parameter3": {
265                  "anyOf": [
266                      {"type": "string"},
267                      {"type": "integer"},
268                      {"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, "type": "array"},
269                  ],
270                  "default": 42,
271                  "title": "Parameter3",
272              },
273              "parameter4": {
274                  "anyOf": [{"type": "string"}, {"items": {"type": "integer"}, "type": "array"}, {"type": "object"}],
275                  "default": {"key": "value"},
276                  "title": "Parameter4",
277              },
278          },
279          "title": "complex_function",
280          "type": "object",
281      }
282  
283      _remove_title_from_schema(complex_schema)
284  
285      assert complex_schema == {
286          "properties": {
287              "parameter1": {"anyOf": [{"type": "string"}, {"type": "integer"}], "default": "default_value"},
288              "parameter2": {
289                  "default": [1, 2, 3],
290                  "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
291                  "type": "array",
292              },
293              "parameter3": {
294                  "anyOf": [
295                      {"type": "string"},
296                      {"type": "integer"},
297                      {"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, "type": "array"},
298                  ],
299                  "default": 42,
300              },
301              "parameter4": {
302                  "anyOf": [{"type": "string"}, {"items": {"type": "integer"}, "type": "array"}, {"type": "object"}],
303                  "default": {"key": "value"},
304              },
305          },
306          "type": "object",
307      }
308  
309  
310  def test_remove_title_from_schema_do_not_remove_title_property():
311      """Test that the utility function only removes the 'title' keywords and not the 'title' property (if present)."""
312      schema = {
313          "properties": {
314              "parameter1": {"type": "string", "title": "Parameter1"},
315              "title": {"type": "string", "title": "Title"},
316          },
317          "title": "complex_function",
318          "type": "object",
319      }
320  
321      _remove_title_from_schema(schema)
322  
323      assert schema == {"properties": {"parameter1": {"type": "string"}, "title": {"type": "string"}}, "type": "object"}
324  
325  
326  def test_remove_title_from_schema_handle_no_title_in_top_level():
327      schema = {
328          "properties": {
329              "parameter1": {"type": "string", "title": "Parameter1"},
330              "parameter2": {"type": "integer", "title": "Parameter2"},
331          },
332          "type": "object",
333      }
334  
335      _remove_title_from_schema(schema)
336  
337      assert schema == {
338          "properties": {"parameter1": {"type": "string"}, "parameter2": {"type": "integer"}},
339          "type": "object",
340      }
341  
342  
343  def test_from_function_async_raises_error():
344      async def async_get_weather(city: str) -> str:
345          """Get weather report for a city."""
346          return f"Weather report for {city}: 20°C, sunny"
347  
348      with pytest.raises(ValueError, match="Async functions are not supported as tools"):
349          create_tool_from_function(async_get_weather)
350  
351  
352  def test_tool_decorator_async_raises_error():
353      with pytest.raises(ValueError, match="Async functions are not supported as tools"):
354  
355          @tool
356          async def async_get_weather(city: str) -> str:
357              """Get weather report for a city."""
358              return f"Weather report for {city}: 20°C, sunny"