test_tool_invoker.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import datetime 6 import json 7 import time 8 from typing import Any 9 from unittest.mock import patch 10 11 import pytest 12 13 from haystack import Document, Pipeline 14 from haystack.components.agents.state import State 15 from haystack.components.builders.prompt_builder import PromptBuilder 16 from haystack.components.generators.chat.openai import OpenAIChatGenerator 17 from haystack.components.generators.utils import print_streaming_chunk 18 from haystack.components.tools.tool_invoker import ( 19 ResultConversionError, 20 StringConversionError, 21 ToolInvoker, 22 ToolNotFoundException, 23 ToolOutputMergeError, 24 ) 25 from haystack.dataclasses import ( 26 ChatMessage, 27 ChatRole, 28 ImageContent, 29 StreamingChunk, 30 TextContent, 31 ToolCall, 32 ToolCallResult, 33 ) 34 from haystack.tools import ComponentTool, Tool, Toolset 35 from haystack.tools.errors import ToolInvocationError 36 37 38 def weather_function(location): 39 weather_info = { 40 "Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"}, 41 "Paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"}, 42 "Rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"}, 43 } 44 return weather_info.get(location, {"weather": "unknown", "temperature": 0, "unit": "celsius"}) 45 46 47 weather_parameters = {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]} 48 49 50 @pytest.fixture 51 def weather_tool(): 52 return Tool( 53 name="weather_tool", 54 description="Provides weather information for a given location.", 55 parameters=weather_parameters, 56 function=weather_function, 57 ) 58 59 60 @pytest.fixture 61 def weather_tool_with_outputs_to_state(): 62 return Tool( 63 name="weather_tool", 64 description="Provides weather information for a given location.", 65 parameters=weather_parameters, 66 function=weather_function, 67 outputs_to_state={"weather": {"source": "weather"}}, 68 ) 69 70 71 @pytest.fixture 72 def faulty_tool(): 73 def faulty_tool_func(location): 74 raise Exception("This tool always fails.") 75 76 faulty_tool_parameters = { 77 "type": "object", 78 "properties": {"location": {"type": "string"}}, 79 "required": ["location"], 80 } 81 82 return Tool( 83 name="faulty_tool", 84 description="A tool that always fails when invoked.", 85 parameters=faulty_tool_parameters, 86 function=faulty_tool_func, 87 ) 88 89 90 def add_function(num1: int, num2: int): 91 return num1 + num2 92 93 94 @pytest.fixture 95 def tool_set(): 96 return Toolset( 97 tools=[ 98 Tool( 99 name="weather_tool", 100 description="Provides weather information for a given location.", 101 parameters=weather_parameters, 102 function=weather_function, 103 ), 104 Tool( 105 name="addition_tool", 106 description="A tool that adds two numbers.", 107 parameters={ 108 "type": "object", 109 "properties": {"num1": {"type": "integer"}, "num2": {"type": "integer"}}, 110 "required": ["num1", "num2"], 111 }, 112 function=add_function, 113 ), 114 ] 115 ) 116 117 118 @pytest.fixture 119 def invoker(weather_tool): 120 return ToolInvoker(tools=[weather_tool], raise_on_failure=True, convert_result_to_json_string=False) 121 122 123 @pytest.fixture 124 def faulty_invoker(faulty_tool): 125 return ToolInvoker(tools=[faulty_tool], raise_on_failure=True, convert_result_to_json_string=False) 126 127 128 class WarmupTrackingTool(Tool): 129 """A tool that tracks whether warm_up was called.""" 130 131 def __init__(self, *args, **kwargs): 132 super().__init__(*args, **kwargs) 133 self.was_warmed_up = False 134 135 def warm_up(self): 136 self.was_warmed_up = True 137 138 139 class WarmupTrackingToolset(Toolset): 140 """A toolset that tracks whether warm_up was called.""" 141 142 def __init__(self, tools): 143 super().__init__(tools) 144 self.was_warmed_up = False 145 146 def warm_up(self): 147 self.was_warmed_up = True 148 # Call parent to warm up individual tools 149 super().warm_up() 150 151 152 class TestToolInvokerCore: 153 def test_init(self, weather_tool): 154 invoker = ToolInvoker(tools=[weather_tool]) 155 156 assert invoker.tools == [weather_tool] 157 assert invoker._tools_with_names == {"weather_tool": weather_tool} 158 assert invoker.raise_on_failure 159 assert not invoker.convert_result_to_json_string 160 161 def test_validate_and_prepare_tools(self, weather_tool, faulty_tool): 162 result = ToolInvoker._validate_and_prepare_tools([weather_tool, faulty_tool]) 163 assert result == {"weather_tool": weather_tool, "faulty_tool": faulty_tool} 164 165 toolset = Toolset([weather_tool, faulty_tool]) 166 result = ToolInvoker._validate_and_prepare_tools(toolset) 167 assert result == {"weather_tool": weather_tool, "faulty_tool": faulty_tool} 168 169 def test_validate_and_prepare_tools_validation_failures(self, weather_tool): 170 with pytest.raises(ValueError): 171 ToolInvoker._validate_and_prepare_tools([]) 172 173 with pytest.raises(ValueError): 174 ToolInvoker._validate_and_prepare_tools([weather_tool, weather_tool]) 175 176 def test_inject_state_args_no_tool_inputs(self, invoker): 177 weather_tool = Tool( 178 name="weather_tool", 179 description="Provides weather information for a given location.", 180 parameters=weather_parameters, 181 function=weather_function, 182 ) 183 state = State(schema={"location": {"type": str}}, data={"location": "Berlin"}) 184 args = invoker._inject_state_args(tool=weather_tool, llm_args={}, state=state) 185 assert args == {"location": "Berlin"} 186 187 def test_inject_state_args_no_tool_inputs_component_tool(self, invoker): 188 comp = PromptBuilder(template="Hello, {{name}}!") 189 prompt_tool = ComponentTool( 190 component=comp, name="prompt_tool", description="Creates a personalized greeting prompt." 191 ) 192 state = State(schema={"name": {"type": str}}, data={"name": "James"}) 193 args = invoker._inject_state_args(tool=prompt_tool, llm_args={}, state=state) 194 assert args == {"name": "James"} 195 196 def test_inject_state_args_with_tool_inputs(self, invoker): 197 weather_tool = Tool( 198 name="weather_tool", 199 description="Provides weather information for a given location.", 200 parameters=weather_parameters, 201 function=weather_function, 202 inputs_from_state={"loc": "location"}, 203 ) 204 state = State(schema={"location": {"type": str}}, data={"loc": "Berlin"}) 205 args = invoker._inject_state_args(tool=weather_tool, llm_args={}, state=state) 206 assert args == {"location": "Berlin"} 207 208 def test_inject_state_args_param_in_state_and_llm(self, invoker): 209 weather_tool = Tool( 210 name="weather_tool", 211 description="Provides weather information for a given location.", 212 parameters=weather_parameters, 213 function=weather_function, 214 ) 215 state = State(schema={"location": {"type": str}}, data={"location": "Berlin"}) 216 args = invoker._inject_state_args(tool=weather_tool, llm_args={"location": "Paris"}, state=state) 217 assert args == {"location": "Paris"} 218 219 def test_inject_state_args_injects_state_object_for_state_annotated_param(self, invoker): 220 def function_with_state(city: str, state: State) -> str: 221 return f"Weather in {city}" 222 223 state_tool = Tool( 224 name="state_tool", 225 description="A tool that receives the live State object.", 226 parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, 227 function=function_with_state, 228 ) 229 state = State(schema={"city": {"type": str}}, data={"city": "Berlin"}) 230 args = invoker._inject_state_args(tool=state_tool, llm_args={"city": "Paris"}, state=state) 231 assert args["city"] == "Paris" 232 assert args["state"] is state 233 234 def test_inject_state_args_injects_state_object_for_optional_state_annotated_param(self, invoker): 235 def function_with_optional_state(city: str, state: State | None = None) -> str: 236 return f"Weather in {city}" 237 238 state_tool = Tool( 239 name="state_tool", 240 description="A tool that receives an optional State object.", 241 parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, 242 function=function_with_optional_state, 243 ) 244 state = State(schema={}) 245 args = invoker._inject_state_args(tool=state_tool, llm_args={"city": "Paris"}, state=state) 246 assert args["city"] == "Paris" 247 assert args["state"] is state 248 249 250 class TestToolInvokerSerde: 251 def test_to_dict(self, invoker, weather_tool): 252 data = invoker.to_dict() 253 assert data == { 254 "type": "haystack.components.tools.tool_invoker.ToolInvoker", 255 "init_parameters": { 256 "tools": [weather_tool.to_dict()], 257 "raise_on_failure": True, 258 "convert_result_to_json_string": False, 259 "enable_streaming_callback_passthrough": False, 260 "streaming_callback": None, 261 "max_workers": 4, 262 }, 263 } 264 265 def test_to_dict_with_params(self, weather_tool): 266 invoker = ToolInvoker( 267 tools=[weather_tool], 268 raise_on_failure=False, 269 convert_result_to_json_string=True, 270 enable_streaming_callback_passthrough=True, 271 streaming_callback=print_streaming_chunk, 272 ) 273 274 assert invoker.to_dict() == { 275 "type": "haystack.components.tools.tool_invoker.ToolInvoker", 276 "init_parameters": { 277 "tools": [weather_tool.to_dict()], 278 "raise_on_failure": False, 279 "convert_result_to_json_string": True, 280 "enable_streaming_callback_passthrough": True, 281 "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", 282 "max_workers": 4, 283 }, 284 } 285 286 def test_from_dict(self, weather_tool): 287 data = { 288 "type": "haystack.components.tools.tool_invoker.ToolInvoker", 289 "init_parameters": { 290 "tools": [weather_tool.to_dict()], 291 "raise_on_failure": True, 292 "convert_result_to_json_string": False, 293 "enable_streaming_callback_passthrough": False, 294 "streaming_callback": None, 295 }, 296 } 297 invoker = ToolInvoker.from_dict(data) 298 assert invoker.tools == [weather_tool] 299 assert invoker._tools_with_names == {"weather_tool": weather_tool} 300 assert invoker.raise_on_failure 301 assert not invoker.convert_result_to_json_string 302 assert invoker.streaming_callback is None 303 assert invoker.enable_streaming_callback_passthrough is False 304 305 def test_from_dict_with_streaming_callback(self, weather_tool): 306 data = { 307 "type": "haystack.components.tools.tool_invoker.ToolInvoker", 308 "init_parameters": { 309 "tools": [weather_tool.to_dict()], 310 "raise_on_failure": True, 311 "convert_result_to_json_string": False, 312 "enable_streaming_callback_passthrough": True, 313 "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", 314 }, 315 } 316 invoker = ToolInvoker.from_dict(data) 317 assert invoker.tools == [weather_tool] 318 assert invoker._tools_with_names == {"weather_tool": weather_tool} 319 assert invoker.raise_on_failure 320 assert not invoker.convert_result_to_json_string 321 assert invoker.streaming_callback == print_streaming_chunk 322 assert invoker.enable_streaming_callback_passthrough is True 323 324 def test_serde_in_pipeline(self, invoker, monkeypatch): 325 monkeypatch.setenv("OPENAI_API_KEY", "test-key") 326 327 pipeline = Pipeline() 328 pipeline.add_component("invoker", invoker) 329 pipeline.add_component("chatgenerator", OpenAIChatGenerator()) 330 pipeline.connect("invoker", "chatgenerator") 331 332 pipeline_dict = pipeline.to_dict() 333 # Verify the chatgenerator component structure (model will be whatever the default is) 334 chatgen = pipeline_dict["components"]["chatgenerator"] 335 assert chatgen["type"] == "haystack.components.generators.chat.openai.OpenAIChatGenerator" 336 assert "model" in chatgen["init_parameters"] 337 model_name = chatgen["init_parameters"]["model"] 338 339 # Build expected dict with dynamic model value 340 expected = { 341 "metadata": {}, 342 "connection_type_validation": True, 343 "max_runs_per_component": 100, 344 "components": { 345 "invoker": { 346 "type": "haystack.components.tools.tool_invoker.ToolInvoker", 347 "init_parameters": { 348 "tools": [ 349 { 350 "type": "haystack.tools.tool.Tool", 351 "data": { 352 "name": "weather_tool", 353 "description": "Provides weather information for a given location.", 354 "parameters": { 355 "type": "object", 356 "properties": {"location": {"type": "string"}}, 357 "required": ["location"], 358 }, 359 "function": "tools.test_tool_invoker.weather_function", 360 "outputs_to_string": None, 361 "inputs_from_state": None, 362 "outputs_to_state": None, 363 }, 364 } 365 ], 366 "raise_on_failure": True, 367 "convert_result_to_json_string": False, 368 "enable_streaming_callback_passthrough": False, 369 "streaming_callback": None, 370 "max_workers": 4, 371 }, 372 }, 373 "chatgenerator": { 374 "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", 375 "init_parameters": { 376 "model": model_name, 377 "streaming_callback": None, 378 "api_base_url": None, 379 "organization": None, 380 "generation_kwargs": {}, 381 "max_retries": None, 382 "timeout": None, 383 "api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True}, 384 "tools": None, 385 "tools_strict": False, 386 "http_client_kwargs": None, 387 }, 388 }, 389 }, 390 "connections": [{"sender": "invoker.tool_messages", "receiver": "chatgenerator.messages"}], 391 } 392 assert pipeline_dict == expected 393 394 pipeline_yaml = pipeline.dumps() 395 396 new_pipeline = Pipeline.loads(pipeline_yaml) 397 assert new_pipeline == pipeline 398 399 400 class TestToolInvokerRun: 401 def test_run_with_streaming_callback_finish_reason(self, invoker): 402 streaming_chunks = [] 403 404 def streaming_callback(chunk: StreamingChunk) -> None: 405 streaming_chunks.append(chunk) 406 407 tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}) 408 message = ChatMessage.from_assistant(tool_calls=[tool_call]) 409 410 result = invoker.run(messages=[message], streaming_callback=streaming_callback) 411 assert "tool_messages" in result 412 assert len(result["tool_messages"]) == 1 413 414 # Check that we received streaming chunks 415 assert len(streaming_chunks) >= 2 # At least one for tool result and one for finish reason 416 417 # The last chunk should have finish_reason set to "tool_call_results" 418 final_chunk = streaming_chunks[-1] 419 assert final_chunk.finish_reason == "tool_call_results" 420 assert final_chunk.meta["finish_reason"] == "tool_call_results" 421 assert final_chunk.content == "" 422 423 @pytest.mark.asyncio 424 async def test_run_async_with_streaming_callback_finish_reason(self, weather_tool): 425 streaming_chunks = [] 426 427 async def streaming_callback(chunk: StreamingChunk) -> None: 428 streaming_chunks.append(chunk) 429 430 tool_invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=True, convert_result_to_json_string=False) 431 432 tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}) 433 message = ChatMessage.from_assistant(tool_calls=[tool_call]) 434 435 result = await tool_invoker.run_async(messages=[message], streaming_callback=streaming_callback) 436 assert "tool_messages" in result 437 assert len(result["tool_messages"]) == 1 438 439 # Check that we received streaming chunks 440 assert len(streaming_chunks) >= 2 # At least one for tool result and one for finish reason 441 442 # The last chunk should have finish_reason set to "tool_call_results" 443 final_chunk = streaming_chunks[-1] 444 assert final_chunk.finish_reason == "tool_call_results" 445 assert final_chunk.meta["finish_reason"] == "tool_call_results" 446 assert final_chunk.content == "" 447 448 def test_enable_streaming_callback_passthrough(self, monkeypatch): 449 monkeypatch.setenv("OPENAI_API_KEY", "test-key") 450 llm_tool = ComponentTool( 451 component=OpenAIChatGenerator(), 452 name="chat_generator_tool", 453 description="A tool that generates chat messages using OpenAI's GPT model.", 454 ) 455 invoker = ToolInvoker( 456 tools=[llm_tool], enable_streaming_callback_passthrough=True, streaming_callback=print_streaming_chunk 457 ) 458 with patch("haystack.components.generators.chat.OpenAIChatGenerator.run") as mock_run: 459 mock_run.return_value = {"replies": [ChatMessage.from_assistant("Hello! How can I help you?")]} 460 invoker.run( 461 messages=[ 462 ChatMessage.from_assistant( 463 tool_calls=[ 464 ToolCall( 465 tool_name="chat_generator_tool", 466 arguments={"messages": [{"role": "user", "content": [{"text": "Hello!"}]}]}, 467 id="12345", 468 ) 469 ] 470 ) 471 ] 472 ) 473 mock_run.assert_called_once_with( 474 messages=[ChatMessage.from_user(text="Hello!")], streaming_callback=print_streaming_chunk 475 ) 476 477 def test_enable_streaming_callback_passthrough_runtime(self, monkeypatch): 478 monkeypatch.setenv("OPENAI_API_KEY", "test-key") 479 llm_tool = ComponentTool( 480 component=OpenAIChatGenerator(), 481 name="chat_generator_tool", 482 description="A tool that generates chat messages using OpenAI's GPT model.", 483 ) 484 invoker = ToolInvoker( 485 tools=[llm_tool], enable_streaming_callback_passthrough=True, streaming_callback=print_streaming_chunk 486 ) 487 with patch("haystack.components.generators.chat.OpenAIChatGenerator.run") as mock_run: 488 mock_run.return_value = {"replies": [ChatMessage.from_assistant("Hello! How can I help you?")]} 489 invoker.run( 490 messages=[ 491 ChatMessage.from_assistant( 492 tool_calls=[ 493 ToolCall( 494 tool_name="chat_generator_tool", 495 arguments={"messages": [{"role": "user", "content": [{"text": "Hello!"}]}]}, 496 id="12345", 497 ) 498 ] 499 ) 500 ], 501 enable_streaming_callback_passthrough=False, 502 ) 503 mock_run.assert_called_once_with(messages=[ChatMessage.from_user(text="Hello!")]) 504 505 def test_run_no_messages(self, invoker): 506 result = invoker.run(messages=[]) 507 assert result["tool_messages"] == [] 508 509 def test_run_no_tool_calls(self, invoker): 510 user_message = ChatMessage.from_user(text="Hello!") 511 assistant_message = ChatMessage.from_assistant(text="How can I help you?") 512 513 result = invoker.run(messages=[user_message, assistant_message]) 514 assert result["tool_messages"] == [] 515 516 def test_run_multiple_tool_calls(self, invoker): 517 tool_calls = [ 518 ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}), 519 ToolCall(tool_name="weather_tool", arguments={"location": "Paris"}), 520 ToolCall(tool_name="weather_tool", arguments={"location": "Rome"}), 521 ] 522 message = ChatMessage.from_assistant(tool_calls=tool_calls) 523 524 result = invoker.run(messages=[message]) 525 assert "tool_messages" in result 526 assert len(result["tool_messages"]) == 3 527 528 for i, tool_message in enumerate(result["tool_messages"]): 529 assert isinstance(tool_message, ChatMessage) 530 assert tool_message.is_from(ChatRole.TOOL) 531 532 assert tool_message.tool_call_results 533 tool_call_result = tool_message.tool_call_result 534 535 assert isinstance(tool_call_result, ToolCallResult) 536 assert not tool_call_result.error 537 assert tool_call_result.origin == tool_calls[i] 538 539 def test_run_tool_calls_with_empty_args(self): 540 hello_world_tool = Tool( 541 name="hello_world", 542 description="A tool that returns a greeting.", 543 parameters={"type": "object", "properties": {}}, 544 function=lambda: "Hello, world!", 545 ) 546 invoker = ToolInvoker(tools=[hello_world_tool]) 547 548 tool_call = ToolCall(tool_name="hello_world", arguments={}) 549 tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call]) 550 551 result = invoker.run(messages=[tool_call_message]) 552 assert "tool_messages" in result 553 assert len(result["tool_messages"]) == 1 554 555 tool_message = result["tool_messages"][0] 556 assert isinstance(tool_message, ChatMessage) 557 assert tool_message.is_from(ChatRole.TOOL) 558 559 assert tool_message.tool_call_results 560 tool_call_result = tool_message.tool_call_result 561 562 assert isinstance(tool_call_result, ToolCallResult) 563 assert not tool_call_result.error 564 565 assert tool_call_result.result == "Hello, world!" 566 567 def test_run_with_tools_override(self, weather_tool, faulty_tool): 568 """Tests that tools passed to run override the tools passed in init""" 569 invoker = ToolInvoker(tools=[faulty_tool]) 570 assert invoker._tools_with_names == {"faulty_tool": faulty_tool} 571 tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}) 572 message = ChatMessage.from_assistant(tool_calls=[tool_call]) 573 574 result = invoker.run(messages=[message], tools=[weather_tool]) 575 576 tool_message = result["tool_messages"][0] 577 tool_call_result = tool_message.tool_call_result 578 assert not tool_call_result.error 579 assert tool_call_result.result == str({"weather": "mostly sunny", "temperature": 7, "unit": "celsius"}) 580 assert tool_call_result.origin == tool_call 581 582 @pytest.mark.asyncio 583 async def test_run_async_with_tools_override(self, weather_tool, faulty_tool): 584 """Tests that tools passed to run_async override the tools passed in init""" 585 invoker = ToolInvoker(tools=[faulty_tool]) 586 assert invoker._tools_with_names == {"faulty_tool": faulty_tool} 587 tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}) 588 message = ChatMessage.from_assistant(tool_calls=[tool_call]) 589 590 result = await invoker.run_async(messages=[message], tools=[weather_tool]) 591 tool_message = result["tool_messages"][0] 592 tool_call_result = tool_message.tool_call_result 593 assert not tool_call_result.error 594 assert tool_call_result.result == str({"weather": "mostly sunny", "temperature": 7, "unit": "celsius"}) 595 assert tool_call_result.origin == tool_call 596 597 def test_parallel_tool_calling_with_state_updates(self): 598 """Test that parallel tool execution with state updates works correctly with the state lock.""" 599 # Create a shared counter variable to simulate a state value that gets updated 600 execution_log = [] 601 602 def function_1(): 603 time.sleep(0.01) 604 execution_log.append("tool_1_executed") 605 return {"counter": 1, "tool_name": "tool_1"} 606 607 def function_2(): 608 time.sleep(0.01) 609 execution_log.append("tool_2_executed") 610 return {"counter": 2, "tool_name": "tool_2"} 611 612 def function_3(): 613 time.sleep(0.01) 614 execution_log.append("tool_3_executed") 615 return {"counter": 3, "tool_name": "tool_3"} 616 617 # Create tools that all update the same state key 618 tool_1 = Tool( 619 name="state_tool_1", 620 description="A tool that updates state counter", 621 parameters={"type": "object", "properties": {}}, 622 function=function_1, 623 outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}}, 624 ) 625 626 tool_2 = Tool( 627 name="state_tool_2", 628 description="A tool that updates state counter", 629 parameters={"type": "object", "properties": {}}, 630 function=function_2, 631 outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}}, 632 ) 633 634 tool_3 = Tool( 635 name="state_tool_3", 636 description="A tool that updates state counter", 637 parameters={"type": "object", "properties": {}}, 638 function=function_3, 639 outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}}, 640 ) 641 642 # Create ToolInvoker with all three tools 643 invoker = ToolInvoker(tools=[tool_1, tool_2, tool_3], raise_on_failure=True) 644 645 state = State(schema={"counter": {"type": int}, "last_tool": {"type": str}}) 646 tool_calls = [ 647 ToolCall(tool_name="state_tool_1", arguments={}), 648 ToolCall(tool_name="state_tool_2", arguments={}), 649 ToolCall(tool_name="state_tool_3", arguments={}), 650 ] 651 message = ChatMessage.from_assistant(tool_calls=tool_calls) 652 _ = invoker.run(messages=[message], state=state) 653 654 # Verify that all three tools were executed 655 assert len(execution_log) == 3 656 assert "tool_1_executed" in execution_log 657 assert "tool_2_executed" in execution_log 658 assert "tool_3_executed" in execution_log 659 660 # Verify that the state was updated correctly 661 # Due to parallel execution, we can't predict which tool will be the last to update 662 assert state.has("counter") 663 assert state.has("last_tool") 664 assert state.get("counter") in [1, 2, 3] # Should be one of the tool values 665 assert state.get("last_tool") in ["tool_1", "tool_2", "tool_3"] # Should be one of the tool names 666 667 @pytest.mark.asyncio 668 async def test_async_parallel_tool_calling_with_state_updates(self): 669 """Test that parallel tool execution with state updates works correctly with the state lock.""" 670 # Create a shared counter variable to simulate a state value that gets updated 671 execution_log = [] 672 673 def function_1(): 674 time.sleep(0.01) 675 execution_log.append("tool_1_executed") 676 return {"counter": 1, "tool_name": "tool_1"} 677 678 def function_2(): 679 time.sleep(0.01) 680 execution_log.append("tool_2_executed") 681 return {"counter": 2, "tool_name": "tool_2"} 682 683 def function_3(): 684 time.sleep(0.01) 685 execution_log.append("tool_3_executed") 686 return {"counter": 3, "tool_name": "tool_3"} 687 688 # Create tools that all update the same state key 689 tool_1 = Tool( 690 name="state_tool_1", 691 description="A tool that updates state counter", 692 parameters={"type": "object", "properties": {}}, 693 function=function_1, 694 outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}}, 695 ) 696 697 tool_2 = Tool( 698 name="state_tool_2", 699 description="A tool that updates state counter", 700 parameters={"type": "object", "properties": {}}, 701 function=function_2, 702 outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}}, 703 ) 704 705 tool_3 = Tool( 706 name="state_tool_3", 707 description="A tool that updates state counter", 708 parameters={"type": "object", "properties": {}}, 709 function=function_3, 710 outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}}, 711 ) 712 713 # Create ToolInvoker with all three tools 714 invoker = ToolInvoker(tools=[tool_1, tool_2, tool_3], raise_on_failure=True) 715 716 state = State(schema={"counter": {"type": int}, "last_tool": {"type": str}}) 717 tool_calls = [ 718 ToolCall(tool_name="state_tool_1", arguments={}), 719 ToolCall(tool_name="state_tool_2", arguments={}), 720 ToolCall(tool_name="state_tool_3", arguments={}), 721 ] 722 message = ChatMessage.from_assistant(tool_calls=tool_calls) 723 _ = await invoker.run_async(messages=[message], state=state) 724 725 # Verify that all three tools were executed 726 assert len(execution_log) == 3 727 assert "tool_1_executed" in execution_log 728 assert "tool_2_executed" in execution_log 729 assert "tool_3_executed" in execution_log 730 731 # Verify that the state was updated correctly 732 # Due to parallel execution, we can't predict which tool will be the last to update 733 assert state.has("counter") 734 assert state.has("last_tool") 735 assert state.get("counter") in [1, 2, 3] # Should be one of the tool values 736 assert state.get("last_tool") in ["tool_1", "tool_2", "tool_3"] # Should be one of the tool names 737 738 def test_call_invoker_two_subsequent_run_calls(self, invoker: ToolInvoker): 739 tool_calls = [ 740 ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}), 741 ToolCall(tool_name="weather_tool", arguments={"location": "Paris"}), 742 ToolCall(tool_name="weather_tool", arguments={"location": "Rome"}), 743 ] 744 message = ChatMessage.from_assistant(tool_calls=tool_calls) 745 746 streaming_callback_called = False 747 748 def streaming_callback(chunk: StreamingChunk) -> None: 749 nonlocal streaming_callback_called 750 streaming_callback_called = True 751 752 # First call 753 result_1 = invoker.run(messages=[message], streaming_callback=streaming_callback) 754 assert "tool_messages" in result_1 755 assert len(result_1["tool_messages"]) == 3 756 757 # Second call 758 result_2 = invoker.run(messages=[message], streaming_callback=streaming_callback) 759 assert "tool_messages" in result_2 760 assert len(result_2["tool_messages"]) == 3 761 762 @pytest.mark.asyncio 763 async def test_call_invoker_two_subsequent_run_async_calls(self, invoker: ToolInvoker): 764 tool_calls = [ 765 ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}), 766 ToolCall(tool_name="weather_tool", arguments={"location": "Paris"}), 767 ToolCall(tool_name="weather_tool", arguments={"location": "Rome"}), 768 ] 769 message = ChatMessage.from_assistant(tool_calls=tool_calls) 770 771 streaming_callback_called = False 772 773 async def streaming_callback(chunk: StreamingChunk) -> None: 774 nonlocal streaming_callback_called 775 streaming_callback_called = True 776 777 # First call 778 result_1 = await invoker.run_async(messages=[message], streaming_callback=streaming_callback) 779 assert "tool_messages" in result_1 780 assert len(result_1["tool_messages"]) == 3 781 782 # Second call 783 result_2 = await invoker.run_async(messages=[message], streaming_callback=streaming_callback) 784 assert "tool_messages" in result_2 785 assert len(result_2["tool_messages"]) == 3 786 787 def test_run_injects_state_object_into_tool(self): 788 received_state = {} 789 790 def function_with_state(city: str, state: State) -> str: 791 received_state["state"] = state 792 return f"Weather in {city}: sunny" 793 794 state_tool = Tool( 795 name="state_tool", 796 description="A tool that receives the live State object.", 797 parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, 798 function=function_with_state, 799 ) 800 invoker = ToolInvoker(tools=[state_tool]) 801 state = State(schema={"city": {"type": str}}) 802 803 tool_call = ToolCall(tool_name="state_tool", arguments={"city": "Berlin"}) 804 message = ChatMessage.from_assistant(tool_calls=[tool_call]) 805 result = invoker.run(messages=[message], state=state) 806 807 assert len(result["tool_messages"]) == 1 808 assert not result["tool_messages"][0].tool_call_results[0].error 809 assert received_state["state"] is state 810 811 @pytest.mark.asyncio 812 async def test_run_async_injects_state_object_into_tool(self): 813 received_state = {} 814 815 def function_with_state(city: str, state: State) -> str: 816 received_state["state"] = state 817 return f"Weather in {city}: sunny" 818 819 state_tool = Tool( 820 name="state_tool", 821 description="A tool that receives the live State object.", 822 parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}, 823 function=function_with_state, 824 ) 825 invoker = ToolInvoker(tools=[state_tool]) 826 state = State(schema={"city": {"type": str}}) 827 828 tool_call = ToolCall(tool_name="state_tool", arguments={"city": "Berlin"}) 829 message = ChatMessage.from_assistant(tool_calls=[tool_call]) 830 result = await invoker.run_async(messages=[message], state=state) 831 832 assert len(result["tool_messages"]) == 1 833 assert not result["tool_messages"][0].tool_call_results[0].error 834 assert received_state["state"] is state 835 836 837 class TestToolInvokerErrorHandling: 838 def test_tool_not_found_error(self, invoker): 839 tool_call = ToolCall(tool_name="non_existent_tool", arguments={"location": "Berlin"}) 840 tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call]) 841 842 with pytest.raises(ToolNotFoundException): 843 invoker.run(messages=[tool_call_message]) 844 845 def test_tool_not_found_does_not_raise_exception(self, weather_tool): 846 invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=False, convert_result_to_json_string=False) 847 848 tool_call = ToolCall(tool_name="non_existent_tool", arguments={"location": "Berlin"}) 849 tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call]) 850 851 result = invoker.run(messages=[tool_call_message]) 852 tool_message = result["tool_messages"][0] 853 854 assert tool_message.tool_call_results[0].error 855 assert "not found" in tool_message.tool_call_results[0].result 856 857 def test_tool_invocation_error(self, faulty_invoker): 858 tool_call = ToolCall(tool_name="faulty_tool", arguments={"location": "Berlin"}) 859 tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call]) 860 861 with pytest.raises(ToolInvocationError): 862 faulty_invoker.run(messages=[tool_call_message]) 863 864 def test_tool_invocation_error_does_not_raise_exception(self, faulty_tool): 865 faulty_invoker = ToolInvoker(tools=[faulty_tool], raise_on_failure=False, convert_result_to_json_string=False) 866 867 tool_call = ToolCall(tool_name="faulty_tool", arguments={"location": "Berlin"}) 868 tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call]) 869 870 result = faulty_invoker.run(messages=[tool_call_message]) 871 tool_message = result["tool_messages"][0] 872 assert tool_message.tool_call_results[0].error 873 assert "Failed to invoke" in tool_message.tool_call_results[0].result 874 875 def test_outputs_to_string_with_multiple_outputs(self): 876 weather_tool = Tool( 877 name="weather_tool", 878 description="Provides weather information for a given location.", 879 parameters=weather_parameters, 880 function=weather_function, 881 # Pass config that selects two of three outputs 882 outputs_to_string={"weather": {"source": "weather"}, "temp": {"source": "temperature"}}, 883 ) 884 invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=True) 885 886 tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}) 887 888 tool_result = {"weather": "sunny", "temperature": 25, "unit": "celsius"} 889 chat_message = invoker._prepare_tool_result_message( 890 result=tool_result, tool_call=tool_call, tool_to_invoke=weather_tool 891 ) 892 assert chat_message.tool_call_results[0].result == "{'weather': 'sunny', 'temp': '25'}" 893 894 def test_string_conversion_error(self): 895 weather_tool = Tool( 896 name="weather_tool", 897 description="Provides weather information for a given location.", 898 parameters=weather_parameters, 899 function=weather_function, 900 # Pass custom handler that will throw an error when trying to convert tool_result 901 outputs_to_string={"handler": json.dumps}, 902 ) 903 invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=True) 904 905 tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}) 906 907 tool_result = datetime.datetime.now() 908 with pytest.raises(StringConversionError): 909 invoker._prepare_tool_result_message(result=tool_result, tool_call=tool_call, tool_to_invoke=weather_tool) 910 911 def test_string_conversion_error_does_not_raise_exception(self): 912 weather_tool = Tool( 913 name="weather_tool", 914 description="Provides weather information for a given location.", 915 parameters=weather_parameters, 916 function=weather_function, 917 # Pass custom handler that will throw an error when trying to convert tool_result 918 outputs_to_string={"handler": json.dumps}, 919 ) 920 invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=False) 921 922 tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}) 923 924 tool_result = datetime.datetime.now() 925 tool_message = invoker._prepare_tool_result_message( 926 result=tool_result, tool_call=tool_call, tool_to_invoke=weather_tool 927 ) 928 929 assert tool_message.tool_call_results[0].error 930 assert "Failed to convert" in tool_message.tool_call_results[0].result 931 932 def test_result_conversion_error(self): 933 def handler(result): 934 raise ValueError("Handler error") 935 936 weather_tool = Tool( 937 name="weather_tool", 938 description="Provides weather information for a given location.", 939 parameters=weather_parameters, 940 function=weather_function, 941 outputs_to_string={"handler": handler, "raw_result": True}, 942 ) 943 invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=True) 944 945 tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}) 946 947 tool_result = "something" 948 with pytest.raises(ResultConversionError): 949 invoker._prepare_tool_result_message(result=tool_result, tool_call=tool_call, tool_to_invoke=weather_tool) 950 951 def test_result_conversion_error_does_not_raise_exception(self): 952 def handler(result): 953 raise ValueError("Handler error") 954 955 weather_tool = Tool( 956 name="weather_tool", 957 description="Provides weather information for a given location.", 958 parameters=weather_parameters, 959 function=weather_function, 960 outputs_to_string={"handler": handler, "raw_result": True}, 961 ) 962 invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=False) 963 964 tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}) 965 966 tool_result = "something" 967 tool_message = invoker._prepare_tool_result_message( 968 result=tool_result, tool_call=tool_call, tool_to_invoke=weather_tool 969 ) 970 assert tool_message.tool_call_results[0].error 971 assert "Failed to convert" in tool_message.tool_call_results[0].result 972 973 def test_run_state_merge_error_handled_gracefully(self, weather_tool_with_outputs_to_state): 974 class ProblematicState(State): 975 def set(self, key: str, value: Any, handler_override=None): 976 # Simulate a State error during merging 977 raise ValueError("State set operation failed") 978 979 state = ProblematicState(schema={"test_key": {"type": str}}) 980 invoker = ToolInvoker(tools=[weather_tool_with_outputs_to_state], raise_on_failure=False) 981 982 tool_calls = [ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})] 983 message = ChatMessage.from_assistant(tool_calls=tool_calls) 984 985 result = invoker.run(messages=[message], state=state) 986 987 assert "tool_messages" in result 988 assert len(result["tool_messages"]) == 1 989 assert result["tool_messages"][0].tool_call_results[0].error is True 990 assert ( 991 "Failed to merge tool outputs from tool weather_tool into State" 992 in result["tool_messages"][0].tool_call_results[0].result 993 ) 994 995 def test_run_state_merge_error_raises_when_configured(self, weather_tool_with_outputs_to_state): 996 class ProblematicState(State): 997 def set(self, key: str, value: Any, handler_override=None): 998 # Simulate a State error during merging 999 raise ValueError("State set operation failed") 1000 1001 state = ProblematicState(schema={"test_key": {"type": str}}) 1002 invoker = ToolInvoker(tools=[weather_tool_with_outputs_to_state], raise_on_failure=True) 1003 1004 tool_calls = [ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})] 1005 message = ChatMessage.from_assistant(tool_calls=tool_calls) 1006 1007 with pytest.raises(ToolOutputMergeError, match="Failed to merge"): 1008 invoker.run(messages=[message], state=state) 1009 1010 @pytest.mark.asyncio 1011 async def test_run_async_state_merge_error_handled_gracefully(self, weather_tool_with_outputs_to_state): 1012 class ProblematicState(State): 1013 def set(self, key: str, value: Any, handler_override=None): 1014 # Simulate a State error during merging 1015 raise ValueError("State set operation failed") 1016 1017 state = ProblematicState(schema={"test_key": {"type": str}}) 1018 invoker = ToolInvoker(tools=[weather_tool_with_outputs_to_state], raise_on_failure=False) 1019 1020 tool_calls = [ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})] 1021 message = ChatMessage.from_assistant(tool_calls=tool_calls) 1022 1023 result = await invoker.run_async(messages=[message], state=state) 1024 1025 assert "tool_messages" in result 1026 assert len(result["tool_messages"]) == 1 1027 assert result["tool_messages"][0].tool_call_results[0].error is True 1028 assert ( 1029 "Failed to merge tool outputs from tool weather_tool into State" 1030 in result["tool_messages"][0].tool_call_results[0].result 1031 ) 1032 1033 @pytest.mark.asyncio 1034 async def test_run_async_state_merge_error_raises_when_configured(self, weather_tool_with_outputs_to_state): 1035 class ProblematicState(State): 1036 def set(self, key: str, value: Any, handler_override=None): 1037 # Simulate a State error during merging 1038 raise ValueError("State set operation failed") 1039 1040 state = ProblematicState(schema={"test_key": {"type": str}}) 1041 invoker = ToolInvoker(tools=[weather_tool_with_outputs_to_state], raise_on_failure=True) 1042 1043 tool_calls = [ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})] 1044 message = ChatMessage.from_assistant(tool_calls=tool_calls) 1045 1046 with pytest.raises(ToolOutputMergeError, match="Failed to merge"): 1047 await invoker.run_async(messages=[message], state=state) 1048 1049 1050 class TestToolInvokerUtilities: 1051 def test_default_output_to_string_handler_basic_types(self, weather_tool): 1052 invoker = ToolInvoker(tools=[weather_tool], convert_result_to_json_string=False) 1053 1054 assert invoker._default_output_to_string_handler("hello") == "hello" 1055 assert invoker._default_output_to_string_handler(42) == "42" 1056 assert invoker._default_output_to_string_handler(3.14) == "3.14" 1057 assert invoker._default_output_to_string_handler(True) == "True" 1058 assert invoker._default_output_to_string_handler(None) == "None" 1059 1060 assert invoker._default_output_to_string_handler([1, 2, 3]) == "[1, 2, 3]" 1061 assert invoker._default_output_to_string_handler({"key": "value"}) == "{'key': 'value'}" 1062 1063 def test_default_output_to_string_handler_json_string_mode(self, weather_tool): 1064 invoker = ToolInvoker(tools=[weather_tool], convert_result_to_json_string=True) 1065 1066 assert invoker._default_output_to_string_handler("hello") == '"hello"' 1067 assert invoker._default_output_to_string_handler(42) == "42" 1068 assert invoker._default_output_to_string_handler(True) == "true" 1069 assert invoker._default_output_to_string_handler(None) == "null" 1070 1071 assert invoker._default_output_to_string_handler([1, 2, 3]) == "[1, 2, 3]" 1072 assert invoker._default_output_to_string_handler({"key": "value"}) == '{"key": "value"}' 1073 1074 assert invoker._default_output_to_string_handler("Hello 🌍") == '"Hello 🌍"' 1075 1076 def test_default_output_to_string_handler_with_serializable_objects(self, weather_tool): 1077 invoker = ToolInvoker(tools=[weather_tool], convert_result_to_json_string=False) 1078 1079 # Create a mock object with to_dict method 1080 class MockObject: 1081 def __init__(self, value): 1082 self.value = value 1083 1084 def to_dict(self): 1085 return {"value": self.value} 1086 1087 mock_obj = MockObject("test_value") 1088 result = invoker._default_output_to_string_handler(mock_obj) 1089 1090 # Should convert to string representation of the dict 1091 assert "test_value" in result 1092 assert "value" in result 1093 1094 def test_merge_tool_outputs_result_not_a_dict(self, weather_tool): 1095 invoker = ToolInvoker(tools=[weather_tool]) 1096 state = State(schema={"weather": {"type": str}}) 1097 invoker._merge_tool_outputs(tool=weather_tool, result="test", state=state) 1098 assert state.data == {} 1099 1100 def test_merge_tool_outputs_empty_dict(self, weather_tool): 1101 invoker = ToolInvoker(tools=[weather_tool]) 1102 state = State(schema={"weather": {"type": str}}) 1103 invoker._merge_tool_outputs(tool=weather_tool, result={}, state=state) 1104 assert state.data == {} 1105 1106 def test_merge_tool_outputs_no_output_mapping(self, weather_tool): 1107 invoker = ToolInvoker(tools=[weather_tool]) 1108 state = State(schema={"weather": {"type": str}}) 1109 invoker._merge_tool_outputs( 1110 tool=weather_tool, result={"weather": "sunny", "temperature": 14, "unit": "celsius"}, state=state 1111 ) 1112 assert state.data == {} 1113 1114 def test_merge_tool_outputs_with_output_mapping(self): 1115 weather_tool = Tool( 1116 name="weather_tool", 1117 description="Provides weather information for a given location.", 1118 parameters=weather_parameters, 1119 function=weather_function, 1120 outputs_to_state={"weather": {"source": "weather"}}, 1121 ) 1122 invoker = ToolInvoker(tools=[weather_tool]) 1123 state = State(schema={"weather": {"type": str}}) 1124 invoker._merge_tool_outputs( 1125 tool=weather_tool, result={"weather": "sunny", "temperature": 14, "unit": "celsius"}, state=state 1126 ) 1127 assert state.data == {"weather": "sunny"} 1128 1129 def test_merge_tool_outputs_with_output_mapping_2(self): 1130 weather_tool = Tool( 1131 name="weather_tool", 1132 description="Provides weather information for a given location.", 1133 parameters=weather_parameters, 1134 function=weather_function, 1135 outputs_to_state={"all_weather_results": {}}, 1136 ) 1137 invoker = ToolInvoker(tools=[weather_tool]) 1138 state = State(schema={"all_weather_results": {"type": str}}) 1139 invoker._merge_tool_outputs( 1140 tool=weather_tool, result={"weather": "sunny", "temperature": 14, "unit": "celsius"}, state=state 1141 ) 1142 assert state.data == {"all_weather_results": {"weather": "sunny", "temperature": 14, "unit": "celsius"}} 1143 1144 def test_merge_tool_outputs_source_key_absent_does_not_corrupt_list_state(self): 1145 """ 1146 Simulates a PipelineTool wrapping a pipeline with a conditional branch that may not execute, resulting in the 1147 source key being absent from the tool result. The test verifies that in this case, the existing list in state 1148 is not corrupted by appending None. 1149 """ 1150 tool = Tool( 1151 name="retrieval", 1152 description="mock", 1153 parameters={"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}, 1154 function=lambda query: {}, 1155 outputs_to_state={"documents": {"source": "documents_output"}}, 1156 ) 1157 invoker = ToolInvoker(tools=[tool]) 1158 existing_doc = Document(content="from first call") 1159 state = State(schema={"documents": {"type": list[Document]}}) 1160 state.set("documents", [existing_doc]) 1161 1162 # Tool result where the source key is absent (document extraction branch did not execute) 1163 invoker._merge_tool_outputs(tool=tool, result={"result": "no web results found"}, state=state) 1164 1165 assert state.data["documents"] == [existing_doc] 1166 assert None not in state.data["documents"] 1167 1168 def test_merge_tool_outputs_with_output_mapping_and_handler(self): 1169 handler = lambda _, new: f"{new}" # noqa: E731 1170 weather_tool = Tool( 1171 name="weather_tool", 1172 description="Provides weather information for a given location.", 1173 parameters=weather_parameters, 1174 function=weather_function, 1175 outputs_to_state={"temperature": {"source": "temperature", "handler": handler}}, 1176 ) 1177 invoker = ToolInvoker(tools=[weather_tool]) 1178 state = State(schema={"temperature": {"type": str}}) 1179 invoker._merge_tool_outputs( 1180 tool=weather_tool, result={"weather": "sunny", "temperature": 14, "unit": "celsius"}, state=state 1181 ) 1182 assert state.data == {"temperature": "14"} 1183 1184 def test_process_output_empty_config(self, invoker, base64_image_string): 1185 image_content = ImageContent(base64_image=base64_image_string, mime_type="image/png") 1186 1187 result = invoker._process_output( 1188 config={"raw_result": True}, 1189 result=[image_content], 1190 tool_call=ToolCall(tool_name="retrieve_image", arguments={}), 1191 ) 1192 assert result == [image_content] 1193 1194 def test_process_output_source_only(self, invoker, base64_image_string): 1195 image_content = ImageContent(base64_image=base64_image_string, mime_type="image/png") 1196 1197 result = invoker._process_output( 1198 config={"source": "images", "raw_result": True}, 1199 result={"images": [image_content]}, 1200 tool_call=ToolCall(tool_name="retrieve_image", arguments={}), 1201 ) 1202 assert result == [image_content] 1203 1204 def test_process_output_handler_only(self, invoker, base64_image_string): 1205 def handler(result: dict) -> list[ImageContent]: 1206 return [ImageContent(base64_image=result["base64_image_string"], mime_type=result["mime_type"])] 1207 1208 result = invoker._process_output( 1209 config={"handler": handler, "raw_result": True}, 1210 result={"base64_image_string": base64_image_string, "mime_type": "image/png"}, 1211 tool_call=ToolCall(tool_name="retrieve_image", arguments={}), 1212 ) 1213 assert result == [ImageContent(base64_image=base64_image_string, mime_type="image/png")] 1214 1215 def test_process_output_source_and_handler(self, invoker, base64_image_string): 1216 def handler(result: dict) -> list[ImageContent]: 1217 return [ImageContent(base64_image=result["base64_image_string"], mime_type=result["mime_type"])] 1218 1219 result = invoker._process_output( 1220 config={"source": "images", "handler": handler, "raw_result": True}, 1221 result={ 1222 "images": {"base64_image_string": base64_image_string, "mime_type": "image/png"}, 1223 "other_key": "other_value", 1224 }, 1225 tool_call=ToolCall(tool_name="retrieve_image", arguments={}), 1226 ) 1227 assert result == [ImageContent(base64_image=base64_image_string, mime_type="image/png")] 1228 1229 def test_output_to_result_e2e(self, weather_tool): 1230 def handler(result): 1231 return [ 1232 TextContent(text=f"weather: {result['weather']}"), 1233 TextContent(text=f"temperature: {result['temperature']} {result['unit']}"), 1234 ] 1235 1236 weather_tool.outputs_to_string = {"handler": handler, "raw_result": True} 1237 1238 invoker = ToolInvoker(tools=[weather_tool]) 1239 1240 message = ChatMessage.from_assistant( 1241 tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})] 1242 ) 1243 1244 tool_messages = invoker.run(messages=[message])["tool_messages"] 1245 1246 assert tool_messages[0].tool_call_results[0].result == [ 1247 TextContent(text="weather: mostly sunny"), 1248 TextContent(text="temperature: 7 celsius"), 1249 ] 1250 1251 1252 class TestWarmUpTools: 1253 """Tests for Tool/Toolset warm_up through ToolInvoker""" 1254 1255 def test_tool_invoker_warm_up_with_single_tool(self): 1256 """Test that ToolInvoker.warm_up() calls warm_up on a single tool.""" 1257 tool = WarmupTrackingTool( 1258 name="test_tool", 1259 description="A test tool", 1260 parameters={"type": "object", "properties": {}}, 1261 function=lambda: "test", 1262 ) 1263 1264 invoker = ToolInvoker(tools=[tool]) 1265 1266 assert not tool.was_warmed_up 1267 invoker.warm_up() 1268 assert tool.was_warmed_up 1269 1270 def test_tool_invoker_warm_up_with_multiple_tools(self): 1271 """Test that ToolInvoker.warm_up() calls warm_up on multiple tools.""" 1272 tool1 = WarmupTrackingTool( 1273 name="tool1", 1274 description="First tool", 1275 parameters={"type": "object", "properties": {}}, 1276 function=lambda: "tool1", 1277 ) 1278 tool2 = WarmupTrackingTool( 1279 name="tool2", 1280 description="Second tool", 1281 parameters={"type": "object", "properties": {}}, 1282 function=lambda: "tool2", 1283 ) 1284 1285 invoker = ToolInvoker(tools=[tool1, tool2]) 1286 1287 assert not tool1.was_warmed_up 1288 assert not tool2.was_warmed_up 1289 1290 invoker.warm_up() 1291 1292 assert tool1.was_warmed_up 1293 assert tool2.was_warmed_up 1294 1295 def test_tool_invoker_warm_up_with_toolset(self, weather_tool): 1296 """Test that ToolInvoker.warm_up() calls warm_up on the toolset.""" 1297 toolset = WarmupTrackingToolset([weather_tool]) 1298 invoker = ToolInvoker(tools=toolset) 1299 1300 assert not toolset.was_warmed_up 1301 invoker.warm_up() 1302 assert toolset.was_warmed_up 1303 1304 def test_tool_invoker_warm_up_with_mixed_toolsets(self): 1305 """Test that ToolInvoker.warm_up() works with combined toolsets using concatenation.""" 1306 # Create first toolset with a tracking tool 1307 tool1 = WarmupTrackingTool( 1308 name="tool1", 1309 description="First tool", 1310 parameters={"type": "object", "properties": {}}, 1311 function=lambda: "tool1", 1312 ) 1313 toolset1 = WarmupTrackingToolset([tool1]) 1314 1315 # Create second toolset with another tracking tool 1316 tool2 = WarmupTrackingTool( 1317 name="tool2", 1318 description="Second tool", 1319 parameters={"type": "object", "properties": {}}, 1320 function=lambda: "tool2", 1321 ) 1322 toolset2 = WarmupTrackingToolset([tool2]) 1323 1324 # Combine toolsets using the + operator (creates _ToolsetWrapper) 1325 combined = toolset1 + toolset2 1326 1327 # Create invoker with the combined toolset 1328 invoker = ToolInvoker(tools=combined) 1329 1330 assert not toolset1.was_warmed_up 1331 assert not toolset2.was_warmed_up 1332 1333 invoker.warm_up() 1334 1335 # Both toolsets should be warmed up 1336 assert toolset1.was_warmed_up 1337 assert toolset2.was_warmed_up 1338 1339 def test_tool_invoker_warm_up_with_mixed_list_of_tools_and_toolsets(self): 1340 """Test that ToolInvoker.warm_up() works with a mixed list of Tools and Toolsets.""" 1341 # Create standalone tracking tools 1342 tool1 = WarmupTrackingTool( 1343 name="standalone_tool1", 1344 description="First standalone tool", 1345 parameters={"type": "object", "properties": {}}, 1346 function=lambda: "tool1", 1347 ) 1348 tool2 = WarmupTrackingTool( 1349 name="standalone_tool2", 1350 description="Second standalone tool", 1351 parameters={"type": "object", "properties": {}}, 1352 function=lambda: "tool2", 1353 ) 1354 1355 # Create toolsets with tracking 1356 tool3 = WarmupTrackingTool( 1357 name="toolset_tool1", 1358 description="Tool in toolset 1", 1359 parameters={"type": "object", "properties": {}}, 1360 function=lambda: "tool3", 1361 ) 1362 toolset1 = WarmupTrackingToolset([tool3]) 1363 1364 tool4 = WarmupTrackingTool( 1365 name="toolset_tool2", 1366 description="Tool in toolset 2", 1367 parameters={"type": "object", "properties": {}}, 1368 function=lambda: "tool4", 1369 ) 1370 toolset2 = WarmupTrackingToolset([tool4]) 1371 1372 # Create invoker with mixed list: Tool, Toolset, Tool, Toolset 1373 invoker = ToolInvoker(tools=[tool1, toolset1, tool2, toolset2]) 1374 1375 # Verify nothing is warmed up initially 1376 assert not tool1.was_warmed_up 1377 assert not tool2.was_warmed_up 1378 assert not toolset1.was_warmed_up 1379 assert not toolset2.was_warmed_up 1380 1381 # Warm up 1382 invoker.warm_up() 1383 1384 # Verify standalone tools are warmed up 1385 assert tool1.was_warmed_up 1386 assert tool2.was_warmed_up 1387 1388 # Verify toolsets themselves are warmed up (not just their internal tools) 1389 assert toolset1.was_warmed_up 1390 assert toolset2.was_warmed_up 1391 1392 def test_tool_invoker_warm_up_is_idempotent(self): 1393 """Test that ToolInvoker.warm_up() is idempotent and only warms up once.""" 1394 1395 class WarmupCountingTool(Tool): 1396 """A tool that counts how many times warm_up was called.""" 1397 1398 def __init__(self, *args, **kwargs): 1399 super().__init__(*args, **kwargs) 1400 self.warm_up_count = 0 1401 1402 def warm_up(self): 1403 self.warm_up_count += 1 1404 1405 tool = WarmupCountingTool( 1406 name="counting_tool", 1407 description="A tool that counts warm_up calls", 1408 parameters={"type": "object", "properties": {}}, 1409 function=lambda: "test", 1410 ) 1411 1412 invoker = ToolInvoker(tools=[tool]) 1413 1414 # Call warm_up multiple times 1415 invoker.warm_up() 1416 invoker.warm_up() 1417 invoker.warm_up() 1418 1419 # Should only be warmed up once 1420 assert tool.warm_up_count == 1 1421 1422 def test_warm_up_refreshes_tools_with_names(self): 1423 """ 1424 Test that ToolInvoker.warm_up() refreshes _tools_with_names when using a toolset with lazy connection. 1425 """ 1426 # Create placeholder tool that simulates MCPToolset behavior of lazy connection 1427 placeholder_tool = Tool( 1428 name="mcp_not_connected_placeholder_123", 1429 description="Placeholder tool before connection", 1430 parameters={"type": "object", "properties": {}}, 1431 function=lambda: "placeholder", 1432 ) 1433 1434 # Create the actual tool that will replace the placeholder during warmup 1435 # This simulates what mcp-server-time mcp would provide 1436 actual_tool = Tool( 1437 name="get_time", 1438 description="Get the current time in ISO format", 1439 parameters={"type": "object", "properties": {}, "required": []}, 1440 function=lambda: "2024-12-01T12:00:00Z", 1441 ) 1442 1443 # Create a toolset that simulates MCPToolset with eager_connect=False (lazy connection) 1444 class MockMCPToolset(Toolset): 1445 """Simulates MCPToolset behavior with eager_connect=False.""" 1446 1447 def __init__(self): 1448 # Start with placeholder tools (like MCPToolset does when not eagerly connected) 1449 super().__init__([placeholder_tool]) 1450 self._warmed_up = False 1451 1452 def warm_up(self): 1453 """Simulate connecting to MCP server and replacing placeholder tools with actual tools.""" 1454 if not self._warmed_up: 1455 # Replace placeholder tools with actual tools (simulating MCP connection) 1456 self.tools = [actual_tool] 1457 self._warmed_up = True 1458 1459 mcp_toolset = MockMCPToolset() 1460 invoker = ToolInvoker(tools=mcp_toolset) 1461 1462 # Before warmup: _tools_with_names should contain the placeholder tool 1463 assert "mcp_not_connected_placeholder_123" in invoker._tools_with_names 1464 assert "get_time" not in invoker._tools_with_names 1465 1466 # Call warm_up() directly to trigger tool refresh 1467 invoker.warm_up() 1468 1469 # After warmup: _tools_with_names should be refreshed with actual tool names 1470 assert "mcp_not_connected_placeholder_123" not in invoker._tools_with_names 1471 assert "get_time" in invoker._tools_with_names