/ test / tools / test_tool.py
test_tool.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import re
  6  
  7  import pytest
  8  
  9  from haystack.dataclasses import TextContent
 10  from haystack.tools import Tool, _check_duplicate_tool_names
 11  from haystack.tools.errors import ToolInvocationError
 12  from haystack.tools.tool import _deserialize_outputs_to_string, _serialize_outputs_to_string
 13  
 14  
 15  def get_weather_report(city: str) -> str:
 16      return f"Weather report for {city}: 20°C, sunny"
 17  
 18  
 19  def format_string(text: str) -> str:
 20      return f"Formatted: {text}"
 21  
 22  
 23  def outputs_to_result_handler(result):
 24      return [TextContent(text=result["text"])]
 25  
 26  
 27  parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
 28  
 29  
 30  async def async_get_weather(city: str) -> str:
 31      return f"Weather report for {city}: 20°C, sunny"
 32  
 33  
 34  class TestTool:
 35      def test_init(self):
 36          tool = Tool(
 37              name="weather", description="Get weather report", parameters=parameters, function=get_weather_report
 38          )
 39  
 40          assert tool.name == "weather"
 41          assert tool.description == "Get weather report"
 42          assert tool.parameters == parameters
 43          assert tool.function == get_weather_report
 44          assert tool.inputs_from_state is None
 45          assert tool.outputs_to_state is None
 46  
 47      def test_init_invalid_parameters(self):
 48          params = {"type": "invalid", "properties": {"city": {"type": "string"}}}
 49          with pytest.raises(ValueError):
 50              Tool(name="irrelevant", description="irrelevant", parameters=params, function=get_weather_report)
 51  
 52      def test_init_async_function_raises_error(self):
 53          with pytest.raises(ValueError, match="Async functions are not supported as tools"):
 54              Tool(name="weather", description="Get weather report", parameters=parameters, function=async_get_weather)
 55  
 56      @pytest.mark.parametrize(
 57          "outputs_to_state",
 58          [
 59              pytest.param({"documents": {"source": get_weather_report}}, id="source-not-a-string"),
 60              pytest.param({"documents": {"handler": "some_string", "source": "docs"}}, id="handler-not-callable"),
 61          ],
 62      )
 63      def test_init_invalid_output_structure(self, outputs_to_state):
 64          with pytest.raises(ValueError):
 65              Tool(
 66                  name="irrelevant",
 67                  description="irrelevant",
 68                  parameters={"type": "object", "properties": {"city": {"type": "string"}}},
 69                  function=get_weather_report,
 70                  outputs_to_state=outputs_to_state,
 71              )
 72  
 73      def test_init_invalid_output_structure_config_not_dict(self):
 74          with pytest.raises(TypeError):
 75              Tool(
 76                  name="irrelevant",
 77                  description="irrelevant",
 78                  parameters={"type": "object", "properties": {"city": {"type": "string"}}},
 79                  function=get_weather_report,
 80                  outputs_to_state={"documents": ["some_value"]},
 81              )
 82  
 83      @pytest.mark.parametrize(
 84          "outputs_to_string",
 85          [
 86              pytest.param({"source": get_weather_report}, id="source-not-a-string"),
 87              pytest.param({"handler": "some_string"}, id="handler-not-callable"),
 88              pytest.param({"raw_result": "not-a-bool"}, id="raw_result-not-a-bool"),
 89              pytest.param({"documents": {"source": get_weather_report}}, id="multi-value-source-not-a-string"),
 90              pytest.param({"documents": {"handler": "some_string"}}, id="multi-value-handler-not-callable"),
 91              pytest.param(
 92                  {"documents": {"source": "docs", "raw_result": True}}, id="multi-value-raw_result-not-supported"
 93              ),
 94          ],
 95      )
 96      def test_init_invalid_outputs_to_string_structure(self, outputs_to_string):
 97          with pytest.raises(ValueError):
 98              Tool(
 99                  name="irrelevant",
100                  description="irrelevant",
101                  parameters={"type": "object", "properties": {"city": {"type": "string"}}},
102                  function=get_weather_report,
103                  outputs_to_string=outputs_to_string,
104              )
105  
106      def test_init_invalid_outputs_to_string_structure_config_not_dict(self):
107          with pytest.raises(TypeError):
108              Tool(
109                  name="irrelevant",
110                  description="irrelevant",
111                  parameters={"type": "object", "properties": {"city": {"type": "string"}}},
112                  function=get_weather_report,
113                  outputs_to_string={"documents": ["some_value"]},
114              )
115  
116      def test_tool_spec(self):
117          tool = Tool(
118              name="weather", description="Get weather report", parameters=parameters, function=get_weather_report
119          )
120  
121          assert tool.tool_spec == {"name": "weather", "description": "Get weather report", "parameters": parameters}
122  
123      def test_invoke(self):
124          tool = Tool(
125              name="weather", description="Get weather report", parameters=parameters, function=get_weather_report
126          )
127  
128          assert tool.invoke(city="Berlin") == "Weather report for Berlin: 20°C, sunny"
129  
130      def test_invoke_fail(self):
131          tool = Tool(
132              name="weather", description="Get weather report", parameters=parameters, function=get_weather_report
133          )
134          with pytest.raises(
135              ToolInvocationError,
136              match=re.escape(
137                  "Failed to invoke Tool `weather` with parameters {}. Error: get_weather_report() missing 1 required "
138                  "positional argument: 'city'"
139              ),
140          ):
141              tool.invoke()
142  
143      def test_to_dict(self):
144          tool = Tool(
145              name="weather",
146              description="Get weather report",
147              parameters=parameters,
148              function=get_weather_report,
149              outputs_to_string={"handler": format_string},
150              inputs_from_state={"location": "city"},
151              outputs_to_state={"documents": {"handler": get_weather_report, "source": "docs"}},
152          )
153  
154          assert tool.to_dict() == {
155              "type": "haystack.tools.tool.Tool",
156              "data": {
157                  "name": "weather",
158                  "description": "Get weather report",
159                  "parameters": parameters,
160                  "function": "test_tool.get_weather_report",
161                  "outputs_to_string": {"handler": "test_tool.format_string"},
162                  "inputs_from_state": {"location": "city"},
163                  "outputs_to_state": {"documents": {"source": "docs", "handler": "test_tool.get_weather_report"}},
164              },
165          }
166  
167      def test_from_dict(self):
168          tool_dict = {
169              "type": "haystack.tools.tool.Tool",
170              "data": {
171                  "name": "weather",
172                  "description": "Get weather report",
173                  "parameters": parameters,
174                  "function": "test_tool.get_weather_report",
175                  "outputs_to_string": {"handler": "test_tool.format_string"},
176                  "inputs_from_state": {"location": "city"},
177                  "outputs_to_state": {"documents": {"source": "docs", "handler": "test_tool.get_weather_report"}},
178              },
179          }
180  
181          tool = Tool.from_dict(tool_dict)
182  
183          assert tool.name == "weather"
184          assert tool.description == "Get weather report"
185          assert tool.parameters == parameters
186          assert tool.function == get_weather_report
187          assert tool.outputs_to_string == {"handler": format_string}
188          assert tool.inputs_from_state == {"location": "city"}
189          assert tool.outputs_to_state == {"documents": {"source": "docs", "handler": get_weather_report}}
190  
191      def test_serialize_outputs_to_string(self):
192          config = {"handler": format_string, "source": "result", "raw_result": False}
193          serialized = _serialize_outputs_to_string(config)
194          assert serialized == {"handler": "test_tool.format_string", "source": "result", "raw_result": False}
195  
196          config = {"handler": format_string}
197          serialized = _serialize_outputs_to_string(config)
198          assert serialized == {"handler": "test_tool.format_string"}
199  
200          config = {"handler": outputs_to_result_handler, "raw_result": True}
201          serialized = _serialize_outputs_to_string(config)
202          assert serialized == {"handler": "test_tool.outputs_to_result_handler", "raw_result": True}
203  
204          config = {
205              "report": {"source": "report", "handler": format_string},
206              "temp": {"source": "temperature", "handler": format_string},
207          }
208          serialized = _serialize_outputs_to_string(config)
209          assert serialized == {
210              "report": {"source": "report", "handler": "test_tool.format_string"},
211              "temp": {"source": "temperature", "handler": "test_tool.format_string"},
212          }
213  
214      def test_deserialize_outputs_to_string(self):
215          serialized = {"handler": "test_tool.format_string", "source": "result", "raw_result": False}
216          deserialized = _deserialize_outputs_to_string(serialized)
217          assert deserialized == {"handler": format_string, "source": "result", "raw_result": False}
218  
219          serialized = {"handler": "test_tool.format_string"}
220          deserialized = _deserialize_outputs_to_string(serialized)
221          assert deserialized == {"handler": format_string}
222  
223          serialized = {"handler": "test_tool.outputs_to_result_handler", "raw_result": True}
224          deserialized = _deserialize_outputs_to_string(serialized)
225          assert deserialized == {"handler": outputs_to_result_handler, "raw_result": True}
226  
227          serialized = {
228              "report": {"source": "report", "handler": "test_tool.format_string"},
229              "temp": {"source": "temperature", "handler": "test_tool.format_string"},
230          }
231          deserialized = _deserialize_outputs_to_string(serialized)
232          assert deserialized == {
233              "report": {"source": "report", "handler": format_string},
234              "temp": {"source": "temperature", "handler": format_string},
235          }
236  
237      def test_inputs_from_state_validation_with_invalid_parameter(self):
238          """Test that inputs_from_state is validated against the parameters schema"""
239          with pytest.raises(
240              ValueError,
241              match=re.escape(
242                  "inputs_from_state maps 'state_key' to unknown parameter 'nonexistent'. Valid parameters are: {'city'}."
243              ),
244          ):
245              Tool(
246                  name="weather",
247                  description="Get weather report",
248                  parameters=parameters,
249                  function=get_weather_report,
250                  inputs_from_state={"state_key": "nonexistent"},
251              )
252  
253      def test_inputs_from_state_validation_with_non_string_value(self):
254          """Test that inputs_from_state values must be strings"""
255          with pytest.raises(TypeError, match=re.escape("inputs_from_state values must be str, not dict")):
256              Tool(
257                  name="weather",
258                  description="Get weather report",
259                  parameters=parameters,
260                  function=get_weather_report,
261                  inputs_from_state={"state_key": {"source": "city"}},
262              )
263  
264      def test_inputs_from_state_validation_with_valid_parameter(self):
265          """Test that inputs_from_state works with valid parameter names"""
266          tool = Tool(
267              name="weather",
268              description="Get weather report",
269              parameters=parameters,
270              function=get_weather_report,
271              inputs_from_state={"location": "city"},
272          )
273          assert tool.inputs_from_state == {"location": "city"}
274  
275      def test_outputs_to_state_no_validation_when_get_valid_outputs_returns_none(self):
276          """Test that outputs_to_state is not validated when _get_valid_outputs returns None"""
277          # This should not raise an error even though "nonexistent" is not a valid output
278          # because the base Tool class returns None from _get_valid_outputs()
279          tool = Tool(
280              name="weather",
281              description="Get weather report",
282              parameters=parameters,
283              function=get_weather_report,
284              outputs_to_state={"result": {"source": "nonexistent"}},
285          )
286          assert tool.outputs_to_state == {"result": {"source": "nonexistent"}}
287  
288      def test_outputs_to_state_validation_when_subclass_provides_valid_outputs(self):
289          """Test that outputs_to_state is validated when subclass overrides _get_valid_outputs"""
290  
291          class ToolWithOutputs(Tool):
292              def _get_valid_outputs(self):
293                  return {"report", "temperature"}
294  
295          # Valid output should work
296          tool = ToolWithOutputs(
297              name="weather",
298              description="Get weather report",
299              parameters=parameters,
300              function=get_weather_report,
301              outputs_to_state={"result": {"source": "report"}},
302          )
303          assert tool.outputs_to_state == {"result": {"source": "report"}}
304  
305          # Invalid output should raise an error
306          with pytest.raises(
307              ValueError,
308              match=re.escape("outputs_to_state: 'weather' maps state key 'result' to unknown output 'nonexistent'"),
309          ):
310              ToolWithOutputs(
311                  name="weather",
312                  description="Get weather report",
313                  parameters=parameters,
314                  function=get_weather_report,
315                  outputs_to_state={"result": {"source": "nonexistent"}},
316              )
317  
318  
319  def test_check_duplicate_tool_names():
320      tools = [
321          Tool(name="weather", description="Get weather report", parameters=parameters, function=get_weather_report),
322          Tool(name="weather", description="A different description", parameters=parameters, function=get_weather_report),
323      ]
324      with pytest.raises(ValueError):
325          _check_duplicate_tool_names(tools)
326  
327      tools = [
328          Tool(name="weather", description="Get weather report", parameters=parameters, function=get_weather_report),
329          Tool(name="weather2", description="Get weather report", parameters=parameters, function=get_weather_report),
330      ]
331      _check_duplicate_tool_names(tools)