strategies.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 from dataclasses import replace 6 from typing import TYPE_CHECKING, Any 7 8 from haystack.components.agents import State 9 from haystack.components.tools.tool_invoker import ToolInvoker 10 from haystack.core.serialization import default_from_dict, default_to_dict 11 from haystack.dataclasses import ChatMessage, StreamingCallbackT, ToolCall 12 from haystack.human_in_the_loop import ToolExecutionDecision 13 from haystack.human_in_the_loop.types import ConfirmationPolicy, ConfirmationStrategy, ConfirmationUI 14 from haystack.tools import Tool 15 from haystack.utils.deserialization import deserialize_component_inplace 16 17 # To prevent circular imports 18 if TYPE_CHECKING: 19 from haystack.components.agents.agent import _ExecutionContext 20 21 REJECTION_FEEDBACK_TEMPLATE = "Tool execution for '{tool_name}' was rejected by the user." 22 MODIFICATION_FEEDBACK_TEMPLATE = ( 23 "The parameters for tool '{tool_name}' were updated by the user to:\n{final_tool_params}" 24 ) 25 USER_FEEDBACK_TEMPLATE = "With user feedback: {feedback}" 26 27 28 class BlockingConfirmationStrategy: 29 """ 30 Confirmation strategy that blocks execution to gather user feedback. 31 """ 32 33 def __init__( 34 self, 35 *, 36 confirmation_policy: ConfirmationPolicy, 37 confirmation_ui: ConfirmationUI, 38 reject_template: str = REJECTION_FEEDBACK_TEMPLATE, 39 modify_template: str = MODIFICATION_FEEDBACK_TEMPLATE, 40 user_feedback_template: str = USER_FEEDBACK_TEMPLATE, 41 ) -> None: 42 """ 43 Initialize the BlockingConfirmationStrategy with a confirmation policy and UI. 44 45 :param confirmation_policy: 46 The confirmation policy to determine when to ask for user confirmation. 47 :param confirmation_ui: 48 The user interface to interact with the user for confirmation. 49 :param reject_template: 50 Template for rejection feedback messages. It should include a `{tool_name}` placeholder. 51 :param modify_template: 52 Template for modification feedback messages. It should include `{tool_name}` and `{final_tool_params}` 53 placeholders. 54 :param user_feedback_template: 55 Template for user feedback messages. It should include a `{feedback}` placeholder. 56 """ 57 self.confirmation_policy = confirmation_policy 58 self.confirmation_ui = confirmation_ui 59 self.reject_template = reject_template 60 self.modify_template = modify_template 61 self.user_feedback_template = user_feedback_template 62 63 def run( 64 self, 65 *, 66 tool_name: str, 67 tool_description: str, 68 tool_params: dict[str, Any], 69 tool_call_id: str | None = None, 70 confirmation_strategy_context: dict[str, Any] | None = None, # noqa: ARG002 71 ) -> ToolExecutionDecision: 72 """ 73 Run the human-in-the-loop strategy for a given tool and its parameters. 74 75 :param tool_name: 76 The name of the tool to be executed. 77 :param tool_description: 78 The description of the tool. 79 :param tool_params: 80 The parameters to be passed to the tool. 81 :param tool_call_id: 82 Optional unique identifier for the tool call. This can be used to track and correlate the decision with a 83 specific tool invocation. 84 :param confirmation_strategy_context: 85 Optional dictionary for passing request-scoped resources. Useful in web/server environments 86 to provide per-request objects (e.g., WebSocket connections, async queues, Redis pub/sub clients) 87 that strategies can use for non-blocking user interaction. 88 89 :returns: 90 A ToolExecutionDecision indicating whether to execute the tool with the given parameters, or a 91 feedback message if rejected. 92 """ 93 # Check if we should ask based on policy 94 if not self.confirmation_policy.should_ask( 95 tool_name=tool_name, tool_description=tool_description, tool_params=tool_params 96 ): 97 return ToolExecutionDecision( 98 tool_name=tool_name, execute=True, tool_call_id=tool_call_id, final_tool_params=tool_params 99 ) 100 101 # Get user confirmation through UI 102 confirmation_ui_result = self.confirmation_ui.get_user_confirmation(tool_name, tool_description, tool_params) 103 104 # Pass back the result to the policy for any learning/updating 105 self.confirmation_policy.update_after_confirmation( 106 tool_name, tool_description, tool_params, confirmation_ui_result 107 ) 108 109 # Process the confirmation result 110 final_args = {} 111 if confirmation_ui_result.action == "reject": 112 explanation_text = self.reject_template.format(tool_name=tool_name) 113 if confirmation_ui_result.feedback: 114 explanation_text += " " 115 explanation_text += self.user_feedback_template.format(feedback=confirmation_ui_result.feedback) 116 return ToolExecutionDecision( 117 tool_name=tool_name, execute=False, tool_call_id=tool_call_id, feedback=explanation_text 118 ) 119 if confirmation_ui_result.action == "modify" and confirmation_ui_result.new_tool_params: 120 # Update the tool call params with the new params 121 final_args.update(confirmation_ui_result.new_tool_params) 122 explanation_text = self.modify_template.format(tool_name=tool_name, final_tool_params=final_args) 123 if confirmation_ui_result.feedback: 124 explanation_text += " " 125 explanation_text += self.user_feedback_template.format(feedback=confirmation_ui_result.feedback) 126 return ToolExecutionDecision( 127 tool_name=tool_name, 128 tool_call_id=tool_call_id, 129 execute=True, 130 feedback=explanation_text, 131 final_tool_params=final_args, 132 ) 133 # action == "confirm" 134 return ToolExecutionDecision( 135 tool_name=tool_name, execute=True, tool_call_id=tool_call_id, final_tool_params=tool_params 136 ) 137 138 async def run_async( 139 self, 140 *, 141 tool_name: str, 142 tool_description: str, 143 tool_params: dict[str, Any], 144 tool_call_id: str | None = None, 145 confirmation_strategy_context: dict[str, Any] | None = None, 146 ) -> ToolExecutionDecision: 147 """ 148 Async version of run. Calls the sync run() method by default. 149 150 :param tool_name: 151 The name of the tool to be executed. 152 :param tool_description: 153 The description of the tool. 154 :param tool_params: 155 The parameters to be passed to the tool. 156 :param tool_call_id: 157 Optional unique identifier for the tool call. 158 :param confirmation_strategy_context: 159 Optional dictionary for passing request-scoped resources. 160 161 :returns: 162 A ToolExecutionDecision indicating whether to execute the tool with the given parameters. 163 """ 164 return self.run( 165 tool_name=tool_name, 166 tool_description=tool_description, 167 tool_params=tool_params, 168 tool_call_id=tool_call_id, 169 confirmation_strategy_context=confirmation_strategy_context, 170 ) 171 172 def to_dict(self) -> dict[str, Any]: 173 """ 174 Serializes the BlockingConfirmationStrategy to a dictionary. 175 176 :returns: 177 Dictionary with serialized data. 178 """ 179 return default_to_dict( 180 self, 181 confirmation_policy=self.confirmation_policy.to_dict(), 182 confirmation_ui=self.confirmation_ui.to_dict(), 183 reject_template=self.reject_template, 184 modify_template=self.modify_template, 185 user_feedback_template=self.user_feedback_template, 186 ) 187 188 @classmethod 189 def from_dict(cls, data: dict[str, Any]) -> "BlockingConfirmationStrategy": 190 """ 191 Deserializes the BlockingConfirmationStrategy from a dictionary. 192 193 :param data: 194 Dictionary to deserialize from. 195 196 :returns: 197 Deserialized BlockingConfirmationStrategy. 198 """ 199 deserialize_component_inplace(data["init_parameters"], key="confirmation_policy") 200 deserialize_component_inplace(data["init_parameters"], key="confirmation_ui") 201 return default_from_dict(cls, data) 202 203 204 def _get_confirmation_strategy( 205 *, tool_name: str, confirmation_strategies: dict[str | tuple[str, ...], ConfirmationStrategy] 206 ) -> ConfirmationStrategy | None: 207 """ 208 Get the confirmation strategy for a given tool name. 209 210 :param tool_name: 211 The name of the tool to look up. 212 :param confirmation_strategies: 213 Dictionary of confirmation strategies with string or tuple keys. 214 :returns: 215 The confirmation strategy if found, None otherwise. 216 """ 217 if tool_name in confirmation_strategies: 218 return confirmation_strategies[tool_name] 219 220 for key, strategy in confirmation_strategies.items(): 221 if isinstance(key, tuple) and tool_name in key: 222 return strategy 223 224 return None 225 226 227 def _prepare_tool_args( 228 *, 229 tool: Tool, 230 tool_call_arguments: dict[str, Any], 231 state: State, 232 streaming_callback: StreamingCallbackT | None = None, 233 enable_streaming_passthrough: bool = False, 234 ) -> dict[str, Any]: 235 """ 236 Prepare the final arguments for a tool by injecting state inputs and optionally a streaming callback. 237 238 :param tool: 239 The tool instance to prepare arguments for. 240 :param tool_call_arguments: 241 The initial arguments provided for the tool call. 242 :param state: 243 The current state containing inputs to be injected into the tool arguments. 244 :param streaming_callback: 245 Optional streaming callback to be injected if enabled and applicable. 246 :param enable_streaming_passthrough: 247 Flag indicating whether to inject the streaming callback into the tool arguments. 248 249 :returns: 250 A dictionary of final arguments ready for tool invocation. 251 """ 252 # Combine user + state inputs 253 final_args = ToolInvoker._inject_state_args(tool, tool_call_arguments.copy(), state) 254 # Check whether to inject streaming_callback 255 if ( 256 enable_streaming_passthrough 257 and streaming_callback is not None 258 and "streaming_callback" not in final_args 259 and "streaming_callback" in ToolInvoker._get_func_params(tool) 260 ): 261 final_args["streaming_callback"] = streaming_callback 262 return final_args 263 264 265 def _process_confirmation_strategies( 266 *, 267 confirmation_strategies: dict[str | tuple[str, ...], ConfirmationStrategy], 268 messages_with_tool_calls: list[ChatMessage], 269 execution_context: "_ExecutionContext", 270 ) -> tuple[list[ChatMessage], list[ChatMessage]]: 271 """ 272 Run the confirmation strategies and return modified tool call messages and updated chat history. 273 274 :param confirmation_strategies: Mapping of tool names to their corresponding confirmation strategies 275 :param messages_with_tool_calls: Chat messages containing tool calls 276 :param execution_context: The current execution context of the agent 277 :returns: 278 Tuple of modified messages with confirmed tool calls and updated chat history 279 """ 280 # If confirmations strategies is empty, return original messages and chat history 281 if not confirmation_strategies: 282 return messages_with_tool_calls, execution_context.state.get("messages") 283 284 # Run confirmation strategies and get tool execution decisions 285 teds = _run_confirmation_strategies( 286 confirmation_strategies=confirmation_strategies, 287 messages_with_tool_calls=messages_with_tool_calls, 288 execution_context=execution_context, 289 ) 290 291 # Apply tool execution decisions to messages_with_tool_calls 292 rejection_messages, modified_tool_call_messages = _apply_tool_execution_decisions( 293 tool_call_messages=messages_with_tool_calls, tool_execution_decisions=teds 294 ) 295 296 # Update the chat history with rejection messages and new tool call messages 297 new_chat_history = _update_chat_history( 298 chat_history=execution_context.state.get("messages"), 299 rejection_messages=rejection_messages, 300 tool_call_and_explanation_messages=modified_tool_call_messages, 301 ) 302 303 return modified_tool_call_messages, new_chat_history 304 305 306 async def _process_confirmation_strategies_async( 307 *, 308 confirmation_strategies: dict[str | tuple[str, ...], ConfirmationStrategy], 309 messages_with_tool_calls: list[ChatMessage], 310 execution_context: "_ExecutionContext", 311 ) -> tuple[list[ChatMessage], list[ChatMessage]]: 312 """ 313 Async version of _process_confirmation_strategies. 314 315 Run the confirmation strategies and return modified tool call messages and updated chat history. 316 317 :param confirmation_strategies: Mapping of tool names to their corresponding confirmation strategies 318 :param messages_with_tool_calls: Chat messages containing tool calls 319 :param execution_context: The current execution context of the agent 320 :returns: 321 Tuple of modified messages with confirmed tool calls and updated chat history 322 """ 323 # If confirmations strategies is empty, return original messages and chat history 324 if not confirmation_strategies: 325 return messages_with_tool_calls, execution_context.state.get("messages") 326 327 # Run confirmation strategies and get tool execution decisions (async version) 328 teds = await _run_confirmation_strategies_async( 329 confirmation_strategies=confirmation_strategies, 330 messages_with_tool_calls=messages_with_tool_calls, 331 execution_context=execution_context, 332 ) 333 334 # Apply tool execution decisions to messages_with_tool_calls 335 rejection_messages, modified_tool_call_messages = _apply_tool_execution_decisions( 336 tool_call_messages=messages_with_tool_calls, tool_execution_decisions=teds 337 ) 338 339 # Update the chat history with rejection messages and new tool call messages 340 new_chat_history = _update_chat_history( 341 chat_history=execution_context.state.get("messages"), 342 rejection_messages=rejection_messages, 343 tool_call_and_explanation_messages=modified_tool_call_messages, 344 ) 345 346 return modified_tool_call_messages, new_chat_history 347 348 349 def _run_confirmation_strategies( 350 confirmation_strategies: dict[str | tuple[str, ...], ConfirmationStrategy], 351 messages_with_tool_calls: list[ChatMessage], 352 execution_context: "_ExecutionContext", 353 ) -> list[ToolExecutionDecision]: 354 """ 355 Run confirmation strategies for tool calls in the provided chat messages. 356 357 :param confirmation_strategies: Mapping of tool names to their corresponding confirmation strategies 358 :param messages_with_tool_calls: Messages containing tool calls to process 359 :param execution_context: The current execution context containing state and inputs 360 :returns: 361 A list of ToolExecutionDecision objects representing the decisions made for each tool call. 362 """ 363 state = execution_context.state 364 tools_with_names = {tool.name: tool for tool in execution_context.tool_invoker_inputs["tools"]} 365 existing_teds = execution_context.tool_execution_decisions if execution_context.tool_execution_decisions else [] 366 existing_teds_by_name = {ted.tool_name: ted for ted in existing_teds if ted.tool_name} 367 existing_teds_by_id = {ted.tool_call_id: ted for ted in existing_teds if ted.tool_call_id} 368 369 teds = [] 370 for message in messages_with_tool_calls: 371 if not message.tool_calls: 372 continue 373 374 for tool_call in message.tool_calls: 375 tool_name = tool_call.tool_name 376 tool_to_invoke = tools_with_names[tool_name] 377 378 # Prepare final tool args 379 final_args = _prepare_tool_args( 380 tool=tool_to_invoke, 381 tool_call_arguments=tool_call.arguments, 382 state=state, 383 streaming_callback=execution_context.tool_invoker_inputs.get("streaming_callback"), 384 enable_streaming_passthrough=execution_context.tool_invoker_inputs.get( 385 "enable_streaming_passthrough", False 386 ), 387 ) 388 389 # Get tool execution decisions from confirmation strategies 390 # If no confirmation strategy is defined for this tool, proceed with execution 391 strategy = _get_confirmation_strategy(tool_name=tool_name, confirmation_strategies=confirmation_strategies) 392 if strategy is None: 393 teds.append( 394 ToolExecutionDecision( 395 tool_call_id=tool_call.id, tool_name=tool_name, execute=True, final_tool_params=final_args 396 ) 397 ) 398 continue 399 400 # Check if there's already a decision for this tool call in the execution context 401 ted = existing_teds_by_id.get(tool_call.id or "") or existing_teds_by_name.get(tool_name) 402 403 # If not, run the confirmation strategy 404 if not ted: 405 ted = strategy.run( 406 tool_name=tool_name, 407 tool_description=tool_to_invoke.description, 408 tool_params=final_args, 409 tool_call_id=tool_call.id, 410 confirmation_strategy_context=execution_context.confirmation_strategy_context, 411 ) 412 teds.append(ted) 413 414 return teds 415 416 417 async def _run_confirmation_strategies_async( 418 confirmation_strategies: dict[str | tuple[str, ...], ConfirmationStrategy], 419 messages_with_tool_calls: list[ChatMessage], 420 execution_context: "_ExecutionContext", 421 ) -> list[ToolExecutionDecision]: 422 """ 423 Async version of _run_confirmation_strategies. 424 425 Run confirmation strategies for tool calls in the provided chat messages. 426 427 :param confirmation_strategies: Mapping of tool names to their corresponding confirmation strategies 428 String keys map individual tools, tuple keys map multiple tools to the same strategy. 429 :param messages_with_tool_calls: Messages containing tool calls to process 430 :param execution_context: The current execution context containing state and inputs 431 :returns: 432 A list of ToolExecutionDecision objects representing the decisions made for each tool call. 433 """ 434 state = execution_context.state 435 tools_with_names = {tool.name: tool for tool in execution_context.tool_invoker_inputs["tools"]} 436 existing_teds = execution_context.tool_execution_decisions if execution_context.tool_execution_decisions else [] 437 existing_teds_by_name = {ted.tool_name: ted for ted in existing_teds if ted.tool_name} 438 existing_teds_by_id = {ted.tool_call_id: ted for ted in existing_teds if ted.tool_call_id} 439 440 teds = [] 441 for message in messages_with_tool_calls: 442 if not message.tool_calls: 443 continue 444 445 for tool_call in message.tool_calls: 446 tool_name = tool_call.tool_name 447 tool_to_invoke = tools_with_names[tool_name] 448 449 # Prepare final tool args 450 final_args = _prepare_tool_args( 451 tool=tool_to_invoke, 452 tool_call_arguments=tool_call.arguments, 453 state=state, 454 streaming_callback=execution_context.tool_invoker_inputs.get("streaming_callback"), 455 enable_streaming_passthrough=execution_context.tool_invoker_inputs.get( 456 "enable_streaming_passthrough", False 457 ), 458 ) 459 460 # Get tool execution decisions from confirmation strategies 461 # If no confirmation strategy is defined for this tool, proceed with execution 462 strategy = _get_confirmation_strategy(tool_name=tool_name, confirmation_strategies=confirmation_strategies) 463 if strategy is None: 464 teds.append( 465 ToolExecutionDecision( 466 tool_call_id=tool_call.id, tool_name=tool_name, execute=True, final_tool_params=final_args 467 ) 468 ) 469 continue 470 471 # Check if there's already a decision for this tool call in the execution context 472 ted = existing_teds_by_id.get(tool_call.id or "") or existing_teds_by_name.get(tool_name) 473 474 # If not, run the confirmation strategy (async version) 475 if not ted: 476 # Use run_async if available, otherwise fall back to sync run 477 if hasattr(strategy, "run_async"): 478 ted = await strategy.run_async( 479 tool_name=tool_name, 480 tool_description=tool_to_invoke.description, 481 tool_params=final_args, 482 tool_call_id=tool_call.id, 483 confirmation_strategy_context=execution_context.confirmation_strategy_context, 484 ) 485 else: 486 ted = strategy.run( 487 tool_name=tool_name, 488 tool_description=tool_to_invoke.description, 489 tool_params=final_args, 490 tool_call_id=tool_call.id, 491 confirmation_strategy_context=execution_context.confirmation_strategy_context, 492 ) 493 teds.append(ted) 494 495 return teds 496 497 498 def _apply_tool_execution_decisions( 499 tool_call_messages: list[ChatMessage], tool_execution_decisions: list[ToolExecutionDecision] 500 ) -> tuple[list[ChatMessage], list[ChatMessage]]: 501 """ 502 Apply the tool execution decisions to the tool call messages. 503 504 :param tool_call_messages: The tool call messages to apply the decisions to. 505 :param tool_execution_decisions: The tool execution decisions to apply. 506 :returns: 507 A tuple containing: 508 - A list of rejection messages for rejected tool calls. These are pairs of tool call and tool call result 509 messages. 510 - A list of tool call messages for confirmed or modified tool calls. If tool parameters were modified, 511 a user message explaining the modification is included before the tool call message. 512 """ 513 decision_by_id = {d.tool_call_id: d for d in tool_execution_decisions if d.tool_call_id} 514 decision_by_name = {d.tool_name: d for d in tool_execution_decisions if d.tool_name} 515 516 # Known limitation: If tool calls are missing IDs, we rely on tool names to match decisions to tool calls. 517 # This can lead to incorrect matches if there are multiple tool calls in the provided messages with duplicate names. 518 if not decision_by_id and len(decision_by_name) < len(tool_execution_decisions): 519 raise ValueError( 520 "ToolExecutionDecisions are missing tool_call_id fields and there are multiple tool calls with the same " 521 "name. When multiple tool calls with the same name are present, tool_call_id is required to correctly " 522 "match decisions to tool calls." 523 ) 524 525 def make_assistant_message(chat_message: ChatMessage, tool_calls: list[ToolCall]) -> ChatMessage: 526 return ChatMessage.from_assistant( 527 text=chat_message.text, 528 meta=chat_message.meta, 529 name=chat_message.name, 530 tool_calls=tool_calls, 531 reasoning=chat_message.reasoning, 532 ) 533 534 new_tool_call_messages = [] 535 rejection_messages = [] 536 537 for chat_msg in tool_call_messages: 538 new_tool_calls = [] 539 for tc in chat_msg.tool_calls or []: 540 ted = decision_by_id.get(tc.id or "") or decision_by_name.get(tc.tool_name) 541 if not ted: 542 # This shouldn't happen, if so something went wrong in _run_confirmation_strategies 543 continue 544 545 if not ted.execute: 546 # rejected tool call 547 tool_result_text = ted.feedback or REJECTION_FEEDBACK_TEMPLATE.format(tool_name=tc.tool_name) 548 rejection_messages.extend( 549 [ 550 make_assistant_message(chat_msg, [tc]), 551 ChatMessage.from_tool(tool_result=tool_result_text, origin=tc, error=True), 552 ] 553 ) 554 continue 555 556 # Covers confirm and modify cases 557 final_args = ted.final_tool_params or {} 558 if tc.arguments != final_args: 559 # In the modify case we add a user message explaining the modification otherwise the LLM won't know 560 # why the tool parameters changed and will likely just try and call the tool again with the 561 # original parameters. 562 user_text = ted.feedback or MODIFICATION_FEEDBACK_TEMPLATE.format( 563 tool_name=tc.tool_name, final_tool_params=final_args 564 ) 565 new_tool_call_messages.append(ChatMessage.from_user(text=user_text)) 566 new_tool_calls.append(replace(tc, arguments=final_args)) 567 568 # Only add the tool call message if there are any tool calls left (i.e. not all were rejected) 569 if new_tool_calls: 570 new_tool_call_messages.append(make_assistant_message(chat_msg, new_tool_calls)) 571 572 return rejection_messages, new_tool_call_messages 573 574 575 def _update_chat_history( 576 chat_history: list[ChatMessage], 577 rejection_messages: list[ChatMessage], 578 tool_call_and_explanation_messages: list[ChatMessage], 579 ) -> list[ChatMessage]: 580 """ 581 Update the chat history to include rejection messages and tool call messages at the appropriate positions. 582 583 Steps: 584 1. Identify the last user message and the last tool message in the current chat history. 585 2. Determine the insertion point as the maximum index of these two messages. 586 3. Create a new chat history that includes: 587 - All messages up to the insertion point. 588 - Any rejection messages (pairs of tool call and tool call result messages). 589 - Any tool call messages for confirmed or modified tool calls, including user messages explaining modifications. 590 591 :param chat_history: The current chat history. 592 :param rejection_messages: Chat messages to add for rejected tool calls (pairs of tool call and tool call result 593 messages). 594 :param tool_call_and_explanation_messages: Tool call messages for confirmed or modified tool calls, which may 595 include user messages explaining modifications. 596 :returns: 597 The updated chat history. 598 """ 599 user_indices = [i for i, message in enumerate(chat_history) if message.is_from("user")] 600 tool_indices = [i for i, message in enumerate(chat_history) if message.is_from("tool")] 601 602 last_user_idx = max(user_indices) if user_indices else -1 603 last_tool_idx = max(tool_indices) if tool_indices else -1 604 605 insertion_point = max(last_user_idx, last_tool_idx) 606 607 return chat_history[: insertion_point + 1] + rejection_messages + tool_call_and_explanation_messages