test_from_function.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 from collections.abc import Callable 6 from typing import Annotated, Literal 7 8 import pytest 9 10 from haystack.components.agents.state import State 11 from haystack.tools.errors import SchemaGenerationError 12 from haystack.tools.from_function import _remove_title_from_schema, create_tool_from_function, tool 13 from haystack.tools.tool import Tool 14 15 16 def function_with_docstring(city: str) -> str: 17 """Get weather report for a city.""" 18 return f"Weather report for {city}: 20°C, sunny" 19 20 21 def test_from_function_description_from_docstring(): 22 tool = create_tool_from_function(function=function_with_docstring) 23 24 assert tool.name == "function_with_docstring" 25 assert tool.description == "Get weather report for a city." 26 assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} 27 assert tool.function == function_with_docstring 28 29 30 def test_from_function_with_empty_description(): 31 tool = create_tool_from_function(function=function_with_docstring, description="") 32 33 assert tool.name == "function_with_docstring" 34 assert tool.description == "" 35 assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} 36 assert tool.function == function_with_docstring 37 38 39 def test_from_function_with_custom_description(): 40 tool = create_tool_from_function(function=function_with_docstring, description="custom description") 41 42 assert tool.name == "function_with_docstring" 43 assert tool.description == "custom description" 44 assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} 45 assert tool.function == function_with_docstring 46 47 48 def test_from_function_with_custom_name(): 49 tool = create_tool_from_function(function=function_with_docstring, name="custom_name") 50 51 assert tool.name == "custom_name" 52 assert tool.description == "Get weather report for a city." 53 assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} 54 assert tool.function == function_with_docstring 55 56 57 def test_from_function_annotated(): 58 def function_with_annotations( 59 city: Annotated[str, "the city for which to get the weather"] = "Munich", 60 unit: Annotated[Literal["Celsius", "Fahrenheit"], "the unit for the temperature"] = "Celsius", 61 nullable_param: Annotated[str | None, "a nullable parameter"] = None, 62 ) -> str: 63 """A simple function to get the current weather for a location.""" 64 return f"Weather report for {city}: 20 {unit}, sunny" 65 66 tool = create_tool_from_function(function=function_with_annotations) 67 68 assert tool.name == "function_with_annotations" 69 assert tool.description == "A simple function to get the current weather for a location." 70 assert tool.parameters == { 71 "type": "object", 72 "properties": { 73 "city": {"type": "string", "description": "the city for which to get the weather", "default": "Munich"}, 74 "unit": { 75 "type": "string", 76 "enum": ["Celsius", "Fahrenheit"], 77 "description": "the unit for the temperature", 78 "default": "Celsius", 79 }, 80 "nullable_param": { 81 "anyOf": [{"type": "string"}, {"type": "null"}], 82 "description": "a nullable parameter", 83 "default": None, 84 }, 85 }, 86 } 87 88 89 def test_from_function_missing_type_hint(): 90 def function_missing_type_hint(city) -> str: 91 return f"Weather report for {city}: 20°C, sunny" 92 93 with pytest.raises(ValueError): 94 create_tool_from_function(function=function_missing_type_hint) 95 96 97 def test_from_function_schema_generation_error(): 98 def function_with_invalid_type_hint(city: "invalid") -> str: # noqa: F821 99 return f"Weather report for {city}: 20°C, sunny" 100 101 with pytest.raises(SchemaGenerationError): 102 create_tool_from_function(function=function_with_invalid_type_hint) 103 104 105 def test_from_function_with_callable_params_skipped(): 106 def function_with_callback(query: str, callback: Callable[[str], None] | None = None) -> str: 107 """A function with a callable parameter.""" 108 return query 109 110 tool = create_tool_from_function(function=function_with_callback) 111 112 assert tool.name == "function_with_callback" 113 param_names = list(tool.parameters.get("properties", {}).keys()) 114 assert "callback" not in param_names 115 assert "query" in param_names 116 117 118 def test_from_function_state_param_excluded_from_schema(): 119 def function_with_state(city: str, state: State) -> str: 120 """Get weather for a city, with access to agent state.""" 121 return f"Weather in {city}: sunny" 122 123 tool = create_tool_from_function(function=function_with_state) 124 125 assert tool.name == "function_with_state" 126 param_names = list(tool.parameters.get("properties", {}).keys()) 127 assert "state" not in param_names 128 assert "city" in param_names 129 assert tool.parameters == {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} 130 131 132 def test_tool_decorator_state_param_excluded_from_schema(): 133 @tool 134 def function_with_state(city: str, state: State) -> str: 135 """Get weather for a city, with access to agent state.""" 136 return f"Weather in {city}: sunny" 137 138 param_names = list(function_with_state.parameters.get("properties", {}).keys()) 139 assert "state" not in param_names 140 assert "city" in param_names 141 142 143 def test_from_function_optional_state_param_excluded_from_schema(): 144 def function_with_optional_state(city: str, state: State | None = None) -> str: 145 """Get weather for a city, optionally using agent state.""" 146 return f"Weather in {city}: sunny" 147 148 tool = create_tool_from_function(function=function_with_optional_state) 149 150 param_names = list(tool.parameters.get("properties", {}).keys()) 151 assert "state" not in param_names 152 assert "city" in param_names 153 154 155 def test_tool_decorator(): 156 @tool 157 def get_weather(city: str) -> str: 158 """Get weather report for a city.""" 159 return f"Weather report for {city}: 20°C, sunny" 160 161 assert get_weather.name == "get_weather" 162 assert get_weather.description == "Get weather report for a city." 163 assert get_weather.parameters == { 164 "type": "object", 165 "properties": {"city": {"type": "string"}}, 166 "required": ["city"], 167 } 168 assert callable(get_weather.function) 169 assert get_weather.function("Berlin") == "Weather report for Berlin: 20°C, sunny" 170 171 172 # Test function for decorator deserialization 173 @tool 174 def weather_tool_with_decorator(city: str) -> str: 175 """Get weather report for a city.""" 176 return f"Weather report for {city}: 20°C, sunny" 177 178 179 def test_tool_decorator_deserialization(): 180 serialized = weather_tool_with_decorator.to_dict() 181 deserialized = Tool.from_dict(serialized) 182 assert deserialized.name == "weather_tool_with_decorator" 183 assert deserialized.description == "Get weather report for a city." 184 assert deserialized.parameters == { 185 "type": "object", 186 "properties": {"city": {"type": "string"}}, 187 "required": ["city"], 188 } 189 190 191 def test_tool_decorator_with_annotated_params(): 192 @tool 193 def get_weather( 194 city: Annotated[str, "The target city"] = "Berlin", 195 output_format: Annotated[Literal["short", "long"], "Output format"] = "short", 196 ) -> str: 197 """Get weather report for a city.""" 198 return f"Weather report for {city} ({output_format} format): 20°C, sunny" 199 200 assert get_weather.name == "get_weather" 201 assert get_weather.description == "Get weather report for a city." 202 assert get_weather.parameters == { 203 "type": "object", 204 "properties": { 205 "city": {"type": "string", "description": "The target city", "default": "Berlin"}, 206 "output_format": { 207 "type": "string", 208 "enum": ["short", "long"], 209 "description": "Output format", 210 "default": "short", 211 }, 212 }, 213 } 214 assert callable(get_weather.function) 215 assert get_weather.function("Berlin", "short") == "Weather report for Berlin (short format): 20°C, sunny" 216 217 218 def test_tool_decorator_with_parameters(): 219 @tool(name="fetch_weather", description="A tool to check the weather.") 220 def get_weather( 221 city: Annotated[str, "The target city"] = "Berlin", 222 output_format: Annotated[Literal["short", "long"], "Output format"] = "short", 223 ) -> str: 224 """Get weather report for a city.""" 225 return f"Weather report for {city} ({output_format} format): 20°C, sunny" 226 227 assert get_weather.name == "fetch_weather" 228 assert get_weather.description == "A tool to check the weather." 229 230 231 def test_tool_decorator_with_inputs_and_outputs(): 232 @tool(inputs_from_state={"output_format": "output_format"}, outputs_to_state={"output": {"source": "output"}}) 233 def get_weather( 234 city: Annotated[str, "The target city"] = "Berlin", 235 output_format: Annotated[Literal["short", "long"], "Output format"] = "short", 236 ) -> str: 237 """Get weather report for a city.""" 238 return f"Weather report for {city} ({output_format} format): 20°C, sunny" 239 240 assert get_weather.name == "get_weather" 241 assert get_weather.inputs_from_state == {"output_format": "output_format"} 242 assert get_weather.outputs_to_state == {"output": {"source": "output"}} 243 # Inputs should be excluded from auto-generated parameters 244 assert get_weather.parameters == { 245 "type": "object", 246 "properties": {"city": {"type": "string", "description": "The target city", "default": "Berlin"}}, 247 } 248 249 250 def test_remove_title_from_schema(): 251 complex_schema = { 252 "properties": { 253 "parameter1": { 254 "anyOf": [{"type": "string"}, {"type": "integer"}], 255 "default": "default_value", 256 "title": "Parameter1", 257 }, 258 "parameter2": { 259 "default": [1, 2, 3], 260 "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, 261 "title": "Parameter2", 262 "type": "array", 263 }, 264 "parameter3": { 265 "anyOf": [ 266 {"type": "string"}, 267 {"type": "integer"}, 268 {"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, "type": "array"}, 269 ], 270 "default": 42, 271 "title": "Parameter3", 272 }, 273 "parameter4": { 274 "anyOf": [{"type": "string"}, {"items": {"type": "integer"}, "type": "array"}, {"type": "object"}], 275 "default": {"key": "value"}, 276 "title": "Parameter4", 277 }, 278 }, 279 "title": "complex_function", 280 "type": "object", 281 } 282 283 _remove_title_from_schema(complex_schema) 284 285 assert complex_schema == { 286 "properties": { 287 "parameter1": {"anyOf": [{"type": "string"}, {"type": "integer"}], "default": "default_value"}, 288 "parameter2": { 289 "default": [1, 2, 3], 290 "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, 291 "type": "array", 292 }, 293 "parameter3": { 294 "anyOf": [ 295 {"type": "string"}, 296 {"type": "integer"}, 297 {"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, "type": "array"}, 298 ], 299 "default": 42, 300 }, 301 "parameter4": { 302 "anyOf": [{"type": "string"}, {"items": {"type": "integer"}, "type": "array"}, {"type": "object"}], 303 "default": {"key": "value"}, 304 }, 305 }, 306 "type": "object", 307 } 308 309 310 def test_remove_title_from_schema_do_not_remove_title_property(): 311 """Test that the utility function only removes the 'title' keywords and not the 'title' property (if present).""" 312 schema = { 313 "properties": { 314 "parameter1": {"type": "string", "title": "Parameter1"}, 315 "title": {"type": "string", "title": "Title"}, 316 }, 317 "title": "complex_function", 318 "type": "object", 319 } 320 321 _remove_title_from_schema(schema) 322 323 assert schema == {"properties": {"parameter1": {"type": "string"}, "title": {"type": "string"}}, "type": "object"} 324 325 326 def test_remove_title_from_schema_handle_no_title_in_top_level(): 327 schema = { 328 "properties": { 329 "parameter1": {"type": "string", "title": "Parameter1"}, 330 "parameter2": {"type": "integer", "title": "Parameter2"}, 331 }, 332 "type": "object", 333 } 334 335 _remove_title_from_schema(schema) 336 337 assert schema == { 338 "properties": {"parameter1": {"type": "string"}, "parameter2": {"type": "integer"}}, 339 "type": "object", 340 } 341 342 343 def test_from_function_async_raises_error(): 344 async def async_get_weather(city: str) -> str: 345 """Get weather report for a city.""" 346 return f"Weather report for {city}: 20°C, sunny" 347 348 with pytest.raises(ValueError, match="Async functions are not supported as tools"): 349 create_tool_from_function(async_get_weather) 350 351 352 def test_tool_decorator_async_raises_error(): 353 with pytest.raises(ValueError, match="Async functions are not supported as tools"): 354 355 @tool 356 async def async_get_weather(city: str) -> str: 357 """Get weather report for a city.""" 358 return f"Weather report for {city}: 20°C, sunny"