/ test / tools / test_serde_utils.py
test_serde_utils.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import pytest
  6  
  7  from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset
  8  
  9  
 10  def get_weather_report(city: str) -> str:
 11      return f"Weather report for {city}: 20°C, sunny"
 12  
 13  
 14  def calculate(a: int, b: int, operation: str) -> int:
 15      if operation == "add":
 16          return a + b
 17      if operation == "multiply":
 18          return a * b
 19      return 0
 20  
 21  
 22  def translate_text(text: str, target_language: str) -> str:
 23      return f"Translated '{text}' to {target_language}"
 24  
 25  
 26  def summarize_text(text: str, max_length: int) -> str:
 27      return text[:max_length]
 28  
 29  
 30  def format_text(text: str, style: str) -> str:
 31      return f"Formatted text in {style} style: {text}"
 32  
 33  
 34  weather_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
 35  
 36  calculator_parameters = {
 37      "type": "object",
 38      "properties": {
 39          "a": {"type": "integer"},
 40          "b": {"type": "integer"},
 41          "operation": {"type": "string", "enum": ["add", "multiply"]},
 42      },
 43      "required": ["a", "b", "operation"],
 44  }
 45  
 46  translator_parameters = {
 47      "type": "object",
 48      "properties": {"text": {"type": "string"}, "target_language": {"type": "string"}},
 49      "required": ["text", "target_language"],
 50  }
 51  
 52  summarizer_parameters = {
 53      "type": "object",
 54      "properties": {"text": {"type": "string"}, "max_length": {"type": "integer"}},
 55      "required": ["text", "max_length"],
 56  }
 57  
 58  formatter_parameters = {
 59      "type": "object",
 60      "properties": {"text": {"type": "string"}, "style": {"type": "string"}},
 61      "required": ["text", "style"],
 62  }
 63  
 64  # Legacy name for backward compatibility with existing tests
 65  parameters = weather_parameters
 66  
 67  
 68  class TestToolSerdeUtils:
 69      def test_serialize_toolset(self):
 70          toolset = Toolset(
 71              tools=[
 72                  Tool(
 73                      name="weather", description="Get weather report", parameters=parameters, function=get_weather_report
 74                  )
 75              ]
 76          )
 77  
 78          data = serialize_tools_or_toolset(toolset)
 79          assert data == toolset.to_dict()
 80  
 81      def test_serialize_tool(self):
 82          tool = Tool(
 83              name="weather", description="Get weather report", parameters=parameters, function=get_weather_report
 84          )
 85  
 86          data = serialize_tools_or_toolset([tool])
 87          assert data == [tool.to_dict()]
 88  
 89      def test_deserialize_tools_inplace(self):
 90          tool = Tool(
 91              name="weather", description="Get weather report", parameters=parameters, function=get_weather_report
 92          )
 93  
 94          data = {"tools": [tool.to_dict()]}
 95          deserialize_tools_or_toolset_inplace(data)
 96          assert data["tools"] == [tool]
 97  
 98          data = {"mytools": [tool.to_dict()]}
 99          deserialize_tools_or_toolset_inplace(data, key="mytools")
100          assert data["mytools"] == [tool]
101  
102          data = {"no_tools": 123}
103          deserialize_tools_or_toolset_inplace(data)
104          assert data == {"no_tools": 123}
105  
106      def test_deserialize_tools_inplace_failures(self):
107          data = {"key": "value"}
108          deserialize_tools_or_toolset_inplace(data)
109          assert data == {"key": "value"}
110  
111          data = {"tools": None}
112          deserialize_tools_or_toolset_inplace(data)
113          assert data == {"tools": None}
114  
115          data = {"tools": "not a list"}
116          with pytest.raises(TypeError):
117              deserialize_tools_or_toolset_inplace(data)
118  
119          data = {"tools": ["not a dictionary"]}
120          with pytest.raises(TypeError):
121              deserialize_tools_or_toolset_inplace(data)
122  
123          # not a subclass of Tool
124          data = {"tools": [{"type": "haystack.dataclasses.ChatMessage", "data": {"irrelevant": "irrelevant"}}]}
125          with pytest.raises(TypeError):
126              deserialize_tools_or_toolset_inplace(data)
127  
128      def test_deserialize_toolset_inplace(self):
129          tool = Tool(
130              name="weather", description="Get weather report", parameters=parameters, function=get_weather_report
131          )
132          toolset = Toolset(tools=[tool])
133  
134          data = {"tools": toolset.to_dict()}
135  
136          deserialize_tools_or_toolset_inplace(data)
137  
138          assert data["tools"] == toolset
139          assert isinstance(data["tools"], Toolset)
140          assert data["tools"][0] == tool
141  
142      def test_deserialize_toolset_inplace_failures(self):
143          data = {"tools": {"key": "value"}}
144          with pytest.raises(TypeError):
145              deserialize_tools_or_toolset_inplace(data)
146  
147          data = {"tools": {"type": "haystack.tools.Tool", "data": "some_data"}}
148          with pytest.raises(TypeError):
149              deserialize_tools_or_toolset_inplace(data)
150  
151      def test_serialize_list_of_toolsets(self):
152          """Test serialization of a list of Toolset instances."""
153          tool1 = Tool(
154              name="weather", description="Get weather report", parameters=parameters, function=get_weather_report
155          )
156          tool2 = Tool(
157              name="calculator", description="Calculate numbers", parameters=parameters, function=get_weather_report
158          )
159  
160          toolset1 = Toolset([tool1])
161          toolset2 = Toolset([tool2])
162  
163          data = serialize_tools_or_toolset([toolset1, toolset2])
164  
165          assert isinstance(data, list)
166          assert len(data) == 2
167          assert data[0] == toolset1.to_dict()
168          assert data[1] == toolset2.to_dict()
169          assert data[0]["type"] == "haystack.tools.toolset.Toolset"
170          assert data[1]["type"] == "haystack.tools.toolset.Toolset"
171  
172      def test_deserialize_list_of_toolsets_inplace(self):
173          """Test deserialization of a list of Toolset instances."""
174          tool1 = Tool(
175              name="weather", description="Get weather report", parameters=parameters, function=get_weather_report
176          )
177          tool2 = Tool(
178              name="calculator", description="Calculate numbers", parameters=parameters, function=get_weather_report
179          )
180  
181          toolset1 = Toolset([tool1])
182          toolset2 = Toolset([tool2])
183  
184          data = {"tools": [toolset1.to_dict(), toolset2.to_dict()]}
185          deserialize_tools_or_toolset_inplace(data)
186  
187          assert isinstance(data["tools"], list)
188          assert len(data["tools"]) == 2
189          assert all(isinstance(ts, Toolset) for ts in data["tools"])
190          assert data["tools"][0][0].name == "weather"
191          assert data["tools"][1][0].name == "calculator"
192  
193      def test_serialize_mixed_list_tools_and_toolsets(self):
194          """Test serialization of a mixed list of Tool and Toolset instances."""
195          tool1 = Tool(
196              name="weather", description="Get weather report", parameters=weather_parameters, function=get_weather_report
197          )
198          tool2 = Tool(
199              name="calculator", description="Calculate numbers", parameters=calculator_parameters, function=calculate
200          )
201  
202          toolset = Toolset([tool2])
203  
204          data = serialize_tools_or_toolset([tool1, toolset])
205  
206          assert isinstance(data, list)
207          assert len(data) == 2
208          assert data[0] == tool1.to_dict()
209          assert data[0]["type"] == "haystack.tools.tool.Tool"
210          assert data[0]["data"]["parameters"] == weather_parameters
211          assert data[1] == toolset.to_dict()
212          assert data[1]["type"] == "haystack.tools.toolset.Toolset"
213          assert data[1]["data"]["tools"][0]["data"]["parameters"] == calculator_parameters
214  
215      def test_serialize_mixed_list_multiple_tools_and_toolsets(self):
216          """Test serialization of a mixed list with multiple Tools and a Toolset containing multiple tools."""
217          tool1 = Tool(
218              name="weather", description="Get weather report", parameters=weather_parameters, function=get_weather_report
219          )
220          tool2 = Tool(
221              name="calculator", description="Calculate numbers", parameters=calculator_parameters, function=calculate
222          )
223          tool3 = Tool(
224              name="translator", description="Translate text", parameters=translator_parameters, function=translate_text
225          )
226          tool4 = Tool(
227              name="summarizer", description="Summarize text", parameters=summarizer_parameters, function=summarize_text
228          )
229          tool5 = Tool(name="formatter", description="Format text", parameters=formatter_parameters, function=format_text)
230  
231          toolset = Toolset([tool4, tool5])
232  
233          data = serialize_tools_or_toolset([tool1, tool2, toolset, tool3])
234  
235          assert isinstance(data, list)
236          assert len(data) == 4
237  
238          # Verify Tool 1 (weather)
239          assert data[0] == tool1.to_dict()
240          assert data[0]["type"] == "haystack.tools.tool.Tool"
241          assert data[0]["data"]["name"] == "weather"
242          assert data[0]["data"]["parameters"] == weather_parameters
243  
244          # Verify Tool 2 (calculator)
245          assert data[1] == tool2.to_dict()
246          assert data[1]["type"] == "haystack.tools.tool.Tool"
247          assert data[1]["data"]["name"] == "calculator"
248          assert data[1]["data"]["parameters"] == calculator_parameters
249  
250          # Verify Toolset (with summarizer and formatter)
251          assert data[2] == toolset.to_dict()
252          assert data[2]["type"] == "haystack.tools.toolset.Toolset"
253          assert len(data[2]["data"]["tools"]) == 2
254          assert data[2]["data"]["tools"][0]["data"]["name"] == "summarizer"
255          assert data[2]["data"]["tools"][0]["data"]["parameters"] == summarizer_parameters
256          assert data[2]["data"]["tools"][1]["data"]["name"] == "formatter"
257          assert data[2]["data"]["tools"][1]["data"]["parameters"] == formatter_parameters
258  
259          # Verify Tool 3 (translator)
260          assert data[3] == tool3.to_dict()
261          assert data[3]["type"] == "haystack.tools.tool.Tool"
262          assert data[3]["data"]["name"] == "translator"
263          assert data[3]["data"]["parameters"] == translator_parameters
264  
265      def test_deserialize_mixed_list_tools_and_toolsets_inplace(self):
266          """Test deserialization of a mixed list of Tool and Toolset instances."""
267          tool1 = Tool(
268              name="weather", description="Get weather report", parameters=weather_parameters, function=get_weather_report
269          )
270          tool2 = Tool(
271              name="calculator", description="Calculate numbers", parameters=calculator_parameters, function=calculate
272          )
273  
274          toolset = Toolset([tool2])
275  
276          data = {"tools": [tool1.to_dict(), toolset.to_dict()]}
277          deserialize_tools_or_toolset_inplace(data)
278  
279          assert isinstance(data["tools"], list)
280          assert len(data["tools"]) == 2
281  
282          # Verify Tool (weather)
283          assert isinstance(data["tools"][0], Tool)
284          assert data["tools"][0].name == "weather"
285          assert data["tools"][0].parameters == weather_parameters
286          assert data["tools"][0].function("Paris") == "Weather report for Paris: 20°C, sunny"
287  
288          # Verify Toolset with calculator tool
289          assert isinstance(data["tools"][1], Toolset)
290          assert len(data["tools"][1]) == 1
291          assert data["tools"][1][0].name == "calculator"
292          assert data["tools"][1][0].parameters == calculator_parameters
293          assert data["tools"][1][0].function(10, 5, "add") == 15
294          assert data["tools"][1][0].function(10, 5, "multiply") == 50
295  
296      def test_deserialize_mixed_list_multiple_tools_and_toolsets_inplace(self):
297          """Test deserialization of a mixed list with multiple Tools and a Toolset containing multiple tools."""
298          tool1 = Tool(
299              name="weather", description="Get weather report", parameters=weather_parameters, function=get_weather_report
300          )
301          tool2 = Tool(
302              name="calculator", description="Calculate numbers", parameters=calculator_parameters, function=calculate
303          )
304          tool3 = Tool(
305              name="translator", description="Translate text", parameters=translator_parameters, function=translate_text
306          )
307          tool4 = Tool(
308              name="summarizer", description="Summarize text", parameters=summarizer_parameters, function=summarize_text
309          )
310          tool5 = Tool(name="formatter", description="Format text", parameters=formatter_parameters, function=format_text)
311  
312          toolset = Toolset([tool4, tool5])
313  
314          data = {"tools": [tool1.to_dict(), tool2.to_dict(), toolset.to_dict(), tool3.to_dict()]}
315          deserialize_tools_or_toolset_inplace(data)
316  
317          assert isinstance(data["tools"], list)
318          assert len(data["tools"]) == 4
319  
320          # Verify Tool 1 (weather)
321          assert isinstance(data["tools"][0], Tool)
322          assert data["tools"][0].name == "weather"
323          assert data["tools"][0].parameters == weather_parameters
324          assert data["tools"][0].function("Berlin") == "Weather report for Berlin: 20°C, sunny"
325  
326          # Verify Tool 2 (calculator)
327          assert isinstance(data["tools"][1], Tool)
328          assert data["tools"][1].name == "calculator"
329          assert data["tools"][1].parameters == calculator_parameters
330          assert data["tools"][1].function(5, 3, "add") == 8
331          assert data["tools"][1].function(5, 3, "multiply") == 15
332  
333          # Verify Toolset (with summarizer and formatter)
334          assert isinstance(data["tools"][2], Toolset)
335          assert len(data["tools"][2]) == 2
336          assert data["tools"][2][0].name == "summarizer"
337          assert data["tools"][2][0].parameters == summarizer_parameters
338          assert data["tools"][2][0].function("Hello World", 5) == "Hello"
339          assert data["tools"][2][1].name == "formatter"
340          assert data["tools"][2][1].parameters == formatter_parameters
341          assert data["tools"][2][1].function("test", "bold") == "Formatted text in bold style: test"
342  
343          # Verify Tool 3 (translator)
344          assert isinstance(data["tools"][3], Tool)
345          assert data["tools"][3].name == "translator"
346          assert data["tools"][3].parameters == translator_parameters
347          assert data["tools"][3].function("Hello", "Spanish") == "Translated 'Hello' to Spanish"
348  
349      def test_serialize_none_returns_none(self):
350          """Test that serializing None returns None."""
351          data = serialize_tools_or_toolset(None)
352          assert data is None
353  
354      def test_serialize_empty_list_of_toolsets(self):
355          """Test that serializing an empty list of Toolsets returns an empty list."""
356          data = serialize_tools_or_toolset([])
357          assert data == []