/ test / components / tools / test_tool_invoker.py
test_tool_invoker.py
   1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
   2  #
   3  # SPDX-License-Identifier: Apache-2.0
   4  
   5  import datetime
   6  import json
   7  import time
   8  from typing import Any
   9  from unittest.mock import patch
  10  
  11  import pytest
  12  
  13  from haystack import Document, Pipeline
  14  from haystack.components.agents.state import State
  15  from haystack.components.builders.prompt_builder import PromptBuilder
  16  from haystack.components.generators.chat.openai import OpenAIChatGenerator
  17  from haystack.components.generators.utils import print_streaming_chunk
  18  from haystack.components.tools.tool_invoker import (
  19      ResultConversionError,
  20      StringConversionError,
  21      ToolInvoker,
  22      ToolNotFoundException,
  23      ToolOutputMergeError,
  24  )
  25  from haystack.dataclasses import (
  26      ChatMessage,
  27      ChatRole,
  28      ImageContent,
  29      StreamingChunk,
  30      TextContent,
  31      ToolCall,
  32      ToolCallResult,
  33  )
  34  from haystack.tools import ComponentTool, Tool, Toolset
  35  from haystack.tools.errors import ToolInvocationError
  36  
  37  
  38  def weather_function(location):
  39      weather_info = {
  40          "Berlin": {"weather": "mostly sunny", "temperature": 7, "unit": "celsius"},
  41          "Paris": {"weather": "mostly cloudy", "temperature": 8, "unit": "celsius"},
  42          "Rome": {"weather": "sunny", "temperature": 14, "unit": "celsius"},
  43      }
  44      return weather_info.get(location, {"weather": "unknown", "temperature": 0, "unit": "celsius"})
  45  
  46  
  47  weather_parameters = {"type": "object", "properties": {"location": {"type": "string"}}, "required": ["location"]}
  48  
  49  
  50  @pytest.fixture
  51  def weather_tool():
  52      return Tool(
  53          name="weather_tool",
  54          description="Provides weather information for a given location.",
  55          parameters=weather_parameters,
  56          function=weather_function,
  57      )
  58  
  59  
  60  @pytest.fixture
  61  def weather_tool_with_outputs_to_state():
  62      return Tool(
  63          name="weather_tool",
  64          description="Provides weather information for a given location.",
  65          parameters=weather_parameters,
  66          function=weather_function,
  67          outputs_to_state={"weather": {"source": "weather"}},
  68      )
  69  
  70  
  71  @pytest.fixture
  72  def faulty_tool():
  73      def faulty_tool_func(location):
  74          raise Exception("This tool always fails.")
  75  
  76      faulty_tool_parameters = {
  77          "type": "object",
  78          "properties": {"location": {"type": "string"}},
  79          "required": ["location"],
  80      }
  81  
  82      return Tool(
  83          name="faulty_tool",
  84          description="A tool that always fails when invoked.",
  85          parameters=faulty_tool_parameters,
  86          function=faulty_tool_func,
  87      )
  88  
  89  
  90  def add_function(num1: int, num2: int):
  91      return num1 + num2
  92  
  93  
  94  @pytest.fixture
  95  def tool_set():
  96      return Toolset(
  97          tools=[
  98              Tool(
  99                  name="weather_tool",
 100                  description="Provides weather information for a given location.",
 101                  parameters=weather_parameters,
 102                  function=weather_function,
 103              ),
 104              Tool(
 105                  name="addition_tool",
 106                  description="A tool that adds two numbers.",
 107                  parameters={
 108                      "type": "object",
 109                      "properties": {"num1": {"type": "integer"}, "num2": {"type": "integer"}},
 110                      "required": ["num1", "num2"],
 111                  },
 112                  function=add_function,
 113              ),
 114          ]
 115      )
 116  
 117  
 118  @pytest.fixture
 119  def invoker(weather_tool):
 120      return ToolInvoker(tools=[weather_tool], raise_on_failure=True, convert_result_to_json_string=False)
 121  
 122  
 123  @pytest.fixture
 124  def faulty_invoker(faulty_tool):
 125      return ToolInvoker(tools=[faulty_tool], raise_on_failure=True, convert_result_to_json_string=False)
 126  
 127  
 128  class WarmupTrackingTool(Tool):
 129      """A tool that tracks whether warm_up was called."""
 130  
 131      def __init__(self, *args, **kwargs):
 132          super().__init__(*args, **kwargs)
 133          self.was_warmed_up = False
 134  
 135      def warm_up(self):
 136          self.was_warmed_up = True
 137  
 138  
 139  class WarmupTrackingToolset(Toolset):
 140      """A toolset that tracks whether warm_up was called."""
 141  
 142      def __init__(self, tools):
 143          super().__init__(tools)
 144          self.was_warmed_up = False
 145  
 146      def warm_up(self):
 147          self.was_warmed_up = True
 148          # Call parent to warm up individual tools
 149          super().warm_up()
 150  
 151  
 152  class TestToolInvokerCore:
 153      def test_init(self, weather_tool):
 154          invoker = ToolInvoker(tools=[weather_tool])
 155  
 156          assert invoker.tools == [weather_tool]
 157          assert invoker._tools_with_names == {"weather_tool": weather_tool}
 158          assert invoker.raise_on_failure
 159          assert not invoker.convert_result_to_json_string
 160  
 161      def test_validate_and_prepare_tools(self, weather_tool, faulty_tool):
 162          result = ToolInvoker._validate_and_prepare_tools([weather_tool, faulty_tool])
 163          assert result == {"weather_tool": weather_tool, "faulty_tool": faulty_tool}
 164  
 165          toolset = Toolset([weather_tool, faulty_tool])
 166          result = ToolInvoker._validate_and_prepare_tools(toolset)
 167          assert result == {"weather_tool": weather_tool, "faulty_tool": faulty_tool}
 168  
 169      def test_validate_and_prepare_tools_validation_failures(self, weather_tool):
 170          with pytest.raises(ValueError):
 171              ToolInvoker._validate_and_prepare_tools([])
 172  
 173          with pytest.raises(ValueError):
 174              ToolInvoker._validate_and_prepare_tools([weather_tool, weather_tool])
 175  
 176      def test_inject_state_args_no_tool_inputs(self, invoker):
 177          weather_tool = Tool(
 178              name="weather_tool",
 179              description="Provides weather information for a given location.",
 180              parameters=weather_parameters,
 181              function=weather_function,
 182          )
 183          state = State(schema={"location": {"type": str}}, data={"location": "Berlin"})
 184          args = invoker._inject_state_args(tool=weather_tool, llm_args={}, state=state)
 185          assert args == {"location": "Berlin"}
 186  
 187      def test_inject_state_args_no_tool_inputs_component_tool(self, invoker):
 188          comp = PromptBuilder(template="Hello, {{name}}!")
 189          prompt_tool = ComponentTool(
 190              component=comp, name="prompt_tool", description="Creates a personalized greeting prompt."
 191          )
 192          state = State(schema={"name": {"type": str}}, data={"name": "James"})
 193          args = invoker._inject_state_args(tool=prompt_tool, llm_args={}, state=state)
 194          assert args == {"name": "James"}
 195  
 196      def test_inject_state_args_with_tool_inputs(self, invoker):
 197          weather_tool = Tool(
 198              name="weather_tool",
 199              description="Provides weather information for a given location.",
 200              parameters=weather_parameters,
 201              function=weather_function,
 202              inputs_from_state={"loc": "location"},
 203          )
 204          state = State(schema={"location": {"type": str}}, data={"loc": "Berlin"})
 205          args = invoker._inject_state_args(tool=weather_tool, llm_args={}, state=state)
 206          assert args == {"location": "Berlin"}
 207  
 208      def test_inject_state_args_param_in_state_and_llm(self, invoker):
 209          weather_tool = Tool(
 210              name="weather_tool",
 211              description="Provides weather information for a given location.",
 212              parameters=weather_parameters,
 213              function=weather_function,
 214          )
 215          state = State(schema={"location": {"type": str}}, data={"location": "Berlin"})
 216          args = invoker._inject_state_args(tool=weather_tool, llm_args={"location": "Paris"}, state=state)
 217          assert args == {"location": "Paris"}
 218  
 219      def test_inject_state_args_injects_state_object_for_state_annotated_param(self, invoker):
 220          def function_with_state(city: str, state: State) -> str:
 221              return f"Weather in {city}"
 222  
 223          state_tool = Tool(
 224              name="state_tool",
 225              description="A tool that receives the live State object.",
 226              parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
 227              function=function_with_state,
 228          )
 229          state = State(schema={"city": {"type": str}}, data={"city": "Berlin"})
 230          args = invoker._inject_state_args(tool=state_tool, llm_args={"city": "Paris"}, state=state)
 231          assert args["city"] == "Paris"
 232          assert args["state"] is state
 233  
 234      def test_inject_state_args_injects_state_object_for_optional_state_annotated_param(self, invoker):
 235          def function_with_optional_state(city: str, state: State | None = None) -> str:
 236              return f"Weather in {city}"
 237  
 238          state_tool = Tool(
 239              name="state_tool",
 240              description="A tool that receives an optional State object.",
 241              parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
 242              function=function_with_optional_state,
 243          )
 244          state = State(schema={})
 245          args = invoker._inject_state_args(tool=state_tool, llm_args={"city": "Paris"}, state=state)
 246          assert args["city"] == "Paris"
 247          assert args["state"] is state
 248  
 249  
 250  class TestToolInvokerSerde:
 251      def test_to_dict(self, invoker, weather_tool):
 252          data = invoker.to_dict()
 253          assert data == {
 254              "type": "haystack.components.tools.tool_invoker.ToolInvoker",
 255              "init_parameters": {
 256                  "tools": [weather_tool.to_dict()],
 257                  "raise_on_failure": True,
 258                  "convert_result_to_json_string": False,
 259                  "enable_streaming_callback_passthrough": False,
 260                  "streaming_callback": None,
 261                  "max_workers": 4,
 262              },
 263          }
 264  
 265      def test_to_dict_with_params(self, weather_tool):
 266          invoker = ToolInvoker(
 267              tools=[weather_tool],
 268              raise_on_failure=False,
 269              convert_result_to_json_string=True,
 270              enable_streaming_callback_passthrough=True,
 271              streaming_callback=print_streaming_chunk,
 272          )
 273  
 274          assert invoker.to_dict() == {
 275              "type": "haystack.components.tools.tool_invoker.ToolInvoker",
 276              "init_parameters": {
 277                  "tools": [weather_tool.to_dict()],
 278                  "raise_on_failure": False,
 279                  "convert_result_to_json_string": True,
 280                  "enable_streaming_callback_passthrough": True,
 281                  "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
 282                  "max_workers": 4,
 283              },
 284          }
 285  
 286      def test_from_dict(self, weather_tool):
 287          data = {
 288              "type": "haystack.components.tools.tool_invoker.ToolInvoker",
 289              "init_parameters": {
 290                  "tools": [weather_tool.to_dict()],
 291                  "raise_on_failure": True,
 292                  "convert_result_to_json_string": False,
 293                  "enable_streaming_callback_passthrough": False,
 294                  "streaming_callback": None,
 295              },
 296          }
 297          invoker = ToolInvoker.from_dict(data)
 298          assert invoker.tools == [weather_tool]
 299          assert invoker._tools_with_names == {"weather_tool": weather_tool}
 300          assert invoker.raise_on_failure
 301          assert not invoker.convert_result_to_json_string
 302          assert invoker.streaming_callback is None
 303          assert invoker.enable_streaming_callback_passthrough is False
 304  
 305      def test_from_dict_with_streaming_callback(self, weather_tool):
 306          data = {
 307              "type": "haystack.components.tools.tool_invoker.ToolInvoker",
 308              "init_parameters": {
 309                  "tools": [weather_tool.to_dict()],
 310                  "raise_on_failure": True,
 311                  "convert_result_to_json_string": False,
 312                  "enable_streaming_callback_passthrough": True,
 313                  "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
 314              },
 315          }
 316          invoker = ToolInvoker.from_dict(data)
 317          assert invoker.tools == [weather_tool]
 318          assert invoker._tools_with_names == {"weather_tool": weather_tool}
 319          assert invoker.raise_on_failure
 320          assert not invoker.convert_result_to_json_string
 321          assert invoker.streaming_callback == print_streaming_chunk
 322          assert invoker.enable_streaming_callback_passthrough is True
 323  
 324      def test_serde_in_pipeline(self, invoker, monkeypatch):
 325          monkeypatch.setenv("OPENAI_API_KEY", "test-key")
 326  
 327          pipeline = Pipeline()
 328          pipeline.add_component("invoker", invoker)
 329          pipeline.add_component("chatgenerator", OpenAIChatGenerator())
 330          pipeline.connect("invoker", "chatgenerator")
 331  
 332          pipeline_dict = pipeline.to_dict()
 333          # Verify the chatgenerator component structure (model will be whatever the default is)
 334          chatgen = pipeline_dict["components"]["chatgenerator"]
 335          assert chatgen["type"] == "haystack.components.generators.chat.openai.OpenAIChatGenerator"
 336          assert "model" in chatgen["init_parameters"]
 337          model_name = chatgen["init_parameters"]["model"]
 338  
 339          # Build expected dict with dynamic model value
 340          expected = {
 341              "metadata": {},
 342              "connection_type_validation": True,
 343              "max_runs_per_component": 100,
 344              "components": {
 345                  "invoker": {
 346                      "type": "haystack.components.tools.tool_invoker.ToolInvoker",
 347                      "init_parameters": {
 348                          "tools": [
 349                              {
 350                                  "type": "haystack.tools.tool.Tool",
 351                                  "data": {
 352                                      "name": "weather_tool",
 353                                      "description": "Provides weather information for a given location.",
 354                                      "parameters": {
 355                                          "type": "object",
 356                                          "properties": {"location": {"type": "string"}},
 357                                          "required": ["location"],
 358                                      },
 359                                      "function": "tools.test_tool_invoker.weather_function",
 360                                      "outputs_to_string": None,
 361                                      "inputs_from_state": None,
 362                                      "outputs_to_state": None,
 363                                  },
 364                              }
 365                          ],
 366                          "raise_on_failure": True,
 367                          "convert_result_to_json_string": False,
 368                          "enable_streaming_callback_passthrough": False,
 369                          "streaming_callback": None,
 370                          "max_workers": 4,
 371                      },
 372                  },
 373                  "chatgenerator": {
 374                      "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
 375                      "init_parameters": {
 376                          "model": model_name,
 377                          "streaming_callback": None,
 378                          "api_base_url": None,
 379                          "organization": None,
 380                          "generation_kwargs": {},
 381                          "max_retries": None,
 382                          "timeout": None,
 383                          "api_key": {"type": "env_var", "env_vars": ["OPENAI_API_KEY"], "strict": True},
 384                          "tools": None,
 385                          "tools_strict": False,
 386                          "http_client_kwargs": None,
 387                      },
 388                  },
 389              },
 390              "connections": [{"sender": "invoker.tool_messages", "receiver": "chatgenerator.messages"}],
 391          }
 392          assert pipeline_dict == expected
 393  
 394          pipeline_yaml = pipeline.dumps()
 395  
 396          new_pipeline = Pipeline.loads(pipeline_yaml)
 397          assert new_pipeline == pipeline
 398  
 399  
 400  class TestToolInvokerRun:
 401      def test_run_with_streaming_callback_finish_reason(self, invoker):
 402          streaming_chunks = []
 403  
 404          def streaming_callback(chunk: StreamingChunk) -> None:
 405              streaming_chunks.append(chunk)
 406  
 407          tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
 408          message = ChatMessage.from_assistant(tool_calls=[tool_call])
 409  
 410          result = invoker.run(messages=[message], streaming_callback=streaming_callback)
 411          assert "tool_messages" in result
 412          assert len(result["tool_messages"]) == 1
 413  
 414          # Check that we received streaming chunks
 415          assert len(streaming_chunks) >= 2  # At least one for tool result and one for finish reason
 416  
 417          # The last chunk should have finish_reason set to "tool_call_results"
 418          final_chunk = streaming_chunks[-1]
 419          assert final_chunk.finish_reason == "tool_call_results"
 420          assert final_chunk.meta["finish_reason"] == "tool_call_results"
 421          assert final_chunk.content == ""
 422  
 423      @pytest.mark.asyncio
 424      async def test_run_async_with_streaming_callback_finish_reason(self, weather_tool):
 425          streaming_chunks = []
 426  
 427          async def streaming_callback(chunk: StreamingChunk) -> None:
 428              streaming_chunks.append(chunk)
 429  
 430          tool_invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=True, convert_result_to_json_string=False)
 431  
 432          tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
 433          message = ChatMessage.from_assistant(tool_calls=[tool_call])
 434  
 435          result = await tool_invoker.run_async(messages=[message], streaming_callback=streaming_callback)
 436          assert "tool_messages" in result
 437          assert len(result["tool_messages"]) == 1
 438  
 439          # Check that we received streaming chunks
 440          assert len(streaming_chunks) >= 2  # At least one for tool result and one for finish reason
 441  
 442          # The last chunk should have finish_reason set to "tool_call_results"
 443          final_chunk = streaming_chunks[-1]
 444          assert final_chunk.finish_reason == "tool_call_results"
 445          assert final_chunk.meta["finish_reason"] == "tool_call_results"
 446          assert final_chunk.content == ""
 447  
 448      def test_enable_streaming_callback_passthrough(self, monkeypatch):
 449          monkeypatch.setenv("OPENAI_API_KEY", "test-key")
 450          llm_tool = ComponentTool(
 451              component=OpenAIChatGenerator(),
 452              name="chat_generator_tool",
 453              description="A tool that generates chat messages using OpenAI's GPT model.",
 454          )
 455          invoker = ToolInvoker(
 456              tools=[llm_tool], enable_streaming_callback_passthrough=True, streaming_callback=print_streaming_chunk
 457          )
 458          with patch("haystack.components.generators.chat.OpenAIChatGenerator.run") as mock_run:
 459              mock_run.return_value = {"replies": [ChatMessage.from_assistant("Hello! How can I help you?")]}
 460              invoker.run(
 461                  messages=[
 462                      ChatMessage.from_assistant(
 463                          tool_calls=[
 464                              ToolCall(
 465                                  tool_name="chat_generator_tool",
 466                                  arguments={"messages": [{"role": "user", "content": [{"text": "Hello!"}]}]},
 467                                  id="12345",
 468                              )
 469                          ]
 470                      )
 471                  ]
 472              )
 473              mock_run.assert_called_once_with(
 474                  messages=[ChatMessage.from_user(text="Hello!")], streaming_callback=print_streaming_chunk
 475              )
 476  
 477      def test_enable_streaming_callback_passthrough_runtime(self, monkeypatch):
 478          monkeypatch.setenv("OPENAI_API_KEY", "test-key")
 479          llm_tool = ComponentTool(
 480              component=OpenAIChatGenerator(),
 481              name="chat_generator_tool",
 482              description="A tool that generates chat messages using OpenAI's GPT model.",
 483          )
 484          invoker = ToolInvoker(
 485              tools=[llm_tool], enable_streaming_callback_passthrough=True, streaming_callback=print_streaming_chunk
 486          )
 487          with patch("haystack.components.generators.chat.OpenAIChatGenerator.run") as mock_run:
 488              mock_run.return_value = {"replies": [ChatMessage.from_assistant("Hello! How can I help you?")]}
 489              invoker.run(
 490                  messages=[
 491                      ChatMessage.from_assistant(
 492                          tool_calls=[
 493                              ToolCall(
 494                                  tool_name="chat_generator_tool",
 495                                  arguments={"messages": [{"role": "user", "content": [{"text": "Hello!"}]}]},
 496                                  id="12345",
 497                              )
 498                          ]
 499                      )
 500                  ],
 501                  enable_streaming_callback_passthrough=False,
 502              )
 503              mock_run.assert_called_once_with(messages=[ChatMessage.from_user(text="Hello!")])
 504  
 505      def test_run_no_messages(self, invoker):
 506          result = invoker.run(messages=[])
 507          assert result["tool_messages"] == []
 508  
 509      def test_run_no_tool_calls(self, invoker):
 510          user_message = ChatMessage.from_user(text="Hello!")
 511          assistant_message = ChatMessage.from_assistant(text="How can I help you?")
 512  
 513          result = invoker.run(messages=[user_message, assistant_message])
 514          assert result["tool_messages"] == []
 515  
 516      def test_run_multiple_tool_calls(self, invoker):
 517          tool_calls = [
 518              ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}),
 519              ToolCall(tool_name="weather_tool", arguments={"location": "Paris"}),
 520              ToolCall(tool_name="weather_tool", arguments={"location": "Rome"}),
 521          ]
 522          message = ChatMessage.from_assistant(tool_calls=tool_calls)
 523  
 524          result = invoker.run(messages=[message])
 525          assert "tool_messages" in result
 526          assert len(result["tool_messages"]) == 3
 527  
 528          for i, tool_message in enumerate(result["tool_messages"]):
 529              assert isinstance(tool_message, ChatMessage)
 530              assert tool_message.is_from(ChatRole.TOOL)
 531  
 532              assert tool_message.tool_call_results
 533              tool_call_result = tool_message.tool_call_result
 534  
 535              assert isinstance(tool_call_result, ToolCallResult)
 536              assert not tool_call_result.error
 537              assert tool_call_result.origin == tool_calls[i]
 538  
 539      def test_run_tool_calls_with_empty_args(self):
 540          hello_world_tool = Tool(
 541              name="hello_world",
 542              description="A tool that returns a greeting.",
 543              parameters={"type": "object", "properties": {}},
 544              function=lambda: "Hello, world!",
 545          )
 546          invoker = ToolInvoker(tools=[hello_world_tool])
 547  
 548          tool_call = ToolCall(tool_name="hello_world", arguments={})
 549          tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call])
 550  
 551          result = invoker.run(messages=[tool_call_message])
 552          assert "tool_messages" in result
 553          assert len(result["tool_messages"]) == 1
 554  
 555          tool_message = result["tool_messages"][0]
 556          assert isinstance(tool_message, ChatMessage)
 557          assert tool_message.is_from(ChatRole.TOOL)
 558  
 559          assert tool_message.tool_call_results
 560          tool_call_result = tool_message.tool_call_result
 561  
 562          assert isinstance(tool_call_result, ToolCallResult)
 563          assert not tool_call_result.error
 564  
 565          assert tool_call_result.result == "Hello, world!"
 566  
 567      def test_run_with_tools_override(self, weather_tool, faulty_tool):
 568          """Tests that tools passed to run override the tools passed in init"""
 569          invoker = ToolInvoker(tools=[faulty_tool])
 570          assert invoker._tools_with_names == {"faulty_tool": faulty_tool}
 571          tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
 572          message = ChatMessage.from_assistant(tool_calls=[tool_call])
 573  
 574          result = invoker.run(messages=[message], tools=[weather_tool])
 575  
 576          tool_message = result["tool_messages"][0]
 577          tool_call_result = tool_message.tool_call_result
 578          assert not tool_call_result.error
 579          assert tool_call_result.result == str({"weather": "mostly sunny", "temperature": 7, "unit": "celsius"})
 580          assert tool_call_result.origin == tool_call
 581  
 582      @pytest.mark.asyncio
 583      async def test_run_async_with_tools_override(self, weather_tool, faulty_tool):
 584          """Tests that tools passed to run_async override the tools passed in init"""
 585          invoker = ToolInvoker(tools=[faulty_tool])
 586          assert invoker._tools_with_names == {"faulty_tool": faulty_tool}
 587          tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
 588          message = ChatMessage.from_assistant(tool_calls=[tool_call])
 589  
 590          result = await invoker.run_async(messages=[message], tools=[weather_tool])
 591          tool_message = result["tool_messages"][0]
 592          tool_call_result = tool_message.tool_call_result
 593          assert not tool_call_result.error
 594          assert tool_call_result.result == str({"weather": "mostly sunny", "temperature": 7, "unit": "celsius"})
 595          assert tool_call_result.origin == tool_call
 596  
 597      def test_parallel_tool_calling_with_state_updates(self):
 598          """Test that parallel tool execution with state updates works correctly with the state lock."""
 599          # Create a shared counter variable to simulate a state value that gets updated
 600          execution_log = []
 601  
 602          def function_1():
 603              time.sleep(0.01)
 604              execution_log.append("tool_1_executed")
 605              return {"counter": 1, "tool_name": "tool_1"}
 606  
 607          def function_2():
 608              time.sleep(0.01)
 609              execution_log.append("tool_2_executed")
 610              return {"counter": 2, "tool_name": "tool_2"}
 611  
 612          def function_3():
 613              time.sleep(0.01)
 614              execution_log.append("tool_3_executed")
 615              return {"counter": 3, "tool_name": "tool_3"}
 616  
 617          # Create tools that all update the same state key
 618          tool_1 = Tool(
 619              name="state_tool_1",
 620              description="A tool that updates state counter",
 621              parameters={"type": "object", "properties": {}},
 622              function=function_1,
 623              outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
 624          )
 625  
 626          tool_2 = Tool(
 627              name="state_tool_2",
 628              description="A tool that updates state counter",
 629              parameters={"type": "object", "properties": {}},
 630              function=function_2,
 631              outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
 632          )
 633  
 634          tool_3 = Tool(
 635              name="state_tool_3",
 636              description="A tool that updates state counter",
 637              parameters={"type": "object", "properties": {}},
 638              function=function_3,
 639              outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
 640          )
 641  
 642          # Create ToolInvoker with all three tools
 643          invoker = ToolInvoker(tools=[tool_1, tool_2, tool_3], raise_on_failure=True)
 644  
 645          state = State(schema={"counter": {"type": int}, "last_tool": {"type": str}})
 646          tool_calls = [
 647              ToolCall(tool_name="state_tool_1", arguments={}),
 648              ToolCall(tool_name="state_tool_2", arguments={}),
 649              ToolCall(tool_name="state_tool_3", arguments={}),
 650          ]
 651          message = ChatMessage.from_assistant(tool_calls=tool_calls)
 652          _ = invoker.run(messages=[message], state=state)
 653  
 654          # Verify that all three tools were executed
 655          assert len(execution_log) == 3
 656          assert "tool_1_executed" in execution_log
 657          assert "tool_2_executed" in execution_log
 658          assert "tool_3_executed" in execution_log
 659  
 660          # Verify that the state was updated correctly
 661          # Due to parallel execution, we can't predict which tool will be the last to update
 662          assert state.has("counter")
 663          assert state.has("last_tool")
 664          assert state.get("counter") in [1, 2, 3]  # Should be one of the tool values
 665          assert state.get("last_tool") in ["tool_1", "tool_2", "tool_3"]  # Should be one of the tool names
 666  
 667      @pytest.mark.asyncio
 668      async def test_async_parallel_tool_calling_with_state_updates(self):
 669          """Test that parallel tool execution with state updates works correctly with the state lock."""
 670          # Create a shared counter variable to simulate a state value that gets updated
 671          execution_log = []
 672  
 673          def function_1():
 674              time.sleep(0.01)
 675              execution_log.append("tool_1_executed")
 676              return {"counter": 1, "tool_name": "tool_1"}
 677  
 678          def function_2():
 679              time.sleep(0.01)
 680              execution_log.append("tool_2_executed")
 681              return {"counter": 2, "tool_name": "tool_2"}
 682  
 683          def function_3():
 684              time.sleep(0.01)
 685              execution_log.append("tool_3_executed")
 686              return {"counter": 3, "tool_name": "tool_3"}
 687  
 688          # Create tools that all update the same state key
 689          tool_1 = Tool(
 690              name="state_tool_1",
 691              description="A tool that updates state counter",
 692              parameters={"type": "object", "properties": {}},
 693              function=function_1,
 694              outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
 695          )
 696  
 697          tool_2 = Tool(
 698              name="state_tool_2",
 699              description="A tool that updates state counter",
 700              parameters={"type": "object", "properties": {}},
 701              function=function_2,
 702              outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
 703          )
 704  
 705          tool_3 = Tool(
 706              name="state_tool_3",
 707              description="A tool that updates state counter",
 708              parameters={"type": "object", "properties": {}},
 709              function=function_3,
 710              outputs_to_state={"counter": {"source": "counter"}, "last_tool": {"source": "tool_name"}},
 711          )
 712  
 713          # Create ToolInvoker with all three tools
 714          invoker = ToolInvoker(tools=[tool_1, tool_2, tool_3], raise_on_failure=True)
 715  
 716          state = State(schema={"counter": {"type": int}, "last_tool": {"type": str}})
 717          tool_calls = [
 718              ToolCall(tool_name="state_tool_1", arguments={}),
 719              ToolCall(tool_name="state_tool_2", arguments={}),
 720              ToolCall(tool_name="state_tool_3", arguments={}),
 721          ]
 722          message = ChatMessage.from_assistant(tool_calls=tool_calls)
 723          _ = await invoker.run_async(messages=[message], state=state)
 724  
 725          # Verify that all three tools were executed
 726          assert len(execution_log) == 3
 727          assert "tool_1_executed" in execution_log
 728          assert "tool_2_executed" in execution_log
 729          assert "tool_3_executed" in execution_log
 730  
 731          # Verify that the state was updated correctly
 732          # Due to parallel execution, we can't predict which tool will be the last to update
 733          assert state.has("counter")
 734          assert state.has("last_tool")
 735          assert state.get("counter") in [1, 2, 3]  # Should be one of the tool values
 736          assert state.get("last_tool") in ["tool_1", "tool_2", "tool_3"]  # Should be one of the tool names
 737  
 738      def test_call_invoker_two_subsequent_run_calls(self, invoker: ToolInvoker):
 739          tool_calls = [
 740              ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}),
 741              ToolCall(tool_name="weather_tool", arguments={"location": "Paris"}),
 742              ToolCall(tool_name="weather_tool", arguments={"location": "Rome"}),
 743          ]
 744          message = ChatMessage.from_assistant(tool_calls=tool_calls)
 745  
 746          streaming_callback_called = False
 747  
 748          def streaming_callback(chunk: StreamingChunk) -> None:
 749              nonlocal streaming_callback_called
 750              streaming_callback_called = True
 751  
 752          # First call
 753          result_1 = invoker.run(messages=[message], streaming_callback=streaming_callback)
 754          assert "tool_messages" in result_1
 755          assert len(result_1["tool_messages"]) == 3
 756  
 757          # Second call
 758          result_2 = invoker.run(messages=[message], streaming_callback=streaming_callback)
 759          assert "tool_messages" in result_2
 760          assert len(result_2["tool_messages"]) == 3
 761  
 762      @pytest.mark.asyncio
 763      async def test_call_invoker_two_subsequent_run_async_calls(self, invoker: ToolInvoker):
 764          tool_calls = [
 765              ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"}),
 766              ToolCall(tool_name="weather_tool", arguments={"location": "Paris"}),
 767              ToolCall(tool_name="weather_tool", arguments={"location": "Rome"}),
 768          ]
 769          message = ChatMessage.from_assistant(tool_calls=tool_calls)
 770  
 771          streaming_callback_called = False
 772  
 773          async def streaming_callback(chunk: StreamingChunk) -> None:
 774              nonlocal streaming_callback_called
 775              streaming_callback_called = True
 776  
 777          # First call
 778          result_1 = await invoker.run_async(messages=[message], streaming_callback=streaming_callback)
 779          assert "tool_messages" in result_1
 780          assert len(result_1["tool_messages"]) == 3
 781  
 782          # Second call
 783          result_2 = await invoker.run_async(messages=[message], streaming_callback=streaming_callback)
 784          assert "tool_messages" in result_2
 785          assert len(result_2["tool_messages"]) == 3
 786  
 787      def test_run_injects_state_object_into_tool(self):
 788          received_state = {}
 789  
 790          def function_with_state(city: str, state: State) -> str:
 791              received_state["state"] = state
 792              return f"Weather in {city}: sunny"
 793  
 794          state_tool = Tool(
 795              name="state_tool",
 796              description="A tool that receives the live State object.",
 797              parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
 798              function=function_with_state,
 799          )
 800          invoker = ToolInvoker(tools=[state_tool])
 801          state = State(schema={"city": {"type": str}})
 802  
 803          tool_call = ToolCall(tool_name="state_tool", arguments={"city": "Berlin"})
 804          message = ChatMessage.from_assistant(tool_calls=[tool_call])
 805          result = invoker.run(messages=[message], state=state)
 806  
 807          assert len(result["tool_messages"]) == 1
 808          assert not result["tool_messages"][0].tool_call_results[0].error
 809          assert received_state["state"] is state
 810  
 811      @pytest.mark.asyncio
 812      async def test_run_async_injects_state_object_into_tool(self):
 813          received_state = {}
 814  
 815          def function_with_state(city: str, state: State) -> str:
 816              received_state["state"] = state
 817              return f"Weather in {city}: sunny"
 818  
 819          state_tool = Tool(
 820              name="state_tool",
 821              description="A tool that receives the live State object.",
 822              parameters={"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]},
 823              function=function_with_state,
 824          )
 825          invoker = ToolInvoker(tools=[state_tool])
 826          state = State(schema={"city": {"type": str}})
 827  
 828          tool_call = ToolCall(tool_name="state_tool", arguments={"city": "Berlin"})
 829          message = ChatMessage.from_assistant(tool_calls=[tool_call])
 830          result = await invoker.run_async(messages=[message], state=state)
 831  
 832          assert len(result["tool_messages"]) == 1
 833          assert not result["tool_messages"][0].tool_call_results[0].error
 834          assert received_state["state"] is state
 835  
 836  
 837  class TestToolInvokerErrorHandling:
 838      def test_tool_not_found_error(self, invoker):
 839          tool_call = ToolCall(tool_name="non_existent_tool", arguments={"location": "Berlin"})
 840          tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call])
 841  
 842          with pytest.raises(ToolNotFoundException):
 843              invoker.run(messages=[tool_call_message])
 844  
 845      def test_tool_not_found_does_not_raise_exception(self, weather_tool):
 846          invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=False, convert_result_to_json_string=False)
 847  
 848          tool_call = ToolCall(tool_name="non_existent_tool", arguments={"location": "Berlin"})
 849          tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call])
 850  
 851          result = invoker.run(messages=[tool_call_message])
 852          tool_message = result["tool_messages"][0]
 853  
 854          assert tool_message.tool_call_results[0].error
 855          assert "not found" in tool_message.tool_call_results[0].result
 856  
 857      def test_tool_invocation_error(self, faulty_invoker):
 858          tool_call = ToolCall(tool_name="faulty_tool", arguments={"location": "Berlin"})
 859          tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call])
 860  
 861          with pytest.raises(ToolInvocationError):
 862              faulty_invoker.run(messages=[tool_call_message])
 863  
 864      def test_tool_invocation_error_does_not_raise_exception(self, faulty_tool):
 865          faulty_invoker = ToolInvoker(tools=[faulty_tool], raise_on_failure=False, convert_result_to_json_string=False)
 866  
 867          tool_call = ToolCall(tool_name="faulty_tool", arguments={"location": "Berlin"})
 868          tool_call_message = ChatMessage.from_assistant(tool_calls=[tool_call])
 869  
 870          result = faulty_invoker.run(messages=[tool_call_message])
 871          tool_message = result["tool_messages"][0]
 872          assert tool_message.tool_call_results[0].error
 873          assert "Failed to invoke" in tool_message.tool_call_results[0].result
 874  
 875      def test_outputs_to_string_with_multiple_outputs(self):
 876          weather_tool = Tool(
 877              name="weather_tool",
 878              description="Provides weather information for a given location.",
 879              parameters=weather_parameters,
 880              function=weather_function,
 881              # Pass config that selects two of three outputs
 882              outputs_to_string={"weather": {"source": "weather"}, "temp": {"source": "temperature"}},
 883          )
 884          invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=True)
 885  
 886          tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
 887  
 888          tool_result = {"weather": "sunny", "temperature": 25, "unit": "celsius"}
 889          chat_message = invoker._prepare_tool_result_message(
 890              result=tool_result, tool_call=tool_call, tool_to_invoke=weather_tool
 891          )
 892          assert chat_message.tool_call_results[0].result == "{'weather': 'sunny', 'temp': '25'}"
 893  
 894      def test_string_conversion_error(self):
 895          weather_tool = Tool(
 896              name="weather_tool",
 897              description="Provides weather information for a given location.",
 898              parameters=weather_parameters,
 899              function=weather_function,
 900              # Pass custom handler that will throw an error when trying to convert tool_result
 901              outputs_to_string={"handler": json.dumps},
 902          )
 903          invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=True)
 904  
 905          tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
 906  
 907          tool_result = datetime.datetime.now()
 908          with pytest.raises(StringConversionError):
 909              invoker._prepare_tool_result_message(result=tool_result, tool_call=tool_call, tool_to_invoke=weather_tool)
 910  
 911      def test_string_conversion_error_does_not_raise_exception(self):
 912          weather_tool = Tool(
 913              name="weather_tool",
 914              description="Provides weather information for a given location.",
 915              parameters=weather_parameters,
 916              function=weather_function,
 917              # Pass custom handler that will throw an error when trying to convert tool_result
 918              outputs_to_string={"handler": json.dumps},
 919          )
 920          invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=False)
 921  
 922          tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
 923  
 924          tool_result = datetime.datetime.now()
 925          tool_message = invoker._prepare_tool_result_message(
 926              result=tool_result, tool_call=tool_call, tool_to_invoke=weather_tool
 927          )
 928  
 929          assert tool_message.tool_call_results[0].error
 930          assert "Failed to convert" in tool_message.tool_call_results[0].result
 931  
 932      def test_result_conversion_error(self):
 933          def handler(result):
 934              raise ValueError("Handler error")
 935  
 936          weather_tool = Tool(
 937              name="weather_tool",
 938              description="Provides weather information for a given location.",
 939              parameters=weather_parameters,
 940              function=weather_function,
 941              outputs_to_string={"handler": handler, "raw_result": True},
 942          )
 943          invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=True)
 944  
 945          tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
 946  
 947          tool_result = "something"
 948          with pytest.raises(ResultConversionError):
 949              invoker._prepare_tool_result_message(result=tool_result, tool_call=tool_call, tool_to_invoke=weather_tool)
 950  
 951      def test_result_conversion_error_does_not_raise_exception(self):
 952          def handler(result):
 953              raise ValueError("Handler error")
 954  
 955          weather_tool = Tool(
 956              name="weather_tool",
 957              description="Provides weather information for a given location.",
 958              parameters=weather_parameters,
 959              function=weather_function,
 960              outputs_to_string={"handler": handler, "raw_result": True},
 961          )
 962          invoker = ToolInvoker(tools=[weather_tool], raise_on_failure=False)
 963  
 964          tool_call = ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})
 965  
 966          tool_result = "something"
 967          tool_message = invoker._prepare_tool_result_message(
 968              result=tool_result, tool_call=tool_call, tool_to_invoke=weather_tool
 969          )
 970          assert tool_message.tool_call_results[0].error
 971          assert "Failed to convert" in tool_message.tool_call_results[0].result
 972  
 973      def test_run_state_merge_error_handled_gracefully(self, weather_tool_with_outputs_to_state):
 974          class ProblematicState(State):
 975              def set(self, key: str, value: Any, handler_override=None):
 976                  # Simulate a State error during merging
 977                  raise ValueError("State set operation failed")
 978  
 979          state = ProblematicState(schema={"test_key": {"type": str}})
 980          invoker = ToolInvoker(tools=[weather_tool_with_outputs_to_state], raise_on_failure=False)
 981  
 982          tool_calls = [ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})]
 983          message = ChatMessage.from_assistant(tool_calls=tool_calls)
 984  
 985          result = invoker.run(messages=[message], state=state)
 986  
 987          assert "tool_messages" in result
 988          assert len(result["tool_messages"]) == 1
 989          assert result["tool_messages"][0].tool_call_results[0].error is True
 990          assert (
 991              "Failed to merge tool outputs from tool weather_tool into State"
 992              in result["tool_messages"][0].tool_call_results[0].result
 993          )
 994  
 995      def test_run_state_merge_error_raises_when_configured(self, weather_tool_with_outputs_to_state):
 996          class ProblematicState(State):
 997              def set(self, key: str, value: Any, handler_override=None):
 998                  # Simulate a State error during merging
 999                  raise ValueError("State set operation failed")
1000  
1001          state = ProblematicState(schema={"test_key": {"type": str}})
1002          invoker = ToolInvoker(tools=[weather_tool_with_outputs_to_state], raise_on_failure=True)
1003  
1004          tool_calls = [ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})]
1005          message = ChatMessage.from_assistant(tool_calls=tool_calls)
1006  
1007          with pytest.raises(ToolOutputMergeError, match="Failed to merge"):
1008              invoker.run(messages=[message], state=state)
1009  
1010      @pytest.mark.asyncio
1011      async def test_run_async_state_merge_error_handled_gracefully(self, weather_tool_with_outputs_to_state):
1012          class ProblematicState(State):
1013              def set(self, key: str, value: Any, handler_override=None):
1014                  # Simulate a State error during merging
1015                  raise ValueError("State set operation failed")
1016  
1017          state = ProblematicState(schema={"test_key": {"type": str}})
1018          invoker = ToolInvoker(tools=[weather_tool_with_outputs_to_state], raise_on_failure=False)
1019  
1020          tool_calls = [ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})]
1021          message = ChatMessage.from_assistant(tool_calls=tool_calls)
1022  
1023          result = await invoker.run_async(messages=[message], state=state)
1024  
1025          assert "tool_messages" in result
1026          assert len(result["tool_messages"]) == 1
1027          assert result["tool_messages"][0].tool_call_results[0].error is True
1028          assert (
1029              "Failed to merge tool outputs from tool weather_tool into State"
1030              in result["tool_messages"][0].tool_call_results[0].result
1031          )
1032  
1033      @pytest.mark.asyncio
1034      async def test_run_async_state_merge_error_raises_when_configured(self, weather_tool_with_outputs_to_state):
1035          class ProblematicState(State):
1036              def set(self, key: str, value: Any, handler_override=None):
1037                  # Simulate a State error during merging
1038                  raise ValueError("State set operation failed")
1039  
1040          state = ProblematicState(schema={"test_key": {"type": str}})
1041          invoker = ToolInvoker(tools=[weather_tool_with_outputs_to_state], raise_on_failure=True)
1042  
1043          tool_calls = [ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})]
1044          message = ChatMessage.from_assistant(tool_calls=tool_calls)
1045  
1046          with pytest.raises(ToolOutputMergeError, match="Failed to merge"):
1047              await invoker.run_async(messages=[message], state=state)
1048  
1049  
1050  class TestToolInvokerUtilities:
1051      def test_default_output_to_string_handler_basic_types(self, weather_tool):
1052          invoker = ToolInvoker(tools=[weather_tool], convert_result_to_json_string=False)
1053  
1054          assert invoker._default_output_to_string_handler("hello") == "hello"
1055          assert invoker._default_output_to_string_handler(42) == "42"
1056          assert invoker._default_output_to_string_handler(3.14) == "3.14"
1057          assert invoker._default_output_to_string_handler(True) == "True"
1058          assert invoker._default_output_to_string_handler(None) == "None"
1059  
1060          assert invoker._default_output_to_string_handler([1, 2, 3]) == "[1, 2, 3]"
1061          assert invoker._default_output_to_string_handler({"key": "value"}) == "{'key': 'value'}"
1062  
1063      def test_default_output_to_string_handler_json_string_mode(self, weather_tool):
1064          invoker = ToolInvoker(tools=[weather_tool], convert_result_to_json_string=True)
1065  
1066          assert invoker._default_output_to_string_handler("hello") == '"hello"'
1067          assert invoker._default_output_to_string_handler(42) == "42"
1068          assert invoker._default_output_to_string_handler(True) == "true"
1069          assert invoker._default_output_to_string_handler(None) == "null"
1070  
1071          assert invoker._default_output_to_string_handler([1, 2, 3]) == "[1, 2, 3]"
1072          assert invoker._default_output_to_string_handler({"key": "value"}) == '{"key": "value"}'
1073  
1074          assert invoker._default_output_to_string_handler("Hello 🌍") == '"Hello 🌍"'
1075  
1076      def test_default_output_to_string_handler_with_serializable_objects(self, weather_tool):
1077          invoker = ToolInvoker(tools=[weather_tool], convert_result_to_json_string=False)
1078  
1079          # Create a mock object with to_dict method
1080          class MockObject:
1081              def __init__(self, value):
1082                  self.value = value
1083  
1084              def to_dict(self):
1085                  return {"value": self.value}
1086  
1087          mock_obj = MockObject("test_value")
1088          result = invoker._default_output_to_string_handler(mock_obj)
1089  
1090          # Should convert to string representation of the dict
1091          assert "test_value" in result
1092          assert "value" in result
1093  
1094      def test_merge_tool_outputs_result_not_a_dict(self, weather_tool):
1095          invoker = ToolInvoker(tools=[weather_tool])
1096          state = State(schema={"weather": {"type": str}})
1097          invoker._merge_tool_outputs(tool=weather_tool, result="test", state=state)
1098          assert state.data == {}
1099  
1100      def test_merge_tool_outputs_empty_dict(self, weather_tool):
1101          invoker = ToolInvoker(tools=[weather_tool])
1102          state = State(schema={"weather": {"type": str}})
1103          invoker._merge_tool_outputs(tool=weather_tool, result={}, state=state)
1104          assert state.data == {}
1105  
1106      def test_merge_tool_outputs_no_output_mapping(self, weather_tool):
1107          invoker = ToolInvoker(tools=[weather_tool])
1108          state = State(schema={"weather": {"type": str}})
1109          invoker._merge_tool_outputs(
1110              tool=weather_tool, result={"weather": "sunny", "temperature": 14, "unit": "celsius"}, state=state
1111          )
1112          assert state.data == {}
1113  
1114      def test_merge_tool_outputs_with_output_mapping(self):
1115          weather_tool = Tool(
1116              name="weather_tool",
1117              description="Provides weather information for a given location.",
1118              parameters=weather_parameters,
1119              function=weather_function,
1120              outputs_to_state={"weather": {"source": "weather"}},
1121          )
1122          invoker = ToolInvoker(tools=[weather_tool])
1123          state = State(schema={"weather": {"type": str}})
1124          invoker._merge_tool_outputs(
1125              tool=weather_tool, result={"weather": "sunny", "temperature": 14, "unit": "celsius"}, state=state
1126          )
1127          assert state.data == {"weather": "sunny"}
1128  
1129      def test_merge_tool_outputs_with_output_mapping_2(self):
1130          weather_tool = Tool(
1131              name="weather_tool",
1132              description="Provides weather information for a given location.",
1133              parameters=weather_parameters,
1134              function=weather_function,
1135              outputs_to_state={"all_weather_results": {}},
1136          )
1137          invoker = ToolInvoker(tools=[weather_tool])
1138          state = State(schema={"all_weather_results": {"type": str}})
1139          invoker._merge_tool_outputs(
1140              tool=weather_tool, result={"weather": "sunny", "temperature": 14, "unit": "celsius"}, state=state
1141          )
1142          assert state.data == {"all_weather_results": {"weather": "sunny", "temperature": 14, "unit": "celsius"}}
1143  
1144      def test_merge_tool_outputs_source_key_absent_does_not_corrupt_list_state(self):
1145          """
1146          Simulates a PipelineTool wrapping a pipeline with a conditional branch that may not execute, resulting in the
1147          source key being absent from the tool result. The test verifies that in this case, the existing list in state
1148          is not corrupted by appending None.
1149          """
1150          tool = Tool(
1151              name="retrieval",
1152              description="mock",
1153              parameters={"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]},
1154              function=lambda query: {},
1155              outputs_to_state={"documents": {"source": "documents_output"}},
1156          )
1157          invoker = ToolInvoker(tools=[tool])
1158          existing_doc = Document(content="from first call")
1159          state = State(schema={"documents": {"type": list[Document]}})
1160          state.set("documents", [existing_doc])
1161  
1162          # Tool result where the source key is absent (document extraction branch did not execute)
1163          invoker._merge_tool_outputs(tool=tool, result={"result": "no web results found"}, state=state)
1164  
1165          assert state.data["documents"] == [existing_doc]
1166          assert None not in state.data["documents"]
1167  
1168      def test_merge_tool_outputs_with_output_mapping_and_handler(self):
1169          handler = lambda _, new: f"{new}"  # noqa: E731
1170          weather_tool = Tool(
1171              name="weather_tool",
1172              description="Provides weather information for a given location.",
1173              parameters=weather_parameters,
1174              function=weather_function,
1175              outputs_to_state={"temperature": {"source": "temperature", "handler": handler}},
1176          )
1177          invoker = ToolInvoker(tools=[weather_tool])
1178          state = State(schema={"temperature": {"type": str}})
1179          invoker._merge_tool_outputs(
1180              tool=weather_tool, result={"weather": "sunny", "temperature": 14, "unit": "celsius"}, state=state
1181          )
1182          assert state.data == {"temperature": "14"}
1183  
1184      def test_process_output_empty_config(self, invoker, base64_image_string):
1185          image_content = ImageContent(base64_image=base64_image_string, mime_type="image/png")
1186  
1187          result = invoker._process_output(
1188              config={"raw_result": True},
1189              result=[image_content],
1190              tool_call=ToolCall(tool_name="retrieve_image", arguments={}),
1191          )
1192          assert result == [image_content]
1193  
1194      def test_process_output_source_only(self, invoker, base64_image_string):
1195          image_content = ImageContent(base64_image=base64_image_string, mime_type="image/png")
1196  
1197          result = invoker._process_output(
1198              config={"source": "images", "raw_result": True},
1199              result={"images": [image_content]},
1200              tool_call=ToolCall(tool_name="retrieve_image", arguments={}),
1201          )
1202          assert result == [image_content]
1203  
1204      def test_process_output_handler_only(self, invoker, base64_image_string):
1205          def handler(result: dict) -> list[ImageContent]:
1206              return [ImageContent(base64_image=result["base64_image_string"], mime_type=result["mime_type"])]
1207  
1208          result = invoker._process_output(
1209              config={"handler": handler, "raw_result": True},
1210              result={"base64_image_string": base64_image_string, "mime_type": "image/png"},
1211              tool_call=ToolCall(tool_name="retrieve_image", arguments={}),
1212          )
1213          assert result == [ImageContent(base64_image=base64_image_string, mime_type="image/png")]
1214  
1215      def test_process_output_source_and_handler(self, invoker, base64_image_string):
1216          def handler(result: dict) -> list[ImageContent]:
1217              return [ImageContent(base64_image=result["base64_image_string"], mime_type=result["mime_type"])]
1218  
1219          result = invoker._process_output(
1220              config={"source": "images", "handler": handler, "raw_result": True},
1221              result={
1222                  "images": {"base64_image_string": base64_image_string, "mime_type": "image/png"},
1223                  "other_key": "other_value",
1224              },
1225              tool_call=ToolCall(tool_name="retrieve_image", arguments={}),
1226          )
1227          assert result == [ImageContent(base64_image=base64_image_string, mime_type="image/png")]
1228  
1229      def test_output_to_result_e2e(self, weather_tool):
1230          def handler(result):
1231              return [
1232                  TextContent(text=f"weather: {result['weather']}"),
1233                  TextContent(text=f"temperature: {result['temperature']} {result['unit']}"),
1234              ]
1235  
1236          weather_tool.outputs_to_string = {"handler": handler, "raw_result": True}
1237  
1238          invoker = ToolInvoker(tools=[weather_tool])
1239  
1240          message = ChatMessage.from_assistant(
1241              tool_calls=[ToolCall(tool_name="weather_tool", arguments={"location": "Berlin"})]
1242          )
1243  
1244          tool_messages = invoker.run(messages=[message])["tool_messages"]
1245  
1246          assert tool_messages[0].tool_call_results[0].result == [
1247              TextContent(text="weather: mostly sunny"),
1248              TextContent(text="temperature: 7 celsius"),
1249          ]
1250  
1251  
1252  class TestWarmUpTools:
1253      """Tests for Tool/Toolset warm_up through ToolInvoker"""
1254  
1255      def test_tool_invoker_warm_up_with_single_tool(self):
1256          """Test that ToolInvoker.warm_up() calls warm_up on a single tool."""
1257          tool = WarmupTrackingTool(
1258              name="test_tool",
1259              description="A test tool",
1260              parameters={"type": "object", "properties": {}},
1261              function=lambda: "test",
1262          )
1263  
1264          invoker = ToolInvoker(tools=[tool])
1265  
1266          assert not tool.was_warmed_up
1267          invoker.warm_up()
1268          assert tool.was_warmed_up
1269  
1270      def test_tool_invoker_warm_up_with_multiple_tools(self):
1271          """Test that ToolInvoker.warm_up() calls warm_up on multiple tools."""
1272          tool1 = WarmupTrackingTool(
1273              name="tool1",
1274              description="First tool",
1275              parameters={"type": "object", "properties": {}},
1276              function=lambda: "tool1",
1277          )
1278          tool2 = WarmupTrackingTool(
1279              name="tool2",
1280              description="Second tool",
1281              parameters={"type": "object", "properties": {}},
1282              function=lambda: "tool2",
1283          )
1284  
1285          invoker = ToolInvoker(tools=[tool1, tool2])
1286  
1287          assert not tool1.was_warmed_up
1288          assert not tool2.was_warmed_up
1289  
1290          invoker.warm_up()
1291  
1292          assert tool1.was_warmed_up
1293          assert tool2.was_warmed_up
1294  
1295      def test_tool_invoker_warm_up_with_toolset(self, weather_tool):
1296          """Test that ToolInvoker.warm_up() calls warm_up on the toolset."""
1297          toolset = WarmupTrackingToolset([weather_tool])
1298          invoker = ToolInvoker(tools=toolset)
1299  
1300          assert not toolset.was_warmed_up
1301          invoker.warm_up()
1302          assert toolset.was_warmed_up
1303  
1304      def test_tool_invoker_warm_up_with_mixed_toolsets(self):
1305          """Test that ToolInvoker.warm_up() works with combined toolsets using concatenation."""
1306          # Create first toolset with a tracking tool
1307          tool1 = WarmupTrackingTool(
1308              name="tool1",
1309              description="First tool",
1310              parameters={"type": "object", "properties": {}},
1311              function=lambda: "tool1",
1312          )
1313          toolset1 = WarmupTrackingToolset([tool1])
1314  
1315          # Create second toolset with another tracking tool
1316          tool2 = WarmupTrackingTool(
1317              name="tool2",
1318              description="Second tool",
1319              parameters={"type": "object", "properties": {}},
1320              function=lambda: "tool2",
1321          )
1322          toolset2 = WarmupTrackingToolset([tool2])
1323  
1324          # Combine toolsets using the + operator (creates _ToolsetWrapper)
1325          combined = toolset1 + toolset2
1326  
1327          # Create invoker with the combined toolset
1328          invoker = ToolInvoker(tools=combined)
1329  
1330          assert not toolset1.was_warmed_up
1331          assert not toolset2.was_warmed_up
1332  
1333          invoker.warm_up()
1334  
1335          # Both toolsets should be warmed up
1336          assert toolset1.was_warmed_up
1337          assert toolset2.was_warmed_up
1338  
1339      def test_tool_invoker_warm_up_with_mixed_list_of_tools_and_toolsets(self):
1340          """Test that ToolInvoker.warm_up() works with a mixed list of Tools and Toolsets."""
1341          # Create standalone tracking tools
1342          tool1 = WarmupTrackingTool(
1343              name="standalone_tool1",
1344              description="First standalone tool",
1345              parameters={"type": "object", "properties": {}},
1346              function=lambda: "tool1",
1347          )
1348          tool2 = WarmupTrackingTool(
1349              name="standalone_tool2",
1350              description="Second standalone tool",
1351              parameters={"type": "object", "properties": {}},
1352              function=lambda: "tool2",
1353          )
1354  
1355          # Create toolsets with tracking
1356          tool3 = WarmupTrackingTool(
1357              name="toolset_tool1",
1358              description="Tool in toolset 1",
1359              parameters={"type": "object", "properties": {}},
1360              function=lambda: "tool3",
1361          )
1362          toolset1 = WarmupTrackingToolset([tool3])
1363  
1364          tool4 = WarmupTrackingTool(
1365              name="toolset_tool2",
1366              description="Tool in toolset 2",
1367              parameters={"type": "object", "properties": {}},
1368              function=lambda: "tool4",
1369          )
1370          toolset2 = WarmupTrackingToolset([tool4])
1371  
1372          # Create invoker with mixed list: Tool, Toolset, Tool, Toolset
1373          invoker = ToolInvoker(tools=[tool1, toolset1, tool2, toolset2])
1374  
1375          # Verify nothing is warmed up initially
1376          assert not tool1.was_warmed_up
1377          assert not tool2.was_warmed_up
1378          assert not toolset1.was_warmed_up
1379          assert not toolset2.was_warmed_up
1380  
1381          # Warm up
1382          invoker.warm_up()
1383  
1384          # Verify standalone tools are warmed up
1385          assert tool1.was_warmed_up
1386          assert tool2.was_warmed_up
1387  
1388          # Verify toolsets themselves are warmed up (not just their internal tools)
1389          assert toolset1.was_warmed_up
1390          assert toolset2.was_warmed_up
1391  
1392      def test_tool_invoker_warm_up_is_idempotent(self):
1393          """Test that ToolInvoker.warm_up() is idempotent and only warms up once."""
1394  
1395          class WarmupCountingTool(Tool):
1396              """A tool that counts how many times warm_up was called."""
1397  
1398              def __init__(self, *args, **kwargs):
1399                  super().__init__(*args, **kwargs)
1400                  self.warm_up_count = 0
1401  
1402              def warm_up(self):
1403                  self.warm_up_count += 1
1404  
1405          tool = WarmupCountingTool(
1406              name="counting_tool",
1407              description="A tool that counts warm_up calls",
1408              parameters={"type": "object", "properties": {}},
1409              function=lambda: "test",
1410          )
1411  
1412          invoker = ToolInvoker(tools=[tool])
1413  
1414          # Call warm_up multiple times
1415          invoker.warm_up()
1416          invoker.warm_up()
1417          invoker.warm_up()
1418  
1419          # Should only be warmed up once
1420          assert tool.warm_up_count == 1
1421  
1422      def test_warm_up_refreshes_tools_with_names(self):
1423          """
1424          Test that ToolInvoker.warm_up() refreshes _tools_with_names when using a toolset with lazy connection.
1425          """
1426          # Create placeholder tool that simulates MCPToolset behavior of lazy connection
1427          placeholder_tool = Tool(
1428              name="mcp_not_connected_placeholder_123",
1429              description="Placeholder tool before connection",
1430              parameters={"type": "object", "properties": {}},
1431              function=lambda: "placeholder",
1432          )
1433  
1434          # Create the actual tool that will replace the placeholder during warmup
1435          # This simulates what mcp-server-time mcp would provide
1436          actual_tool = Tool(
1437              name="get_time",
1438              description="Get the current time in ISO format",
1439              parameters={"type": "object", "properties": {}, "required": []},
1440              function=lambda: "2024-12-01T12:00:00Z",
1441          )
1442  
1443          # Create a toolset that simulates MCPToolset with eager_connect=False (lazy connection)
1444          class MockMCPToolset(Toolset):
1445              """Simulates MCPToolset behavior with eager_connect=False."""
1446  
1447              def __init__(self):
1448                  # Start with placeholder tools (like MCPToolset does when not eagerly connected)
1449                  super().__init__([placeholder_tool])
1450                  self._warmed_up = False
1451  
1452              def warm_up(self):
1453                  """Simulate connecting to MCP server and replacing placeholder tools with actual tools."""
1454                  if not self._warmed_up:
1455                      # Replace placeholder tools with actual tools (simulating MCP connection)
1456                      self.tools = [actual_tool]
1457                      self._warmed_up = True
1458  
1459          mcp_toolset = MockMCPToolset()
1460          invoker = ToolInvoker(tools=mcp_toolset)
1461  
1462          # Before warmup: _tools_with_names should contain the placeholder tool
1463          assert "mcp_not_connected_placeholder_123" in invoker._tools_with_names
1464          assert "get_time" not in invoker._tools_with_names
1465  
1466          # Call warm_up() directly to trigger tool refresh
1467          invoker.warm_up()
1468  
1469          # After warmup: _tools_with_names should be refreshed with actual tool names
1470          assert "mcp_not_connected_placeholder_123" not in invoker._tools_with_names
1471          assert "get_time" in invoker._tools_with_names