test_tool.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import re 6 7 import pytest 8 9 from haystack.dataclasses import TextContent 10 from haystack.tools import Tool, _check_duplicate_tool_names 11 from haystack.tools.errors import ToolInvocationError 12 from haystack.tools.tool import _deserialize_outputs_to_string, _serialize_outputs_to_string 13 14 15 def get_weather_report(city: str) -> str: 16 return f"Weather report for {city}: 20°C, sunny" 17 18 19 def format_string(text: str) -> str: 20 return f"Formatted: {text}" 21 22 23 def outputs_to_result_handler(result): 24 return [TextContent(text=result["text"])] 25 26 27 parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} 28 29 30 async def async_get_weather(city: str) -> str: 31 return f"Weather report for {city}: 20°C, sunny" 32 33 34 class TestTool: 35 def test_init(self): 36 tool = Tool( 37 name="weather", description="Get weather report", parameters=parameters, function=get_weather_report 38 ) 39 40 assert tool.name == "weather" 41 assert tool.description == "Get weather report" 42 assert tool.parameters == parameters 43 assert tool.function == get_weather_report 44 assert tool.inputs_from_state is None 45 assert tool.outputs_to_state is None 46 47 def test_init_invalid_parameters(self): 48 params = {"type": "invalid", "properties": {"city": {"type": "string"}}} 49 with pytest.raises(ValueError): 50 Tool(name="irrelevant", description="irrelevant", parameters=params, function=get_weather_report) 51 52 def test_init_async_function_raises_error(self): 53 with pytest.raises(ValueError, match="Async functions are not supported as tools"): 54 Tool(name="weather", description="Get weather report", parameters=parameters, function=async_get_weather) 55 56 @pytest.mark.parametrize( 57 "outputs_to_state", 58 [ 59 pytest.param({"documents": {"source": get_weather_report}}, id="source-not-a-string"), 60 pytest.param({"documents": {"handler": "some_string", "source": "docs"}}, id="handler-not-callable"), 61 ], 62 ) 63 def test_init_invalid_output_structure(self, outputs_to_state): 64 with pytest.raises(ValueError): 65 Tool( 66 name="irrelevant", 67 description="irrelevant", 68 parameters={"type": "object", "properties": {"city": {"type": "string"}}}, 69 function=get_weather_report, 70 outputs_to_state=outputs_to_state, 71 ) 72 73 def test_init_invalid_output_structure_config_not_dict(self): 74 with pytest.raises(TypeError): 75 Tool( 76 name="irrelevant", 77 description="irrelevant", 78 parameters={"type": "object", "properties": {"city": {"type": "string"}}}, 79 function=get_weather_report, 80 outputs_to_state={"documents": ["some_value"]}, 81 ) 82 83 @pytest.mark.parametrize( 84 "outputs_to_string", 85 [ 86 pytest.param({"source": get_weather_report}, id="source-not-a-string"), 87 pytest.param({"handler": "some_string"}, id="handler-not-callable"), 88 pytest.param({"raw_result": "not-a-bool"}, id="raw_result-not-a-bool"), 89 pytest.param({"documents": {"source": get_weather_report}}, id="multi-value-source-not-a-string"), 90 pytest.param({"documents": {"handler": "some_string"}}, id="multi-value-handler-not-callable"), 91 pytest.param( 92 {"documents": {"source": "docs", "raw_result": True}}, id="multi-value-raw_result-not-supported" 93 ), 94 ], 95 ) 96 def test_init_invalid_outputs_to_string_structure(self, outputs_to_string): 97 with pytest.raises(ValueError): 98 Tool( 99 name="irrelevant", 100 description="irrelevant", 101 parameters={"type": "object", "properties": {"city": {"type": "string"}}}, 102 function=get_weather_report, 103 outputs_to_string=outputs_to_string, 104 ) 105 106 def test_init_invalid_outputs_to_string_structure_config_not_dict(self): 107 with pytest.raises(TypeError): 108 Tool( 109 name="irrelevant", 110 description="irrelevant", 111 parameters={"type": "object", "properties": {"city": {"type": "string"}}}, 112 function=get_weather_report, 113 outputs_to_string={"documents": ["some_value"]}, 114 ) 115 116 def test_tool_spec(self): 117 tool = Tool( 118 name="weather", description="Get weather report", parameters=parameters, function=get_weather_report 119 ) 120 121 assert tool.tool_spec == {"name": "weather", "description": "Get weather report", "parameters": parameters} 122 123 def test_invoke(self): 124 tool = Tool( 125 name="weather", description="Get weather report", parameters=parameters, function=get_weather_report 126 ) 127 128 assert tool.invoke(city="Berlin") == "Weather report for Berlin: 20°C, sunny" 129 130 def test_invoke_fail(self): 131 tool = Tool( 132 name="weather", description="Get weather report", parameters=parameters, function=get_weather_report 133 ) 134 with pytest.raises( 135 ToolInvocationError, 136 match=re.escape( 137 "Failed to invoke Tool `weather` with parameters {}. Error: get_weather_report() missing 1 required " 138 "positional argument: 'city'" 139 ), 140 ): 141 tool.invoke() 142 143 def test_to_dict(self): 144 tool = Tool( 145 name="weather", 146 description="Get weather report", 147 parameters=parameters, 148 function=get_weather_report, 149 outputs_to_string={"handler": format_string}, 150 inputs_from_state={"location": "city"}, 151 outputs_to_state={"documents": {"handler": get_weather_report, "source": "docs"}}, 152 ) 153 154 assert tool.to_dict() == { 155 "type": "haystack.tools.tool.Tool", 156 "data": { 157 "name": "weather", 158 "description": "Get weather report", 159 "parameters": parameters, 160 "function": "test_tool.get_weather_report", 161 "outputs_to_string": {"handler": "test_tool.format_string"}, 162 "inputs_from_state": {"location": "city"}, 163 "outputs_to_state": {"documents": {"source": "docs", "handler": "test_tool.get_weather_report"}}, 164 }, 165 } 166 167 def test_from_dict(self): 168 tool_dict = { 169 "type": "haystack.tools.tool.Tool", 170 "data": { 171 "name": "weather", 172 "description": "Get weather report", 173 "parameters": parameters, 174 "function": "test_tool.get_weather_report", 175 "outputs_to_string": {"handler": "test_tool.format_string"}, 176 "inputs_from_state": {"location": "city"}, 177 "outputs_to_state": {"documents": {"source": "docs", "handler": "test_tool.get_weather_report"}}, 178 }, 179 } 180 181 tool = Tool.from_dict(tool_dict) 182 183 assert tool.name == "weather" 184 assert tool.description == "Get weather report" 185 assert tool.parameters == parameters 186 assert tool.function == get_weather_report 187 assert tool.outputs_to_string == {"handler": format_string} 188 assert tool.inputs_from_state == {"location": "city"} 189 assert tool.outputs_to_state == {"documents": {"source": "docs", "handler": get_weather_report}} 190 191 def test_serialize_outputs_to_string(self): 192 config = {"handler": format_string, "source": "result", "raw_result": False} 193 serialized = _serialize_outputs_to_string(config) 194 assert serialized == {"handler": "test_tool.format_string", "source": "result", "raw_result": False} 195 196 config = {"handler": format_string} 197 serialized = _serialize_outputs_to_string(config) 198 assert serialized == {"handler": "test_tool.format_string"} 199 200 config = {"handler": outputs_to_result_handler, "raw_result": True} 201 serialized = _serialize_outputs_to_string(config) 202 assert serialized == {"handler": "test_tool.outputs_to_result_handler", "raw_result": True} 203 204 config = { 205 "report": {"source": "report", "handler": format_string}, 206 "temp": {"source": "temperature", "handler": format_string}, 207 } 208 serialized = _serialize_outputs_to_string(config) 209 assert serialized == { 210 "report": {"source": "report", "handler": "test_tool.format_string"}, 211 "temp": {"source": "temperature", "handler": "test_tool.format_string"}, 212 } 213 214 def test_deserialize_outputs_to_string(self): 215 serialized = {"handler": "test_tool.format_string", "source": "result", "raw_result": False} 216 deserialized = _deserialize_outputs_to_string(serialized) 217 assert deserialized == {"handler": format_string, "source": "result", "raw_result": False} 218 219 serialized = {"handler": "test_tool.format_string"} 220 deserialized = _deserialize_outputs_to_string(serialized) 221 assert deserialized == {"handler": format_string} 222 223 serialized = {"handler": "test_tool.outputs_to_result_handler", "raw_result": True} 224 deserialized = _deserialize_outputs_to_string(serialized) 225 assert deserialized == {"handler": outputs_to_result_handler, "raw_result": True} 226 227 serialized = { 228 "report": {"source": "report", "handler": "test_tool.format_string"}, 229 "temp": {"source": "temperature", "handler": "test_tool.format_string"}, 230 } 231 deserialized = _deserialize_outputs_to_string(serialized) 232 assert deserialized == { 233 "report": {"source": "report", "handler": format_string}, 234 "temp": {"source": "temperature", "handler": format_string}, 235 } 236 237 def test_inputs_from_state_validation_with_invalid_parameter(self): 238 """Test that inputs_from_state is validated against the parameters schema""" 239 with pytest.raises( 240 ValueError, 241 match=re.escape( 242 "inputs_from_state maps 'state_key' to unknown parameter 'nonexistent'. Valid parameters are: {'city'}." 243 ), 244 ): 245 Tool( 246 name="weather", 247 description="Get weather report", 248 parameters=parameters, 249 function=get_weather_report, 250 inputs_from_state={"state_key": "nonexistent"}, 251 ) 252 253 def test_inputs_from_state_validation_with_non_string_value(self): 254 """Test that inputs_from_state values must be strings""" 255 with pytest.raises(TypeError, match=re.escape("inputs_from_state values must be str, not dict")): 256 Tool( 257 name="weather", 258 description="Get weather report", 259 parameters=parameters, 260 function=get_weather_report, 261 inputs_from_state={"state_key": {"source": "city"}}, 262 ) 263 264 def test_inputs_from_state_validation_with_valid_parameter(self): 265 """Test that inputs_from_state works with valid parameter names""" 266 tool = Tool( 267 name="weather", 268 description="Get weather report", 269 parameters=parameters, 270 function=get_weather_report, 271 inputs_from_state={"location": "city"}, 272 ) 273 assert tool.inputs_from_state == {"location": "city"} 274 275 def test_outputs_to_state_no_validation_when_get_valid_outputs_returns_none(self): 276 """Test that outputs_to_state is not validated when _get_valid_outputs returns None""" 277 # This should not raise an error even though "nonexistent" is not a valid output 278 # because the base Tool class returns None from _get_valid_outputs() 279 tool = Tool( 280 name="weather", 281 description="Get weather report", 282 parameters=parameters, 283 function=get_weather_report, 284 outputs_to_state={"result": {"source": "nonexistent"}}, 285 ) 286 assert tool.outputs_to_state == {"result": {"source": "nonexistent"}} 287 288 def test_outputs_to_state_validation_when_subclass_provides_valid_outputs(self): 289 """Test that outputs_to_state is validated when subclass overrides _get_valid_outputs""" 290 291 class ToolWithOutputs(Tool): 292 def _get_valid_outputs(self): 293 return {"report", "temperature"} 294 295 # Valid output should work 296 tool = ToolWithOutputs( 297 name="weather", 298 description="Get weather report", 299 parameters=parameters, 300 function=get_weather_report, 301 outputs_to_state={"result": {"source": "report"}}, 302 ) 303 assert tool.outputs_to_state == {"result": {"source": "report"}} 304 305 # Invalid output should raise an error 306 with pytest.raises( 307 ValueError, 308 match=re.escape("outputs_to_state: 'weather' maps state key 'result' to unknown output 'nonexistent'"), 309 ): 310 ToolWithOutputs( 311 name="weather", 312 description="Get weather report", 313 parameters=parameters, 314 function=get_weather_report, 315 outputs_to_state={"result": {"source": "nonexistent"}}, 316 ) 317 318 319 def test_check_duplicate_tool_names(): 320 tools = [ 321 Tool(name="weather", description="Get weather report", parameters=parameters, function=get_weather_report), 322 Tool(name="weather", description="A different description", parameters=parameters, function=get_weather_report), 323 ] 324 with pytest.raises(ValueError): 325 _check_duplicate_tool_names(tools) 326 327 tools = [ 328 Tool(name="weather", description="Get weather report", parameters=parameters, function=get_weather_report), 329 Tool(name="weather2", description="Get weather report", parameters=parameters, function=get_weather_report), 330 ] 331 _check_duplicate_tool_names(tools)