test_serde_utils.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import pytest 6 7 from haystack.tools import Tool, Toolset, deserialize_tools_or_toolset_inplace, serialize_tools_or_toolset 8 9 10 def get_weather_report(city: str) -> str: 11 return f"Weather report for {city}: 20°C, sunny" 12 13 14 def calculate(a: int, b: int, operation: str) -> int: 15 if operation == "add": 16 return a + b 17 if operation == "multiply": 18 return a * b 19 return 0 20 21 22 def translate_text(text: str, target_language: str) -> str: 23 return f"Translated '{text}' to {target_language}" 24 25 26 def summarize_text(text: str, max_length: int) -> str: 27 return text[:max_length] 28 29 30 def format_text(text: str, style: str) -> str: 31 return f"Formatted text in {style} style: {text}" 32 33 34 weather_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} 35 36 calculator_parameters = { 37 "type": "object", 38 "properties": { 39 "a": {"type": "integer"}, 40 "b": {"type": "integer"}, 41 "operation": {"type": "string", "enum": ["add", "multiply"]}, 42 }, 43 "required": ["a", "b", "operation"], 44 } 45 46 translator_parameters = { 47 "type": "object", 48 "properties": {"text": {"type": "string"}, "target_language": {"type": "string"}}, 49 "required": ["text", "target_language"], 50 } 51 52 summarizer_parameters = { 53 "type": "object", 54 "properties": {"text": {"type": "string"}, "max_length": {"type": "integer"}}, 55 "required": ["text", "max_length"], 56 } 57 58 formatter_parameters = { 59 "type": "object", 60 "properties": {"text": {"type": "string"}, "style": {"type": "string"}}, 61 "required": ["text", "style"], 62 } 63 64 # Legacy name for backward compatibility with existing tests 65 parameters = weather_parameters 66 67 68 class TestToolSerdeUtils: 69 def test_serialize_toolset(self): 70 toolset = Toolset( 71 tools=[ 72 Tool( 73 name="weather", description="Get weather report", parameters=parameters, function=get_weather_report 74 ) 75 ] 76 ) 77 78 data = serialize_tools_or_toolset(toolset) 79 assert data == toolset.to_dict() 80 81 def test_serialize_tool(self): 82 tool = Tool( 83 name="weather", description="Get weather report", parameters=parameters, function=get_weather_report 84 ) 85 86 data = serialize_tools_or_toolset([tool]) 87 assert data == [tool.to_dict()] 88 89 def test_deserialize_tools_inplace(self): 90 tool = Tool( 91 name="weather", description="Get weather report", parameters=parameters, function=get_weather_report 92 ) 93 94 data = {"tools": [tool.to_dict()]} 95 deserialize_tools_or_toolset_inplace(data) 96 assert data["tools"] == [tool] 97 98 data = {"mytools": [tool.to_dict()]} 99 deserialize_tools_or_toolset_inplace(data, key="mytools") 100 assert data["mytools"] == [tool] 101 102 data = {"no_tools": 123} 103 deserialize_tools_or_toolset_inplace(data) 104 assert data == {"no_tools": 123} 105 106 def test_deserialize_tools_inplace_failures(self): 107 data = {"key": "value"} 108 deserialize_tools_or_toolset_inplace(data) 109 assert data == {"key": "value"} 110 111 data = {"tools": None} 112 deserialize_tools_or_toolset_inplace(data) 113 assert data == {"tools": None} 114 115 data = {"tools": "not a list"} 116 with pytest.raises(TypeError): 117 deserialize_tools_or_toolset_inplace(data) 118 119 data = {"tools": ["not a dictionary"]} 120 with pytest.raises(TypeError): 121 deserialize_tools_or_toolset_inplace(data) 122 123 # not a subclass of Tool 124 data = {"tools": [{"type": "haystack.dataclasses.ChatMessage", "data": {"irrelevant": "irrelevant"}}]} 125 with pytest.raises(TypeError): 126 deserialize_tools_or_toolset_inplace(data) 127 128 def test_deserialize_toolset_inplace(self): 129 tool = Tool( 130 name="weather", description="Get weather report", parameters=parameters, function=get_weather_report 131 ) 132 toolset = Toolset(tools=[tool]) 133 134 data = {"tools": toolset.to_dict()} 135 136 deserialize_tools_or_toolset_inplace(data) 137 138 assert data["tools"] == toolset 139 assert isinstance(data["tools"], Toolset) 140 assert data["tools"][0] == tool 141 142 def test_deserialize_toolset_inplace_failures(self): 143 data = {"tools": {"key": "value"}} 144 with pytest.raises(TypeError): 145 deserialize_tools_or_toolset_inplace(data) 146 147 data = {"tools": {"type": "haystack.tools.Tool", "data": "some_data"}} 148 with pytest.raises(TypeError): 149 deserialize_tools_or_toolset_inplace(data) 150 151 def test_serialize_list_of_toolsets(self): 152 """Test serialization of a list of Toolset instances.""" 153 tool1 = Tool( 154 name="weather", description="Get weather report", parameters=parameters, function=get_weather_report 155 ) 156 tool2 = Tool( 157 name="calculator", description="Calculate numbers", parameters=parameters, function=get_weather_report 158 ) 159 160 toolset1 = Toolset([tool1]) 161 toolset2 = Toolset([tool2]) 162 163 data = serialize_tools_or_toolset([toolset1, toolset2]) 164 165 assert isinstance(data, list) 166 assert len(data) == 2 167 assert data[0] == toolset1.to_dict() 168 assert data[1] == toolset2.to_dict() 169 assert data[0]["type"] == "haystack.tools.toolset.Toolset" 170 assert data[1]["type"] == "haystack.tools.toolset.Toolset" 171 172 def test_deserialize_list_of_toolsets_inplace(self): 173 """Test deserialization of a list of Toolset instances.""" 174 tool1 = Tool( 175 name="weather", description="Get weather report", parameters=parameters, function=get_weather_report 176 ) 177 tool2 = Tool( 178 name="calculator", description="Calculate numbers", parameters=parameters, function=get_weather_report 179 ) 180 181 toolset1 = Toolset([tool1]) 182 toolset2 = Toolset([tool2]) 183 184 data = {"tools": [toolset1.to_dict(), toolset2.to_dict()]} 185 deserialize_tools_or_toolset_inplace(data) 186 187 assert isinstance(data["tools"], list) 188 assert len(data["tools"]) == 2 189 assert all(isinstance(ts, Toolset) for ts in data["tools"]) 190 assert data["tools"][0][0].name == "weather" 191 assert data["tools"][1][0].name == "calculator" 192 193 def test_serialize_mixed_list_tools_and_toolsets(self): 194 """Test serialization of a mixed list of Tool and Toolset instances.""" 195 tool1 = Tool( 196 name="weather", description="Get weather report", parameters=weather_parameters, function=get_weather_report 197 ) 198 tool2 = Tool( 199 name="calculator", description="Calculate numbers", parameters=calculator_parameters, function=calculate 200 ) 201 202 toolset = Toolset([tool2]) 203 204 data = serialize_tools_or_toolset([tool1, toolset]) 205 206 assert isinstance(data, list) 207 assert len(data) == 2 208 assert data[0] == tool1.to_dict() 209 assert data[0]["type"] == "haystack.tools.tool.Tool" 210 assert data[0]["data"]["parameters"] == weather_parameters 211 assert data[1] == toolset.to_dict() 212 assert data[1]["type"] == "haystack.tools.toolset.Toolset" 213 assert data[1]["data"]["tools"][0]["data"]["parameters"] == calculator_parameters 214 215 def test_serialize_mixed_list_multiple_tools_and_toolsets(self): 216 """Test serialization of a mixed list with multiple Tools and a Toolset containing multiple tools.""" 217 tool1 = Tool( 218 name="weather", description="Get weather report", parameters=weather_parameters, function=get_weather_report 219 ) 220 tool2 = Tool( 221 name="calculator", description="Calculate numbers", parameters=calculator_parameters, function=calculate 222 ) 223 tool3 = Tool( 224 name="translator", description="Translate text", parameters=translator_parameters, function=translate_text 225 ) 226 tool4 = Tool( 227 name="summarizer", description="Summarize text", parameters=summarizer_parameters, function=summarize_text 228 ) 229 tool5 = Tool(name="formatter", description="Format text", parameters=formatter_parameters, function=format_text) 230 231 toolset = Toolset([tool4, tool5]) 232 233 data = serialize_tools_or_toolset([tool1, tool2, toolset, tool3]) 234 235 assert isinstance(data, list) 236 assert len(data) == 4 237 238 # Verify Tool 1 (weather) 239 assert data[0] == tool1.to_dict() 240 assert data[0]["type"] == "haystack.tools.tool.Tool" 241 assert data[0]["data"]["name"] == "weather" 242 assert data[0]["data"]["parameters"] == weather_parameters 243 244 # Verify Tool 2 (calculator) 245 assert data[1] == tool2.to_dict() 246 assert data[1]["type"] == "haystack.tools.tool.Tool" 247 assert data[1]["data"]["name"] == "calculator" 248 assert data[1]["data"]["parameters"] == calculator_parameters 249 250 # Verify Toolset (with summarizer and formatter) 251 assert data[2] == toolset.to_dict() 252 assert data[2]["type"] == "haystack.tools.toolset.Toolset" 253 assert len(data[2]["data"]["tools"]) == 2 254 assert data[2]["data"]["tools"][0]["data"]["name"] == "summarizer" 255 assert data[2]["data"]["tools"][0]["data"]["parameters"] == summarizer_parameters 256 assert data[2]["data"]["tools"][1]["data"]["name"] == "formatter" 257 assert data[2]["data"]["tools"][1]["data"]["parameters"] == formatter_parameters 258 259 # Verify Tool 3 (translator) 260 assert data[3] == tool3.to_dict() 261 assert data[3]["type"] == "haystack.tools.tool.Tool" 262 assert data[3]["data"]["name"] == "translator" 263 assert data[3]["data"]["parameters"] == translator_parameters 264 265 def test_deserialize_mixed_list_tools_and_toolsets_inplace(self): 266 """Test deserialization of a mixed list of Tool and Toolset instances.""" 267 tool1 = Tool( 268 name="weather", description="Get weather report", parameters=weather_parameters, function=get_weather_report 269 ) 270 tool2 = Tool( 271 name="calculator", description="Calculate numbers", parameters=calculator_parameters, function=calculate 272 ) 273 274 toolset = Toolset([tool2]) 275 276 data = {"tools": [tool1.to_dict(), toolset.to_dict()]} 277 deserialize_tools_or_toolset_inplace(data) 278 279 assert isinstance(data["tools"], list) 280 assert len(data["tools"]) == 2 281 282 # Verify Tool (weather) 283 assert isinstance(data["tools"][0], Tool) 284 assert data["tools"][0].name == "weather" 285 assert data["tools"][0].parameters == weather_parameters 286 assert data["tools"][0].function("Paris") == "Weather report for Paris: 20°C, sunny" 287 288 # Verify Toolset with calculator tool 289 assert isinstance(data["tools"][1], Toolset) 290 assert len(data["tools"][1]) == 1 291 assert data["tools"][1][0].name == "calculator" 292 assert data["tools"][1][0].parameters == calculator_parameters 293 assert data["tools"][1][0].function(10, 5, "add") == 15 294 assert data["tools"][1][0].function(10, 5, "multiply") == 50 295 296 def test_deserialize_mixed_list_multiple_tools_and_toolsets_inplace(self): 297 """Test deserialization of a mixed list with multiple Tools and a Toolset containing multiple tools.""" 298 tool1 = Tool( 299 name="weather", description="Get weather report", parameters=weather_parameters, function=get_weather_report 300 ) 301 tool2 = Tool( 302 name="calculator", description="Calculate numbers", parameters=calculator_parameters, function=calculate 303 ) 304 tool3 = Tool( 305 name="translator", description="Translate text", parameters=translator_parameters, function=translate_text 306 ) 307 tool4 = Tool( 308 name="summarizer", description="Summarize text", parameters=summarizer_parameters, function=summarize_text 309 ) 310 tool5 = Tool(name="formatter", description="Format text", parameters=formatter_parameters, function=format_text) 311 312 toolset = Toolset([tool4, tool5]) 313 314 data = {"tools": [tool1.to_dict(), tool2.to_dict(), toolset.to_dict(), tool3.to_dict()]} 315 deserialize_tools_or_toolset_inplace(data) 316 317 assert isinstance(data["tools"], list) 318 assert len(data["tools"]) == 4 319 320 # Verify Tool 1 (weather) 321 assert isinstance(data["tools"][0], Tool) 322 assert data["tools"][0].name == "weather" 323 assert data["tools"][0].parameters == weather_parameters 324 assert data["tools"][0].function("Berlin") == "Weather report for Berlin: 20°C, sunny" 325 326 # Verify Tool 2 (calculator) 327 assert isinstance(data["tools"][1], Tool) 328 assert data["tools"][1].name == "calculator" 329 assert data["tools"][1].parameters == calculator_parameters 330 assert data["tools"][1].function(5, 3, "add") == 8 331 assert data["tools"][1].function(5, 3, "multiply") == 15 332 333 # Verify Toolset (with summarizer and formatter) 334 assert isinstance(data["tools"][2], Toolset) 335 assert len(data["tools"][2]) == 2 336 assert data["tools"][2][0].name == "summarizer" 337 assert data["tools"][2][0].parameters == summarizer_parameters 338 assert data["tools"][2][0].function("Hello World", 5) == "Hello" 339 assert data["tools"][2][1].name == "formatter" 340 assert data["tools"][2][1].parameters == formatter_parameters 341 assert data["tools"][2][1].function("test", "bold") == "Formatted text in bold style: test" 342 343 # Verify Tool 3 (translator) 344 assert isinstance(data["tools"][3], Tool) 345 assert data["tools"][3].name == "translator" 346 assert data["tools"][3].parameters == translator_parameters 347 assert data["tools"][3].function("Hello", "Spanish") == "Translated 'Hello' to Spanish" 348 349 def test_serialize_none_returns_none(self): 350 """Test that serializing None returns None.""" 351 data = serialize_tools_or_toolset(None) 352 assert data is None 353 354 def test_serialize_empty_list_of_toolsets(self): 355 """Test that serializing an empty list of Toolsets returns an empty list.""" 356 data = serialize_tools_or_toolset([]) 357 assert data == []