component.py
1 """ 2 Custom Solace AI Connector Component to Host Google ADK Agents via A2A Protocol. 3 """ 4 5 import asyncio 6 import concurrent.futures 7 import fnmatch 8 import inspect 9 import json 10 import logging 11 import os 12 import threading 13 import time 14 from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union 15 16 from a2a.types import ( 17 AgentCard, 18 MessageSendParams, 19 SendMessageRequest, 20 TaskState, 21 TaskStatus, 22 TaskStatusUpdateEvent, 23 ) 24 from a2a.types import Artifact as A2AArtifact 25 from a2a.types import Message as A2AMessage 26 from google.adk.agents import LlmAgent, RunConfig 27 from google.adk.agents.callback_context import CallbackContext 28 from google.adk.agents.invocation_context import LlmCallsLimitExceededError 29 from google.adk.agents.readonly_context import ReadonlyContext 30 from google.adk.agents.run_config import StreamingMode 31 from google.adk.artifacts import BaseArtifactService 32 from google.adk.auth.credential_service.base_credential_service import ( 33 BaseCredentialService, 34 ) 35 from google.adk.events import Event as ADKEvent 36 from google.adk.memory import BaseMemoryService 37 from google.adk.models import LlmResponse 38 from google.adk.models.llm_request import LlmRequest 39 from google.adk.runners import Runner 40 from google.adk.sessions import BaseSessionService 41 from google.adk.tools import FunctionTool 42 from google.adk.tools.mcp_tool import MCPToolset 43 from google.adk.tools.openapi_tool import OpenAPIToolset 44 from google.genai import types as adk_types 45 from litellm.exceptions import BadRequestError 46 from pydantic import BaseModel, ValidationError 47 from solace_ai_connector.common.event import Event, EventType 48 from solace_ai_connector.common.message import Message as SolaceMessage 49 from solace_ai_connector.common.utils import import_module 50 51 from ...agent.adk.runner import TaskCancelledError, run_adk_async_task_thread_wrapper 52 from ...agent.adk.session_compaction import SessionCompactionState 53 from ...agent.adk.services import ( 54 initialize_artifact_service, 55 initialize_credential_service, 56 initialize_memory_service, 57 initialize_session_service, 58 ) 59 from ...agent.adk.callbacks import _generate_tool_instructions_from_registry 60 from ...agent.adk.setup import ( 61 initialize_adk_agent, 62 initialize_adk_runner, 63 load_adk_tools, 64 ) 65 from ...agent.adk.tool_wrapper import ADKToolWrapper 66 from ...agent.protocol.event_handlers import process_event, publish_agent_card 67 from ...agent.tools.peer_agent_tool import ( 68 CORRELATION_DATA_PREFIX, 69 PEER_TOOL_PREFIX, 70 PeerAgentTool, 71 ) 72 from ...agent.tools.workflow_tool import WorkflowAgentTool 73 from ...agent.tools.registry import tool_registry 74 from ...agent.utils.config_parser import resolve_instruction_provider 75 from ...common import a2a 76 from ...common.a2a.translation import format_and_route_adk_event 77 from ...common.a2a.types import ArtifactInfo 78 from ...common.agent_registry import AgentRegistry 79 from ...common.constants import ( 80 DEFAULT_COMMUNICATION_TIMEOUT, 81 HEALTH_CHECK_INTERVAL_SECONDS, 82 HEALTH_CHECK_TTL_SECONDS, 83 EXTENSION_URI_AGENT_TYPE, 84 EXTENSION_URI_SCHEMAS, 85 ) 86 from ...common.data_parts import AgentProgressUpdateData, ArtifactSavedData 87 from ...common.error_handlers import get_error_message, is_llm_exception 88 from ...common.middleware.registry import MiddlewareRegistry 89 from ...common.sac.sam_component_base import SamComponentBase 90 from ...common.utils.rbac_utils import validate_agent_access 91 from .structured_invocation.handler import StructuredInvocationHandler 92 93 log = logging.getLogger(__name__) 94 95 if TYPE_CHECKING: 96 from .app import AgentInitCleanupConfig 97 from .task_execution_context import TaskExecutionContext 98 99 info = { 100 "class_name": "SamAgentComponent", 101 "description": ( 102 "Hosts a Google ADK agent and bridges communication via the A2A protocol over Solace. " 103 "NOTE: Configuration is defined in the app-level 'app_config' block " 104 "and validated by 'SamAgentApp.app_schema' when using the associated App class." 105 ), 106 "config_parameters": [], 107 "input_schema": { 108 "type": "object", 109 "description": "Not typically used; component reacts to events.", 110 "properties": {}, 111 }, 112 "output_schema": { 113 "type": "object", 114 "description": "Not typically used; component publishes results to Solace.", 115 "properties": {}, 116 }, 117 } 118 InstructionProvider = Callable[[ReadonlyContext], str] 119 120 121 class SamAgentComponent(SamComponentBase): 122 """ 123 A Solace AI Connector component that hosts a Google ADK agent, 124 communicating via the A2A protocol over Solace. 125 """ 126 127 CORRELATION_DATA_PREFIX = CORRELATION_DATA_PREFIX 128 HOST_COMPONENT_VERSION = "1.0.0-alpha" 129 HEALTH_CHECK_TIMER_ID = "agent_health_check" 130 enable_inline_vision = False 131 132 def __init__(self, **kwargs): 133 """ 134 Initializes the A2A_ADK_HostComponent. 135 Args: 136 **kwargs: Configuration parameters passed from the SAC framework. 137 Expects configuration under app_config. 138 """ 139 if "component_config" in kwargs and "app_config" in kwargs["component_config"]: 140 name = kwargs["component_config"]["app_config"].get("agent_name") 141 if name: 142 kwargs.setdefault("name", name) 143 144 super().__init__(info, **kwargs) 145 self.agent_name = self.get_config("agent_name") 146 log.info( 147 "%s Initializing agent: %s (A2A ADK Host Component)...", 148 self.log_identifier, 149 self.agent_name, 150 ) 151 152 # Initialize the agent registry for health tracking 153 self.agent_registry = AgentRegistry() 154 try: 155 self.namespace = self.get_config("namespace") 156 if not self.namespace: 157 raise ValueError("Internal Error: Namespace missing after validation.") 158 self.supports_streaming = self.get_config("supports_streaming", False) 159 self.stream_batching_threshold_bytes = self.get_config( 160 "stream_batching_threshold_bytes", 0 161 ) 162 self.agent_name = self.get_config("agent_name") 163 if not self.agent_name: 164 raise ValueError("Internal Error: Agent name missing after validation.") 165 self.model_config = self.get_config("model") 166 167 if not self._lazy_model_mode and not self.model_config: 168 raise ValueError( 169 "Internal Error: Model config missing after validation." 170 ) 171 if self._lazy_model_mode: 172 log.info( 173 "%s Lazy model mode enabled. Agent will start without model config.", 174 self.log_identifier, 175 ) 176 177 self.instruction_config = self.get_config("instruction", "") 178 self.global_instruction_config = self.get_config("global_instruction", "") 179 self.tools_config = self.get_config("tools", []) 180 self.planner_config = self.get_config("planner") 181 self.code_executor_config = self.get_config("code_executor") 182 self.session_service_config = self.get_config("session_service") 183 if not self.session_service_config: 184 raise ValueError( 185 "Internal Error: Session service config missing after validation." 186 ) 187 self.default_session_behavior = self.session_service_config.get( 188 "default_behavior", "PERSISTENT" 189 ).upper() 190 if self.default_session_behavior not in ["PERSISTENT", "RUN_BASED"]: 191 log.warning( 192 "%s Invalid 'default_behavior' in session_service_config: '%s'. Defaulting to PERSISTENT.", 193 self.log_identifier, 194 self.default_session_behavior, 195 ) 196 self.default_session_behavior = "PERSISTENT" 197 log.info( 198 "%s Default session behavior set to: %s", 199 self.log_identifier, 200 self.default_session_behavior, 201 ) 202 self.artifact_service_config = self.get_config( 203 "artifact_service", {"type": "memory"} 204 ) 205 self.memory_service_config = self.get_config( 206 "memory_service", {"type": "memory"} 207 ) 208 self.auto_summarization_config = self.get_config( 209 "auto_summarization", { 210 "enabled": True, 211 "compaction_percentage": 0.25 212 } 213 ) 214 self.artifact_handling_mode = self.get_config( 215 "artifact_handling_mode", "ignore" 216 ).lower() 217 if self.artifact_handling_mode not in ["ignore", "embed", "reference"]: 218 log.warning( 219 "%s Invalid artifact_handling_mode '%s'. Defaulting to 'ignore'.", 220 self.log_identifier, 221 self.artifact_handling_mode, 222 ) 223 self.artifact_handling_mode = "ignore" 224 log.info( 225 "%s Artifact Handling Mode: %s", 226 self.log_identifier, 227 self.artifact_handling_mode, 228 ) 229 self.enable_inline_vision = self.get_config( 230 "enable_inline_vision", False 231 ) 232 self.max_inline_vision_images = self.get_config( 233 "max_inline_vision_images", 5 234 ) 235 self.max_inline_vision_bytes = self.get_config( 236 "max_inline_vision_bytes", 20971520 # 20MB 237 ) 238 if self.enable_inline_vision: 239 log.info( 240 "%s Inline vision enabled: image files will be passed directly to the LLM " 241 "(max %d images, max %d bytes).", 242 self.log_identifier, 243 self.max_inline_vision_images, 244 self.max_inline_vision_bytes, 245 ) 246 if self.artifact_handling_mode == "reference": 247 log.warning( 248 "%s Artifact handling mode 'reference' selected, but this component does not currently host an endpoint to serve artifacts. Clients may not be able to retrieve referenced artifacts.", 249 self.log_identifier, 250 ) 251 self.agent_card_config = self.get_config("agent_card") 252 if not self.agent_card_config: 253 raise ValueError( 254 "Internal Error: Agent card config missing after validation." 255 ) 256 self.agent_card_publishing_config = self.get_config("agent_card_publishing") 257 if not self.agent_card_publishing_config: 258 raise ValueError( 259 "Internal Error: Agent card publishing config missing after validation." 260 ) 261 self.agent_discovery_config = self.get_config("agent_discovery") 262 if not self.agent_discovery_config: 263 raise ValueError( 264 "Internal Error: Agent discovery config missing after validation." 265 ) 266 self.inter_agent_communication_config = self.get_config( 267 "inter_agent_communication" 268 ) 269 if not self.inter_agent_communication_config: 270 raise ValueError( 271 "Internal Error: Inter-agent comms config missing after validation." 272 ) 273 274 self.max_message_size_bytes = self.get_config( 275 "max_message_size_bytes", 10_000_000 276 ) 277 278 except Exception as e: 279 log.error( 280 "%s Failed to retrieve configuration via get_config: %s", 281 self.log_identifier, 282 e, 283 ) 284 raise ValueError(f"Configuration retrieval error: {e}") from e 285 self.session_service: BaseSessionService = None 286 self.artifact_service: BaseArtifactService = None 287 self.memory_service: BaseMemoryService = None 288 self.credential_service: Optional[BaseCredentialService] = None 289 self.adk_agent: LlmAgent = None 290 self.runner: Runner = None 291 self.agent_card_tool_manifest: List[Dict[str, Any]] = [] 292 self.tool_scopes_map: Dict[str, List[str]] = {} # Maps tool names to required scopes 293 self.peer_agents: Dict[str, Any] = {} # Keep for backward compatibility 294 self._card_publish_timer_id: str = f"publish_card_{self.agent_name}" 295 self._async_init_future = None 296 self.peer_response_queues: Dict[str, asyncio.Queue] = {} 297 self.peer_response_queue_lock = threading.Lock() 298 self.agent_specific_state: Dict[str, Any] = {} 299 self.active_tasks: Dict[str, "TaskExecutionContext"] = {} 300 self.active_tasks_lock = threading.Lock() 301 self._tool_cleanup_hooks: List[Callable] = [] 302 self._agent_system_instruction_string: Optional[str] = None 303 self._agent_system_instruction_callback: Optional[ 304 Callable[[CallbackContext, LlmRequest], Optional[str]] 305 ] = None 306 self._active_background_tasks = set() 307 308 # Initialize session compaction state for parallel task coordination 309 # Agent-scoped to ensure isolation when multiple agents run in the same process 310 self.session_compaction_state = SessionCompactionState() 311 312 # Initialize structured invocation support 313 self.structured_invocation_handler = StructuredInvocationHandler(self) 314 315 try: 316 self.agent_specific_state: Dict[str, Any] = {} 317 init_func_details = self.get_config("agent_init_function") 318 319 try: 320 log.info( 321 "%s Initializing synchronous ADK services...", self.log_identifier 322 ) 323 self.session_service = initialize_session_service(self) 324 self.artifact_service = initialize_artifact_service(self) 325 self.memory_service = initialize_memory_service(self) 326 self.credential_service = initialize_credential_service(self) 327 328 log.info( 329 "%s Initialized Synchronous ADK services.", self.log_identifier 330 ) 331 except Exception as service_err: 332 log.exception( 333 "%s Failed to initialize synchronous ADK services: %s", 334 self.log_identifier, 335 service_err, 336 ) 337 raise RuntimeError( 338 f"Failed to initialize synchronous ADK services: {service_err}" 339 ) from service_err 340 341 # initialize enterprise features if available 342 try: 343 from solace_agent_mesh_enterprise.init_enterprise_component import ( 344 init_enterprise_component_features, 345 ) 346 347 init_enterprise_component_features(self) 348 except ImportError: 349 # Community edition 350 # Contact Solace support for enterprise features 351 pass 352 353 from .app import ( 354 AgentInitCleanupConfig, 355 ) # delayed import to avoid circular dependency 356 357 if init_func_details and isinstance( 358 init_func_details, AgentInitCleanupConfig 359 ): 360 module_name = init_func_details.get("module") 361 func_name = init_func_details.get("name") 362 base_path = init_func_details.get("base_path") 363 specific_init_params_dict = init_func_details.get("config", {}) 364 if module_name and func_name: 365 log.info( 366 "%s Attempting to load init_function: %s.%s", 367 self.log_identifier, 368 module_name, 369 func_name, 370 ) 371 try: 372 module = import_module(module_name, base_path=base_path) 373 init_function = getattr(module, func_name) 374 if not callable(init_function): 375 raise TypeError( 376 f"Init function '{func_name}' in module '{module_name}' is not callable." 377 ) 378 sig = inspect.signature(init_function) 379 pydantic_config_model = None 380 config_param_name = None 381 validated_config_arg = specific_init_params_dict 382 for param_name_sig, param_sig in sig.parameters.items(): 383 if ( 384 param_sig.annotation is not inspect.Parameter.empty 385 and isinstance(param_sig.annotation, type) 386 and issubclass(param_sig.annotation, BaseModel) 387 ): 388 pydantic_config_model = param_sig.annotation 389 config_param_name = param_name_sig 390 break 391 if pydantic_config_model and config_param_name: 392 log.info( 393 "%s Found Pydantic config model '%s' for init_function parameter '%s'.", 394 self.log_identifier, 395 pydantic_config_model.__name__, 396 config_param_name, 397 ) 398 try: 399 validated_config_arg = pydantic_config_model( 400 **specific_init_params_dict 401 ) 402 except ValidationError as ve: 403 log.error( 404 "%s Validation error for init_function config using Pydantic model '%s': %s", 405 self.log_identifier, 406 pydantic_config_model.__name__, 407 ve, 408 ) 409 raise ValueError( 410 f"Invalid configuration for init_function '{func_name}': {ve}" 411 ) from ve 412 elif ( 413 config_param_name 414 and param_sig.annotation is not inspect.Parameter.empty 415 ): 416 log.warning( 417 "%s Config parameter '%s' for init_function '%s' has a type hint '%s', but it's not a Pydantic BaseModel. Passing raw dict.", 418 self.log_identifier, 419 config_param_name, 420 func_name, 421 param_sig.annotation, 422 ) 423 else: 424 log.info( 425 "%s No Pydantic model type hint found for a config parameter of init_function '%s'. Passing raw dict if a config param exists, or only host_component.", 426 self.log_identifier, 427 func_name, 428 ) 429 func_params_list = list(sig.parameters.values()) 430 num_actual_params = len(func_params_list) 431 if num_actual_params == 1: 432 if specific_init_params_dict: 433 log.warning( 434 "%s Init function '%s' takes 1 argument, but 'config' was provided in YAML. Config will be ignored.", 435 self.log_identifier, 436 func_name, 437 ) 438 init_function(self) 439 elif num_actual_params == 2: 440 actual_config_param_name_in_signature = func_params_list[ 441 1 442 ].name 443 init_function( 444 self, 445 **{ 446 actual_config_param_name_in_signature: validated_config_arg 447 }, 448 ) 449 else: 450 raise TypeError( 451 f"Init function '{func_name}' has an unsupported signature. " 452 f"Expected (host_component_instance) or (host_component_instance, config_param), " 453 f"but got {num_actual_params} parameters." 454 ) 455 log.info( 456 "%s Successfully executed init_function: %s.%s", 457 self.log_identifier, 458 module_name, 459 func_name, 460 ) 461 except Exception as e: 462 log.exception( 463 "%s Fatal error during agent initialization via init_function '%s.%s': %s", 464 self.log_identifier, 465 module_name, 466 func_name, 467 e, 468 ) 469 raise RuntimeError( 470 f"Agent custom initialization failed: {e}" 471 ) from e 472 473 # Async init is now handled by the base class `run` method. 474 # We still need a future to signal completion from the async thread. 475 self._async_init_future = concurrent.futures.Future() 476 477 # Set up health check timer if enabled 478 health_check_interval_seconds = self.agent_discovery_config.get( 479 "health_check_interval_seconds", HEALTH_CHECK_INTERVAL_SECONDS 480 ) 481 if health_check_interval_seconds > 0: 482 log.info( 483 "%s Scheduling agent health check every %d seconds.", 484 self.log_identifier, 485 health_check_interval_seconds, 486 ) 487 self.add_timer( 488 delay_ms=health_check_interval_seconds * 1000, 489 timer_id=self.HEALTH_CHECK_TIMER_ID, 490 interval_ms=health_check_interval_seconds * 1000, 491 callback=lambda timer_data: self._check_agent_health(), 492 ) 493 else: 494 log.warning( 495 "%s Agent health check interval not configured or invalid, health checks will not run periodically.", 496 self.log_identifier, 497 ) 498 499 log.info( 500 "%s Initialized agent: %s", 501 self.log_identifier, 502 self.agent_name, 503 ) 504 except Exception as e: 505 log.exception("%s Initialization failed: %s", self.log_identifier, e) 506 raise 507 508 def _get_component_id(self) -> str: 509 """Returns the agent name as the component identifier.""" 510 return self.agent_name 511 512 def _get_component_type(self) -> str: 513 """Returns 'agent' as the component type.""" 514 return "agent" 515 516 def invoke(self, message: SolaceMessage, data: dict) -> dict: 517 """Placeholder invoke method. Primary logic resides in _handle_message.""" 518 log.warning( 519 "%s 'invoke' method called, but primary logic resides in '_handle_message'. This should not happen in normal operation.", 520 self.log_identifier, 521 ) 522 return None 523 524 async def _handle_message_async(self, message: SolaceMessage, topic: str) -> None: 525 """ 526 Async handler for incoming messages. 527 528 Routes the message to the async event handler. 529 530 Args: 531 message: The Solace message 532 topic: The topic the message was received on 533 """ 534 # Create event and process asynchronously 535 event = Event(EventType.MESSAGE, message) 536 await process_event(self, event) 537 538 def handle_timer_event(self, timer_data: Dict[str, Any]): 539 """Handles timer events for agent card publishing and health checks.""" 540 log.debug("%s Received timer event: %s", self.log_identifier, timer_data) 541 timer_id = timer_data.get("timer_id") 542 543 if timer_id == self._card_publish_timer_id: 544 publish_agent_card(self) 545 elif timer_id == self.HEALTH_CHECK_TIMER_ID: 546 self._check_agent_health() 547 548 async def handle_cache_expiry_event(self, cache_data: Dict[str, Any]): 549 """ 550 Handles cache expiry events for peer timeouts by calling the atomic claim helper. 551 """ 552 log.debug("%s Received cache expiry event: %s", self.log_identifier, cache_data) 553 sub_task_id = cache_data.get("key") 554 logical_task_id = cache_data.get("expired_data") 555 556 if not ( 557 sub_task_id 558 and sub_task_id.startswith(CORRELATION_DATA_PREFIX) 559 and logical_task_id 560 ): 561 log.debug( 562 "%s Cache expiry for key '%s' is not a peer sub-task timeout or is missing data.", 563 self.log_identifier, 564 sub_task_id, 565 ) 566 return 567 568 correlation_data = await self._claim_peer_sub_task_completion( 569 sub_task_id=sub_task_id, logical_task_id_from_event=logical_task_id 570 ) 571 572 if correlation_data: 573 log.warning( 574 "%s Detected timeout for sub-task %s (Main Task: %s). Claimed successfully.", 575 self.log_identifier, 576 sub_task_id, 577 logical_task_id, 578 ) 579 await self._handle_peer_timeout(sub_task_id, correlation_data) 580 else: 581 log.info( 582 "%s Ignoring timeout event for sub-task %s as it was already completed.", 583 self.log_identifier, 584 sub_task_id, 585 ) 586 587 async def get_main_task_context( 588 self, logical_task_id: str 589 ) -> Optional["TaskExecutionContext"]: 590 """ 591 Retrieves the main task context for a given logical task ID. 592 593 This method is used when the current agent is the target agent for the task. 594 It returns the TaskExecutionContext which contains the full task state including 595 a2a_context, active_peer_sub_tasks, and other task execution details. 596 597 Args: 598 logical_task_id: The unique logical ID of the task 599 600 Returns: 601 The TaskExecutionContext if the task is active, None otherwise 602 603 Raises: 604 ValueError: If logical_task_id is None or empty 605 """ 606 if not logical_task_id: 607 raise ValueError("logical_task_id cannot be None or empty") 608 609 with self.active_tasks_lock: 610 active_task_context = self.active_tasks.get(logical_task_id) 611 if active_task_context is None: 612 log.warning( 613 f"No active task context found for logical_task_id: {logical_task_id}" 614 ) 615 return None 616 617 return active_task_context 618 619 async def get_all_sub_task_correlation_data_from_logical_task_id( 620 self, logical_task_id: str 621 ) -> list[dict[str, Any]]: 622 """ 623 Retrieves correlation data for all active peer sub-tasks of a given logical task. 624 625 This method is used when forwarding requests to other agents in an A2A workflow. 626 It returns a list of correlation data dictionaries, each containing information 627 about a peer sub-task including peer_task_id, peer_agent_name, and original_task_context. 628 629 Args: 630 logical_task_id: The unique logical ID of the parent task 631 632 Returns: 633 List of correlation data dictionaries for active peer sub-tasks. 634 Returns empty list if no active peer sub-tasks exist. 635 636 Raises: 637 ValueError: If logical_task_id is None or empty 638 """ 639 if not logical_task_id: 640 raise ValueError("logical_task_id cannot be None or empty") 641 642 with self.active_tasks_lock: 643 active_task_context = self.active_tasks.get(logical_task_id) 644 if active_task_context is None: 645 log.warning( 646 f"No active task context found for logical_task_id: {logical_task_id}" 647 ) 648 return [] 649 650 active_peer_sub_tasks = active_task_context.active_peer_sub_tasks 651 if not active_peer_sub_tasks: 652 log.debug( 653 f"No active peer sub-tasks found for logical_task_id: {logical_task_id}" 654 ) 655 return [] 656 657 results = [] 658 for sub_task_id, correlation_data in active_peer_sub_tasks.items(): 659 if sub_task_id is not None and correlation_data is not None: 660 results.append(correlation_data) 661 662 return results 663 664 async def _get_correlation_data_for_sub_task( 665 self, sub_task_id: str 666 ) -> Optional[Dict[str, Any]]: 667 """ 668 Non-destructively retrieves correlation data for a sub-task. 669 Used for intermediate events where the sub-task should remain active. 670 """ 671 logical_task_id = self.cache_service.get_data(sub_task_id) 672 if not logical_task_id: 673 log.warning( 674 "%s No cache entry for sub-task %s. Cannot get correlation data.", 675 self.log_identifier, 676 sub_task_id, 677 ) 678 return None 679 680 with self.active_tasks_lock: 681 task_context = self.active_tasks.get(logical_task_id) 682 683 if not task_context: 684 log.error( 685 "%s TaskExecutionContext not found for task %s, but cache entry existed for sub-task %s. This may indicate a cleanup issue.", 686 self.log_identifier, 687 logical_task_id, 688 sub_task_id, 689 ) 690 return None 691 692 with task_context.lock: 693 return task_context.active_peer_sub_tasks.get(sub_task_id) 694 695 async def _claim_peer_sub_task_completion( 696 self, sub_task_id: str, logical_task_id_from_event: Optional[str] = None 697 ) -> Optional[Dict[str, Any]]: 698 """ 699 Atomically claims a sub-task as complete, preventing race conditions. 700 This is a destructive operation that removes state. 701 702 Args: 703 sub_task_id: The ID of the sub-task to claim. 704 logical_task_id_from_event: The parent task ID, if provided by the event (e.g., a timeout). 705 If not provided, it will be looked up from the cache. 706 """ 707 log_id = f"{self.log_identifier}[ClaimSubTask:{sub_task_id}]" 708 logical_task_id = logical_task_id_from_event 709 710 if not logical_task_id: 711 logical_task_id = self.cache_service.get_data(sub_task_id) 712 if not logical_task_id: 713 log.warning( 714 "%s No cache entry found. Task has likely timed out and been cleaned up. Cannot claim.", 715 log_id, 716 ) 717 return None 718 719 with self.active_tasks_lock: 720 task_context = self.active_tasks.get(logical_task_id) 721 722 if not task_context: 723 log.error( 724 "%s TaskExecutionContext not found for task %s. Cleaning up stale cache entry.", 725 log_id, 726 logical_task_id, 727 ) 728 self.cache_service.remove_data(sub_task_id) 729 return None 730 731 correlation_data = task_context.claim_sub_task_completion(sub_task_id) 732 733 if correlation_data: 734 # If we successfully claimed the task, remove the timeout tracker from the cache. 735 self.cache_service.remove_data(sub_task_id) 736 log.info("%s Successfully claimed completion.", log_id) 737 return correlation_data 738 else: 739 # This means the task was already claimed by a competing event (e.g., timeout vs. response). 740 log.warning("%s Failed to claim; it was already completed.", log_id) 741 return None 742 743 async def reset_peer_timeout(self, sub_task_id: str): 744 """ 745 Resets the timeout for a given peer sub-task. 746 """ 747 log_id = f"{self.log_identifier}[ResetTimeout:{sub_task_id}]" 748 log.debug("%s Resetting timeout for peer sub-task.", log_id) 749 750 # Get the original logical task ID from the cache without removing it 751 logical_task_id = self.cache_service.get_data(sub_task_id) 752 if not logical_task_id: 753 log.warning( 754 "%s No active task found for sub-task %s. Cannot reset timeout.", 755 log_id, 756 sub_task_id, 757 ) 758 return 759 760 # Get the configured timeout 761 timeout_sec = self.inter_agent_communication_config.get( 762 "request_timeout_seconds", DEFAULT_COMMUNICATION_TIMEOUT 763 ) 764 765 # Update the cache with a new expiry 766 self.cache_service.add_data( 767 key=sub_task_id, 768 value=logical_task_id, 769 expiry=timeout_sec, 770 component=self, 771 ) 772 log.info( 773 "%s Timeout for sub-task %s has been reset to %d seconds.", 774 log_id, 775 sub_task_id, 776 timeout_sec, 777 ) 778 779 async def _retrigger_agent_with_peer_responses( 780 self, 781 results_to_inject: list, 782 correlation_data: dict, 783 task_context: "TaskExecutionContext", 784 ): 785 """ 786 Injects peer tool responses into the session history and re-triggers the ADK runner. 787 This function contains the logic to correctly merge parallel tool call responses. 788 """ 789 original_task_context = correlation_data.get("original_task_context") 790 logical_task_id = correlation_data.get("logical_task_id") 791 paused_invocation_id = correlation_data.get("invocation_id") 792 log_retrigger = f"{self.log_identifier}[RetriggerManager:{logical_task_id}]" 793 794 # Clear paused state - task is resuming now 795 task_context.set_paused(False) 796 log.debug( 797 "%s Task %s resuming from paused state with peer responses.", 798 log_retrigger, 799 logical_task_id, 800 ) 801 802 try: 803 effective_session_id = original_task_context.get("effective_session_id") 804 user_id = original_task_context.get("user_id") 805 session = await self.session_service.get_session( 806 app_name=self.agent_name, 807 user_id=user_id, 808 session_id=effective_session_id, 809 ) 810 if not session: 811 raise RuntimeError( 812 f"Could not find ADK session '{effective_session_id}'" 813 ) 814 815 new_response_parts = [] 816 for result in results_to_inject: 817 part = adk_types.Part.from_function_response( 818 name=result["peer_tool_name"], 819 response=result["payload"], 820 ) 821 part.function_response.id = result["adk_function_call_id"] 822 new_response_parts.append(part) 823 824 # Always create a new event for the incoming peer responses. 825 # The ADK's `contents` processor is responsible for merging multiple 826 # tool responses into a single message before the next LLM call. 827 log.info( 828 "%s Creating a new tool response event for %d peer responses.", 829 log_retrigger, 830 len(new_response_parts), 831 ) 832 new_tool_response_content = adk_types.Content( 833 role="tool", parts=new_response_parts 834 ) 835 836 # Always use SSE streaming mode for the ADK runner, even on re-trigger. 837 # This ensures that real-time callbacks for status updates and artifact 838 # creation can function correctly for all turns of a task. 839 streaming_mode = StreamingMode.SSE 840 max_llm_calls = self.get_config("max_llm_calls_per_task", 20) 841 run_config = RunConfig( 842 streaming_mode=streaming_mode, max_llm_calls=max_llm_calls 843 ) 844 845 log.info( 846 "%s Re-triggering ADK runner for main task %s.", 847 log_retrigger, 848 logical_task_id, 849 ) 850 try: 851 await run_adk_async_task_thread_wrapper( 852 self, 853 session, 854 new_tool_response_content, 855 run_config, 856 original_task_context, 857 append_context_event=False, 858 ) 859 finally: 860 log.info( 861 "%s Cleaning up parallel invocation state for invocation %s.", 862 log_retrigger, 863 paused_invocation_id, 864 ) 865 task_context.clear_parallel_invocation_state(paused_invocation_id) 866 867 except Exception as e: 868 log.exception( 869 "%s Failed to re-trigger ADK runner for task %s: %s", 870 log_retrigger, 871 logical_task_id, 872 e, 873 ) 874 if original_task_context: 875 loop = self.get_async_loop() 876 if loop and loop.is_running(): 877 # For structured invocation tasks, route the error through 878 # the SI handler so it can send a proper structured error 879 # result and clean up. 880 if task_context.get_flag("structured_invocation"): 881 asyncio.run_coroutine_threadsafe( 882 self.structured_invocation_handler.finalize_deferred_structured_invocation( 883 task_context, original_task_context, e 884 ), 885 loop, 886 ) 887 else: 888 asyncio.run_coroutine_threadsafe( 889 self.finalize_task_error(e, original_task_context), loop 890 ) 891 else: 892 log.error( 893 "%s Async loop not available. Cannot schedule error finalization for task %s.", 894 log_retrigger, 895 logical_task_id, 896 ) 897 898 async def _handle_peer_timeout( 899 self, 900 sub_task_id: str, 901 correlation_data: Dict[str, Any], 902 ): 903 """ 904 Handles the timeout of a peer agent task. It sends a cancellation request 905 to the peer, updates the local completion counter, and potentially 906 re-triggers the runner if all parallel tasks are now complete. 907 """ 908 logical_task_id = correlation_data.get("logical_task_id") 909 invocation_id = correlation_data.get("invocation_id") 910 log_retrigger = f"{self.log_identifier}[RetriggerManager:{logical_task_id}]" 911 912 log.warning( 913 "%s Peer request timed out for sub-task: %s (Invocation: %s)", 914 log_retrigger, 915 sub_task_id, 916 invocation_id, 917 ) 918 919 # Proactively send a cancellation request to the peer agent. 920 peer_agent_name = correlation_data.get("peer_agent_name") 921 if peer_agent_name: 922 try: 923 log.info( 924 "%s Sending CancelTaskRequest to peer '%s' for timed-out sub-task %s.", 925 log_retrigger, 926 peer_agent_name, 927 sub_task_id, 928 ) 929 task_id_for_peer = sub_task_id.replace(CORRELATION_DATA_PREFIX, "", 1) 930 cancel_request = a2a.create_cancel_task_request( 931 task_id=task_id_for_peer 932 ) 933 user_props = {"clientId": self.agent_name} 934 peer_topic = self._get_agent_request_topic(peer_agent_name) 935 self.publish_a2a_message( 936 payload=cancel_request.model_dump(exclude_none=True), 937 topic=peer_topic, 938 user_properties=user_props, 939 ) 940 except Exception as e: 941 log.error( 942 "%s Failed to send CancelTaskRequest to peer '%s' for sub-task %s: %s", 943 log_retrigger, 944 peer_agent_name, 945 sub_task_id, 946 e, 947 ) 948 949 # Process the timeout locally. 950 with self.active_tasks_lock: 951 task_context = self.active_tasks.get(logical_task_id) 952 953 if not task_context: 954 log.warning( 955 "%s TaskExecutionContext not found for task %s. Ignoring timeout event.", 956 log_retrigger, 957 logical_task_id, 958 ) 959 return 960 961 timeout_value = self.inter_agent_communication_config.get( 962 "request_timeout_seconds", DEFAULT_COMMUNICATION_TIMEOUT 963 ) 964 all_sub_tasks_completed = task_context.handle_peer_timeout( 965 sub_task_id, correlation_data, timeout_value, invocation_id 966 ) 967 968 if not all_sub_tasks_completed: 969 log.info( 970 "%s Waiting for more peer responses for invocation %s after timeout of sub-task %s.", 971 log_retrigger, 972 invocation_id, 973 sub_task_id, 974 ) 975 return 976 977 log.info( 978 "%s All peer responses/timeouts received for invocation %s. Retriggering agent.", 979 log_retrigger, 980 invocation_id, 981 ) 982 results_to_inject = task_context.parallel_tool_calls[invocation_id].get( 983 "results", [] 984 ) 985 986 await self._retrigger_agent_with_peer_responses( 987 results_to_inject, correlation_data, task_context 988 ) 989 990 def _inject_peer_tools_callback( 991 self, callback_context: CallbackContext, llm_request: LlmRequest 992 ) -> Optional[LlmResponse]: 993 """ 994 ADK before_model_callback to dynamically add PeerAgentTools to the LLM request 995 and generate the corresponding instruction text for the LLM. 996 """ 997 log.debug("%s Running _inject_peer_tools_callback...", self.log_identifier) 998 if not self.peer_agents: 999 log.debug("%s No peer agents currently discovered.", self.log_identifier) 1000 return None 1001 1002 a2a_context = callback_context.state.get("a2a_context", {}) 1003 user_config = ( 1004 a2a_context.get("a2a_user_config", {}) 1005 if isinstance(a2a_context, dict) 1006 else {} 1007 ) 1008 1009 inter_agent_config = self.get_config("inter_agent_communication", {}) 1010 allow_list = inter_agent_config.get("allow_list", ["*"]) 1011 deny_list = set(self.get_config("deny_list", [])) 1012 self_name = self.get_config("agent_name") 1013 1014 peer_tools_to_add = [] 1015 allowed_peer_descriptions = [] 1016 1017 # Sort peer agents alphabetically to ensure consistent tool ordering for prompt caching 1018 for peer_name, agent_card in sorted(self.peer_agents.items()): 1019 if not isinstance(agent_card, AgentCard) or peer_name == self_name: 1020 continue 1021 1022 is_allowed = any( 1023 fnmatch.fnmatch(peer_name, p) for p in allow_list 1024 ) and not any(fnmatch.fnmatch(peer_name, p) for p in deny_list) 1025 1026 if is_allowed: 1027 config_resolver = MiddlewareRegistry.get_config_resolver() 1028 operation_spec = { 1029 "operation_type": "peer_delegation", 1030 "target_agent": peer_name, 1031 "delegation_context": "peer_discovery", 1032 } 1033 validation_context = { 1034 "discovery_phase": "peer_enumeration", 1035 "agent_context": {"component_type": "peer_discovery"}, 1036 } 1037 validation_result = config_resolver.validate_operation_config( 1038 user_config, operation_spec, validation_context 1039 ) 1040 if not validation_result.get("valid", True): 1041 log.debug( 1042 "%s Peer agent '%s' filtered out by user configuration.", 1043 self.log_identifier, 1044 peer_name, 1045 ) 1046 is_allowed = False 1047 1048 if not is_allowed: 1049 continue 1050 1051 try: 1052 # Determine agent type and schemas 1053 agent_type = "standard" 1054 input_schema = None 1055 1056 if agent_card.capabilities and agent_card.capabilities.extensions: 1057 for ext in agent_card.capabilities.extensions: 1058 if ext.uri == EXTENSION_URI_AGENT_TYPE: 1059 agent_type = ext.params.get("type", "standard") 1060 elif ext.uri == EXTENSION_URI_SCHEMAS: 1061 input_schema = ext.params.get("input_schema") 1062 1063 tool_instance = None 1064 tool_description_line = "" 1065 1066 if agent_type == "workflow": 1067 # Default schema if none provided 1068 if not input_schema: 1069 input_schema = { 1070 "type": "object", 1071 "properties": {"text": {"type": "string"}}, 1072 "required": ["text"], 1073 } 1074 1075 tool_instance = WorkflowAgentTool( 1076 target_agent_name=peer_name, 1077 input_schema=input_schema, 1078 host_component=self, 1079 ) 1080 1081 desc = ( 1082 getattr(agent_card, "description", "No description") 1083 or "No description" 1084 ) 1085 tool_description_line = f"- `{tool_instance.name}`: {desc}" 1086 1087 else: 1088 # Standard Peer Agent 1089 tool_instance = PeerAgentTool( 1090 target_agent_name=peer_name, host_component=self 1091 ) 1092 # Get enhanced description from the tool instance 1093 # which includes capabilities, skills, and tools 1094 enhanced_desc = tool_instance._build_enhanced_description( 1095 agent_card 1096 ) 1097 tool_description_line = f"\n### `peer_{peer_name}`\n{enhanced_desc}" 1098 1099 if tool_instance.name not in llm_request.tools_dict: 1100 peer_tools_to_add.append(tool_instance) 1101 allowed_peer_descriptions.append(tool_description_line) 1102 1103 except Exception as e: 1104 log.error( 1105 "%s Failed to create tool for '%s': %s", 1106 self.log_identifier, 1107 peer_name, 1108 e, 1109 ) 1110 1111 if allowed_peer_descriptions: 1112 peer_list_str = "\n".join(allowed_peer_descriptions) 1113 instruction_text = ( 1114 "## Peer Agent and Workflow Delegation\n\n" 1115 "You can delegate tasks to other specialized agents or workflows if they are better suited.\n\n" 1116 "**How to delegate to peer agents:**\n" 1117 "- Use the `peer_<agent_name>(task_description: str)` tool for delegation\n" 1118 "- Replace `<agent_name>` with the actual name of the target agent\n" 1119 "- Provide a clear and detailed `task_description` for the peer agent\n" 1120 "- **Important:** The peer agent does not have access to your session history, " 1121 "so you must provide all required context necessary to fulfill the request\n\n" 1122 "**How to delegate to workflows:**\n" 1123 "- Use the `workflow_<agent_name>` tool for workflow delegation\n" 1124 "- Follow the specific parameter requirements defined in the tool schema\n" 1125 "- Workflows also do not have access to your session history\n\n" 1126 "IMPORTANT: When a peer agent's response contains citation markers like [[cite:search0]], [[cite:file1]], etc., " 1127 "you MUST preserve these markers in your response to the user. These markers link to source references and are " 1128 "essential for proper attribution. Include them exactly as they appear in the peer's response. DO NOT repeat them without markers.\n\n" 1129 "## Available Peer Agents and Workflows\n" 1130 f"{peer_list_str}" 1131 ) 1132 callback_context.state["peer_tool_instructions"] = instruction_text 1133 log.debug( 1134 "%s Stored peer tool instructions in callback_context.state.", 1135 self.log_identifier, 1136 ) 1137 1138 if peer_tools_to_add: 1139 try: 1140 if llm_request.config.tools is None: 1141 llm_request.config.tools = [] 1142 if len(llm_request.config.tools) > 0: 1143 for tool in peer_tools_to_add: 1144 llm_request.tools_dict[tool.name] = tool 1145 declaration = tool._get_declaration() 1146 llm_request.config.tools[0].function_declarations.append( 1147 declaration 1148 ) 1149 else: 1150 llm_request.append_tools(peer_tools_to_add) 1151 log.debug( 1152 "%s Dynamically added %d PeerAgentTool(s) to LLM request.", 1153 self.log_identifier, 1154 len(peer_tools_to_add), 1155 ) 1156 except Exception as e: 1157 log.error( 1158 "%s Failed to append dynamic peer tools to LLM request: %s", 1159 self.log_identifier, 1160 e, 1161 exc_info=True, 1162 ) 1163 return None 1164 1165 @staticmethod 1166 def _remove_tool(llm_request: LlmRequest, tool_name: str) -> bool: 1167 """Remove a tool's FunctionDeclaration from config.tools so the LLM 1168 does not see it for this request. 1169 1170 The tool is intentionally kept in tools_dict so the ADK runtime can 1171 still dispatch to it if the LLM calls it from conversation history. 1172 1173 Returns True if a declaration was found and removed, False otherwise. 1174 """ 1175 removed = False 1176 1177 if llm_request.config and llm_request.config.tools: 1178 for tool_obj in llm_request.config.tools: 1179 if tool_obj.function_declarations: 1180 original_count = len(tool_obj.function_declarations) 1181 tool_obj.function_declarations = [ 1182 fd for fd in tool_obj.function_declarations 1183 if fd.name != tool_name 1184 ] 1185 if len(tool_obj.function_declarations) < original_count: 1186 removed = True 1187 # ADK requires Tool objects to have at least one declaration 1188 llm_request.config.tools = [ 1189 t for t in llm_request.config.tools 1190 if t.function_declarations 1191 ] 1192 if not llm_request.config.tools: 1193 llm_request.config.tools = None 1194 1195 return removed 1196 1197 @staticmethod 1198 def _has_declaration(llm_request: LlmRequest, tool_name: str) -> bool: 1199 """Check whether a FunctionDeclaration with the given name exists in config.tools.""" 1200 if llm_request.config and llm_request.config.tools: 1201 for tool_obj in llm_request.config.tools: 1202 if tool_obj.function_declarations: 1203 if any(fd.name == tool_name for fd in tool_obj.function_declarations): 1204 return True 1205 return False 1206 1207 def _ensure_tool_in_tools_dict( 1208 self, llm_request: LlmRequest, tool_name: str 1209 ) -> bool: 1210 """Ensure a tool exists in tools_dict for ADK dispatch safety. 1211 1212 tools_dict is rebuilt from agent.tools on every LLM call, so 1213 dynamically-injected tools from a previous call won't be present. 1214 This method ensures the tool is always dispatchable — preventing 1215 ValueError from the ADK runtime if the LLM calls it from 1216 conversation history after it has been hidden. 1217 1218 If the tool is already in tools_dict (e.g., from YAML static 1219 loading), it is left untouched to preserve any tool_config. 1220 1221 Returns True if the tool is (now) in tools_dict. 1222 """ 1223 if tool_name in llm_request.tools_dict: 1224 return True 1225 1226 tool_def = tool_registry.get_tool_by_name(tool_name) 1227 if not tool_def: 1228 return False 1229 1230 try: 1231 tool_callable = ADKToolWrapper( 1232 tool_def.implementation, 1233 None, 1234 tool_def.name, 1235 origin="builtin", 1236 raw_string_args=tool_def.raw_string_args, 1237 artifact_args=tool_def.artifact_args, 1238 ) 1239 tool_callable.__doc__ = tool_def.description 1240 function_tool = FunctionTool(tool_callable) 1241 function_tool.origin = "builtin" 1242 llm_request.tools_dict[tool_def.name] = function_tool 1243 return True 1244 except Exception as e: 1245 log.error( 1246 "%s Failed to create FunctionTool for %s: %s", 1247 self.log_identifier, 1248 tool_name, 1249 e, 1250 exc_info=True, 1251 ) 1252 return False 1253 1254 def _sync_tools_callback( 1255 self, callback_context: CallbackContext, llm_request: LlmRequest 1256 ) -> Optional[LlmResponse]: 1257 """Sync the LLM tool list with the current request context. 1258 1259 Ensures tools that require specific context (e.g., a project index) 1260 are only present when that context is available. Adds missing tools 1261 when prerequisites are met, removes them when they are not. 1262 1263 The tool is always kept in tools_dict so the ADK runtime can dispatch 1264 it if the LLM calls it from conversation history. Visibility to the 1265 LLM is controlled exclusively via config.tools declarations. 1266 """ 1267 log.debug("%s Running _sync_tools_callback...", self.log_identifier) 1268 1269 # Currently, the gateway only passes project_id when indexing is enabled 1270 # AND the BM25 index exists and only index_search is added or removed 1271 a2a_context = callback_context.state.get("a2a_context", {}) 1272 if not isinstance(a2a_context, dict): 1273 a2a_context = {} 1274 1275 original_metadata = a2a_context.get("original_message_metadata", {}) 1276 if not isinstance(original_metadata, dict): 1277 original_metadata = {} 1278 1279 has_project_id = bool(original_metadata.get("project_id")) 1280 1281 # Always ensure tools_dict has index_search so the ADK runtime can 1282 # dispatch if the LLM calls it from conversation history. tools_dict 1283 # is rebuilt from agent.tools each request, so dynamically-injected 1284 # tools from a previous request won't be present. 1285 self._ensure_tool_in_tools_dict(llm_request, "index_search") 1286 1287 if not has_project_id: 1288 # No project — remove declaration so LLM doesn't see the tool 1289 if SamAgentComponent._remove_tool(llm_request, "index_search"): 1290 log.info( 1291 "%s Removed index_search declaration (no project or index available).", 1292 self.log_identifier, 1293 ) 1294 # Set stale instructions to None for any previous project context so the 1295 # LLM doesn't receive index_search related prompts. 1296 callback_context.state["project_tool_instructions"] = None 1297 return None 1298 1299 # Project is present — ensure the declaration exists 1300 if SamAgentComponent._has_declaration(llm_request, "index_search"): 1301 log.debug( 1302 "%s index_search already declared, skipping injection.", 1303 self.log_identifier, 1304 ) 1305 return None 1306 1307 tool_def = tool_registry.get_tool_by_name("index_search") 1308 if not tool_def: 1309 log.debug( 1310 "%s index_search not found in tool registry, skipping.", 1311 self.log_identifier, 1312 ) 1313 return None 1314 1315 try: 1316 declaration = adk_types.FunctionDeclaration( 1317 name=tool_def.name, 1318 description=tool_def.description, 1319 parameters=tool_def.parameters, 1320 ) 1321 1322 if not llm_request.config.tools: 1323 llm_request.config.tools = [ 1324 adk_types.Tool(function_declarations=[]) 1325 ] 1326 llm_request.config.tools[0].function_declarations.append(declaration) 1327 1328 instructions = _generate_tool_instructions_from_registry( 1329 [tool_def], self.log_identifier 1330 ) 1331 if instructions: 1332 callback_context.state["project_tool_instructions"] = instructions 1333 1334 log.debug( 1335 "%s Dynamically injected index_search declaration.", 1336 self.log_identifier, 1337 ) 1338 except Exception as e: 1339 log.error( 1340 "%s Failed to inject index_search: %s", 1341 self.log_identifier, 1342 e, 1343 exc_info=True, 1344 ) 1345 1346 return None 1347 1348 def _filter_tools_by_capability_callback( 1349 self, callback_context: CallbackContext, llm_request: LlmRequest 1350 ) -> Optional[LlmResponse]: 1351 """ 1352 ADK before_model_callback to filter tools in the LlmRequest based on user configuration. 1353 This callback modifies `llm_request.config.tools` in place by potentially 1354 removing individual FunctionDeclarations from genai.Tool objects or removing 1355 entire genai.Tool objects if all their declarations are filtered out. 1356 """ 1357 log_id_prefix = f"{self.log_identifier}[ToolCapabilityFilter]" 1358 log.debug("%s Running _filter_tools_by_capability_callback...", log_id_prefix) 1359 1360 a2a_context = callback_context.state.get("a2a_context", {}) 1361 if not isinstance(a2a_context, dict): 1362 log.warning( 1363 "%s 'a2a_context' in session state is not a dictionary. Using empty configuration.", 1364 log_id_prefix, 1365 ) 1366 a2a_context = {} 1367 user_config = a2a_context.get("a2a_user_config", {}) 1368 if not isinstance(user_config, dict): 1369 log.warning( 1370 "%s 'a2a_user_config' in a2a_context is not a dictionary. Using empty configuration.", 1371 log_id_prefix, 1372 ) 1373 user_config = {} 1374 1375 log.debug( 1376 "%s User configuration for filtering: %s", 1377 log_id_prefix, 1378 {k: v for k, v in user_config.items() if not k.startswith("_")}, 1379 ) 1380 1381 config_resolver = MiddlewareRegistry.get_config_resolver() 1382 1383 if not llm_request.config or not llm_request.config.tools: 1384 log.debug("%s No tools in request to filter.", log_id_prefix) 1385 return None 1386 1387 explicit_tools_config = self.get_config("tools", []) 1388 final_filtered_genai_tools: List[adk_types.Tool] = [] 1389 original_genai_tools_count = len(llm_request.config.tools) 1390 original_function_declarations_count = 0 1391 1392 for original_tool in llm_request.config.tools: 1393 if not original_tool.function_declarations: 1394 log.warning( 1395 "%s genai.Tool object has no function declarations. Keeping it.", 1396 log_id_prefix, 1397 ) 1398 final_filtered_genai_tools.append(original_tool) 1399 continue 1400 1401 original_function_declarations_count += len( 1402 original_tool.function_declarations 1403 ) 1404 permitted_declarations_for_this_tool: List[ 1405 adk_types.FunctionDeclaration 1406 ] = [] 1407 1408 for func_decl in original_tool.function_declarations: 1409 func_decl_name = func_decl.name 1410 tool_object = llm_request.tools_dict.get(func_decl_name) 1411 origin = SamAgentComponent._extract_tool_origin(tool_object) 1412 1413 feature_descriptor = { 1414 "feature_type": "tool_function", 1415 "function_name": func_decl_name, 1416 "tool_source": origin, 1417 "tool_metadata": {"function_name": func_decl_name}, 1418 } 1419 1420 if origin == "peer_agent": 1421 peer_name = func_decl_name.replace(PEER_TOOL_PREFIX, "", 1) 1422 feature_descriptor["tool_metadata"]["peer_agent_name"] = peer_name 1423 elif origin == "builtin": 1424 tool_def = tool_registry.get_tool_by_name(func_decl_name) 1425 if tool_def: 1426 feature_descriptor["tool_metadata"][ 1427 "tool_category" 1428 ] = tool_def.category 1429 feature_descriptor["tool_metadata"][ 1430 "required_scopes" 1431 ] = tool_def.required_scopes 1432 elif origin in ["python", "mcp", "adk_builtin"]: 1433 # Find the explicit config for this tool to pass to the resolver 1434 for tool_cfg in explicit_tools_config: 1435 cfg_tool_type = tool_cfg.get("tool_type") 1436 cfg_tool_name = tool_cfg.get("tool_name") 1437 cfg_func_name = tool_cfg.get("function_name") 1438 if ( 1439 cfg_tool_type == "python" 1440 and cfg_func_name == func_decl_name 1441 ) or ( 1442 cfg_tool_type in ["builtin", "mcp"] 1443 and cfg_tool_name == func_decl_name 1444 ): 1445 feature_descriptor["tool_metadata"][ 1446 "tool_config" 1447 ] = tool_cfg 1448 break 1449 1450 context = { 1451 "agent_context": self.get_agent_context(), 1452 "filter_phase": "pre_llm", 1453 "tool_configurations": { 1454 "explicit_tools": explicit_tools_config, 1455 }, 1456 } 1457 1458 if config_resolver.is_feature_enabled( 1459 user_config, feature_descriptor, context 1460 ): 1461 permitted_declarations_for_this_tool.append(func_decl) 1462 log.debug( 1463 "%s FunctionDeclaration '%s' (Source: %s) permitted.", 1464 log_id_prefix, 1465 func_decl_name, 1466 origin, 1467 ) 1468 else: 1469 log.info( 1470 "%s FunctionDeclaration '%s' (Source: %s) FILTERED OUT due to configuration restrictions.", 1471 log_id_prefix, 1472 func_decl_name, 1473 origin, 1474 ) 1475 1476 if permitted_declarations_for_this_tool: 1477 scoped_tool = original_tool.model_copy(deep=True) 1478 scoped_tool.function_declarations = permitted_declarations_for_this_tool 1479 1480 final_filtered_genai_tools.append(scoped_tool) 1481 log.debug( 1482 "%s Keeping genai.Tool as it has %d permitted FunctionDeclaration(s).", 1483 log_id_prefix, 1484 len(permitted_declarations_for_this_tool), 1485 ) 1486 else: 1487 log.info( 1488 "%s Entire genai.Tool (original declarations: %s) FILTERED OUT as all its FunctionDeclarations were denied by configuration.", 1489 log_id_prefix, 1490 [fd.name for fd in original_tool.function_declarations], 1491 ) 1492 1493 final_function_declarations_count = sum( 1494 len(t.function_declarations) 1495 for t in final_filtered_genai_tools 1496 if t.function_declarations 1497 ) 1498 1499 if final_function_declarations_count != original_function_declarations_count: 1500 log.info( 1501 "%s Tool list modified by capability filter. Original genai.Tools: %d (Total Declarations: %d). Filtered genai.Tools: %d (Total Declarations: %d).", 1502 log_id_prefix, 1503 original_genai_tools_count, 1504 original_function_declarations_count, 1505 len(final_filtered_genai_tools), 1506 final_function_declarations_count, 1507 ) 1508 llm_request.config.tools = ( 1509 final_filtered_genai_tools if final_filtered_genai_tools else None 1510 ) 1511 else: 1512 log.debug( 1513 "%s Tool list and FunctionDeclarations unchanged after capability filtering.", 1514 log_id_prefix, 1515 ) 1516 1517 return None 1518 1519 @staticmethod 1520 def _extract_tool_origin(tool) -> str: 1521 """ 1522 Helper method to extract the origin of a tool from various possible attributes. 1523 """ 1524 if hasattr(tool, "origin") and tool.origin is not None: 1525 return tool.origin 1526 elif ( 1527 hasattr(tool, "func") 1528 and hasattr(tool.func, "origin") 1529 and tool.func.origin is not None 1530 ): 1531 return tool.func.origin 1532 else: 1533 return getattr(tool, "origin", "unknown") 1534 1535 def get_agent_context(self) -> Dict[str, Any]: 1536 """Get agent context for middleware calls.""" 1537 return { 1538 "agent_name": getattr(self, "agent_name", "unknown"), 1539 "component_type": "sac_agent", 1540 } 1541 1542 def _inject_gateway_instructions_callback( 1543 self, callback_context: CallbackContext, llm_request: LlmRequest 1544 ) -> Optional[LlmResponse]: 1545 """ 1546 ADK before_model_callback to dynamically prepend gateway-defined system_purpose 1547 and response_format to the agent's llm_request.config.system_instruction. 1548 """ 1549 log_id_prefix = f"{self.log_identifier}[GatewayInstrInject]" 1550 log.debug( 1551 "%s Running _inject_gateway_instructions_callback to modify system_instruction...", 1552 log_id_prefix, 1553 ) 1554 1555 a2a_context = callback_context.state.get("a2a_context", {}) 1556 if not isinstance(a2a_context, dict): 1557 log.warning( 1558 "%s 'a2a_context' in session state is not a dictionary. Skipping instruction injection.", 1559 log_id_prefix, 1560 ) 1561 return None 1562 1563 system_purpose = a2a_context.get("system_purpose") 1564 response_format = a2a_context.get("response_format") 1565 user_profile = a2a_context.get("a2a_user_config", {}).get("user_profile") 1566 1567 inject_purpose = self.get_config("inject_system_purpose", False) 1568 inject_format = self.get_config("inject_response_format", False) 1569 inject_user_profile = self.get_config("inject_user_profile", False) 1570 1571 gateway_instructions_to_add = [] 1572 1573 if ( 1574 inject_purpose 1575 and system_purpose 1576 and isinstance(system_purpose, str) 1577 and system_purpose.strip() 1578 ): 1579 gateway_instructions_to_add.append( 1580 f"System Purpose:\n{system_purpose.strip()}" 1581 ) 1582 log.debug( 1583 "%s Prepared system_purpose for system_instruction.", log_id_prefix 1584 ) 1585 1586 if ( 1587 inject_format 1588 and response_format 1589 and isinstance(response_format, str) 1590 and response_format.strip() 1591 ): 1592 gateway_instructions_to_add.append( 1593 f"Desired Response Format:\n{response_format.strip()}" 1594 ) 1595 log.debug( 1596 "%s Prepared response_format for system_instruction.", log_id_prefix 1597 ) 1598 1599 if ( 1600 inject_user_profile 1601 and user_profile 1602 and (isinstance(user_profile, str) or isinstance(user_profile, dict)) 1603 ): 1604 if isinstance(user_profile, dict): 1605 user_profile = json.dumps(user_profile, indent=2, default=str) 1606 gateway_instructions_to_add.append( 1607 f"Inquiring User Profile:\n{user_profile.strip()}\n" 1608 ) 1609 log.debug("%s Prepared user_profile for system_instruction.", log_id_prefix) 1610 1611 if not gateway_instructions_to_add: 1612 log.debug( 1613 "%s No gateway instructions to inject into system_instruction.", 1614 log_id_prefix, 1615 ) 1616 return None 1617 1618 if llm_request.config is None: 1619 log.warning( 1620 "%s llm_request.config is None, cannot append gateway instructions to system_instruction.", 1621 log_id_prefix, 1622 ) 1623 return None 1624 1625 if llm_request.config.system_instruction is None: 1626 llm_request.config.system_instruction = "" 1627 1628 combined_new_instructions = "\n\n".join(gateway_instructions_to_add) 1629 1630 if llm_request.config.system_instruction: 1631 llm_request.config.system_instruction += ( 1632 f"\n\n---\n\n{combined_new_instructions}" 1633 ) 1634 else: 1635 llm_request.config.system_instruction = combined_new_instructions 1636 1637 log.info( 1638 "%s Injected %d gateway instruction block(s) into llm_request.config.system_instruction.", 1639 log_id_prefix, 1640 len(gateway_instructions_to_add), 1641 ) 1642 1643 return None 1644 1645 async def _publish_text_as_partial_a2a_status_update( 1646 self, 1647 text_content: str, 1648 a2a_context: Dict, 1649 is_stream_terminating_content: bool = False, 1650 ): 1651 """ 1652 Constructs and publishes a TaskStatusUpdateEvent for the given text. 1653 The 'final' flag is determined by is_stream_terminating_content. 1654 This method skips buffer flushing since it's used for LLM streaming text. 1655 """ 1656 logical_task_id = a2a_context.get("logical_task_id", "unknown_task") 1657 log_identifier_helper = ( 1658 f"{self.log_identifier}[PublishPartialText:{logical_task_id}]" 1659 ) 1660 1661 if not text_content: 1662 log.debug( 1663 "%s No text content to publish as update (final=%s).", 1664 log_identifier_helper, 1665 is_stream_terminating_content, 1666 ) 1667 return 1668 1669 try: 1670 a2a_message = a2a.create_agent_text_message( 1671 text=text_content, 1672 task_id=logical_task_id, 1673 context_id=a2a_context.get("contextId"), 1674 ) 1675 event_metadata = {"agent_name": self.agent_name} 1676 status_update_event = a2a.create_status_update( 1677 task_id=logical_task_id, 1678 context_id=a2a_context.get("contextId"), 1679 message=a2a_message, 1680 is_final=is_stream_terminating_content, 1681 metadata=event_metadata, 1682 ) 1683 1684 await self._publish_status_update_with_buffer_flush( 1685 status_update_event, 1686 a2a_context, 1687 skip_buffer_flush=True, 1688 ) 1689 1690 log.debug( 1691 "%s Published LLM streaming text (length: %d bytes, final: %s).", 1692 log_identifier_helper, 1693 len(text_content.encode("utf-8")), 1694 is_stream_terminating_content, 1695 ) 1696 1697 except Exception as e: 1698 log.exception( 1699 "%s Error in _publish_text_as_partial_a2a_status_update: %s", 1700 log_identifier_helper, 1701 e, 1702 ) 1703 1704 async def _publish_agent_status_signal_update( 1705 self, status_text: str, a2a_context: Dict 1706 ): 1707 """ 1708 Constructs and publishes a TaskStatusUpdateEvent specifically for agent_status_message signals. 1709 This method will flush the buffer before publishing to maintain proper message ordering. 1710 """ 1711 logical_task_id = a2a_context.get("logical_task_id", "unknown_task") 1712 log_identifier_helper = ( 1713 f"{self.log_identifier}[PublishAgentSignal:{logical_task_id}]" 1714 ) 1715 1716 if not status_text: 1717 log.debug( 1718 "%s No text content for agent status signal.", log_identifier_helper 1719 ) 1720 return 1721 1722 try: 1723 progress_data = AgentProgressUpdateData(status_text=status_text) 1724 status_update_event = a2a.create_data_signal_event( 1725 task_id=logical_task_id, 1726 context_id=a2a_context.get("contextId"), 1727 signal_data=progress_data, 1728 agent_name=self.agent_name, 1729 part_metadata={"source_embed_type": "status_update"}, 1730 ) 1731 1732 await self._publish_status_update_with_buffer_flush( 1733 status_update_event, 1734 a2a_context, 1735 skip_buffer_flush=False, 1736 ) 1737 1738 log.debug( 1739 "%s Published agent_status_message signal ('%s').", 1740 log_identifier_helper, 1741 status_text, 1742 ) 1743 1744 except Exception as e: 1745 log.exception( 1746 "%s Error in _publish_agent_status_signal_update: %s", 1747 log_identifier_helper, 1748 e, 1749 ) 1750 1751 async def _flush_buffer_if_needed( 1752 self, a2a_context: Dict, reason: str = "status_update" 1753 ) -> bool: 1754 """ 1755 Flushes streaming buffer if it contains content. 1756 1757 Args: 1758 a2a_context: The A2A context dictionary for the current task 1759 reason: The reason for flushing (for logging purposes) 1760 1761 Returns: 1762 bool: True if buffer was flushed, False if no content to flush 1763 """ 1764 logical_task_id = a2a_context.get("logical_task_id", "unknown_task") 1765 log_identifier = f"{self.log_identifier}[BufferFlush:{logical_task_id}]" 1766 1767 with self.active_tasks_lock: 1768 task_context = self.active_tasks.get(logical_task_id) 1769 1770 if not task_context: 1771 log.warning( 1772 "%s TaskExecutionContext not found for task %s. Cannot flush buffer.", 1773 log_identifier, 1774 logical_task_id, 1775 ) 1776 return False 1777 1778 buffer_content = task_context.get_streaming_buffer_content() 1779 if not buffer_content: 1780 log.debug( 1781 "%s No buffer content to flush (reason: %s).", 1782 log_identifier, 1783 reason, 1784 ) 1785 return False 1786 1787 buffer_size = len(buffer_content.encode("utf-8")) 1788 log.info( 1789 "%s Flushing buffer content (size: %d bytes, reason: %s).", 1790 log_identifier, 1791 buffer_size, 1792 reason, 1793 ) 1794 1795 try: 1796 resolved_text, unprocessed_tail = await self._flush_and_resolve_buffer( 1797 a2a_context, is_final=False 1798 ) 1799 1800 if resolved_text: 1801 is_run_based = a2a_context.get("is_run_based_session", False) 1802 if is_run_based: 1803 with self.active_tasks_lock: 1804 tc = self.active_tasks.get(logical_task_id) 1805 if tc: 1806 tc.append_to_run_based_buffer(resolved_text) 1807 else: 1808 await self._publish_text_as_partial_a2a_status_update( 1809 resolved_text, 1810 a2a_context, 1811 is_stream_terminating_content=False, 1812 ) 1813 log.debug( 1814 "%s Successfully flushed and published buffer content (resolved: %d bytes).", 1815 log_identifier, 1816 len(resolved_text.encode("utf-8")), 1817 ) 1818 return True 1819 else: 1820 log.debug( 1821 "%s Buffer flush completed but no resolved text to publish.", 1822 log_identifier, 1823 ) 1824 return False 1825 1826 except Exception as e: 1827 log.exception( 1828 "%s Error during buffer flush (reason: %s): %s", 1829 log_identifier, 1830 reason, 1831 e, 1832 ) 1833 return False 1834 1835 async def notify_artifact_saved( 1836 self, 1837 artifact_info: ArtifactInfo, 1838 a2a_context: Dict[str, Any], 1839 function_call_id: Optional[str] = None, 1840 ) -> None: 1841 """ 1842 Publishes an artifact saved notification signal. 1843 1844 This is a separate event from ArtifactCreationProgressData and does not 1845 follow the start->updates->end protocol. It's a single notification that 1846 an artifact has been successfully saved to storage. 1847 1848 Args: 1849 artifact_info: Information about the saved artifact 1850 a2a_context: The A2A context dictionary for the current task 1851 function_call_id: Optional function call ID if artifact was created by a tool 1852 """ 1853 log_identifier = ( 1854 f"{self.log_identifier}[ArtifactSaved:{artifact_info.filename}]" 1855 ) 1856 1857 try: 1858 # Create artifact saved signal 1859 artifact_signal = ArtifactSavedData( 1860 type="artifact_saved", 1861 filename=artifact_info.filename, 1862 version=artifact_info.version, 1863 mime_type=artifact_info.mime_type or "application/octet-stream", 1864 size_bytes=artifact_info.size, 1865 description=artifact_info.description, 1866 function_call_id=function_call_id, 1867 ) 1868 1869 # Create and publish status update event 1870 logical_task_id = a2a_context.get("logical_task_id") 1871 context_id = a2a_context.get("contextId") 1872 1873 status_update_event = a2a.create_data_signal_event( 1874 task_id=logical_task_id, 1875 context_id=context_id, 1876 signal_data=artifact_signal, 1877 agent_name=self.agent_name, 1878 ) 1879 1880 await self._publish_status_update_with_buffer_flush( 1881 status_update_event, 1882 a2a_context, 1883 skip_buffer_flush=False, 1884 ) 1885 1886 log.debug( 1887 "%s Published artifact saved notification for '%s' v%s.", 1888 log_identifier, 1889 artifact_info.filename, 1890 artifact_info.version, 1891 ) 1892 except Exception as e: 1893 log.error( 1894 "%s Failed to publish artifact saved notification: %s", 1895 log_identifier, 1896 e, 1897 ) 1898 1899 async def _publish_status_update_with_buffer_flush( 1900 self, 1901 status_update_event: TaskStatusUpdateEvent, 1902 a2a_context: Dict, 1903 skip_buffer_flush: bool = False, 1904 ) -> None: 1905 """ 1906 Central method for publishing status updates with automatic buffer flushing. 1907 1908 Args: 1909 status_update_event: The status update event to publish 1910 a2a_context: The A2A context dictionary for the current task 1911 skip_buffer_flush: If True, skip buffer flushing (used for LLM streaming text) 1912 """ 1913 logical_task_id = a2a_context.get("logical_task_id", "unknown_task") 1914 jsonrpc_request_id = a2a_context.get("jsonrpc_request_id") 1915 log_identifier = f"{self.log_identifier}[StatusUpdate:{logical_task_id}]" 1916 1917 status_type = "unknown" 1918 if status_update_event.metadata: 1919 if status_update_event.metadata.get("type") == "tool_invocation_start": 1920 status_type = "tool_invocation_start" 1921 elif "agent_name" in status_update_event.metadata: 1922 status_type = "agent_status" 1923 1924 if ( 1925 status_update_event.status 1926 and status_update_event.status.message 1927 and status_update_event.status.message.parts 1928 ): 1929 for part in status_update_event.status.message.parts: 1930 if hasattr(part, "data") and part.data: 1931 if part.data.get("a2a_signal_type") == "agent_status_message": 1932 status_type = "agent_status_signal" 1933 break 1934 elif "tool_error" in part.data: 1935 status_type = "tool_failure" 1936 break 1937 1938 log.debug( 1939 "%s Publishing status update (type: %s, skip_buffer_flush: %s).", 1940 log_identifier, 1941 status_type, 1942 skip_buffer_flush, 1943 ) 1944 1945 if not skip_buffer_flush: 1946 buffer_was_flushed = await self._flush_buffer_if_needed( 1947 a2a_context, reason=f"before_{status_type}_status" 1948 ) 1949 if buffer_was_flushed: 1950 log.info( 1951 "%s Buffer flushed before %s status update.", 1952 log_identifier, 1953 status_type, 1954 ) 1955 1956 try: 1957 rpc_response = a2a.create_success_response( 1958 result=status_update_event, request_id=jsonrpc_request_id 1959 ) 1960 payload_to_publish = rpc_response.model_dump(exclude_none=True) 1961 1962 target_topic = a2a_context.get( 1963 "statusTopic" 1964 ) or a2a.get_gateway_status_topic( 1965 self.namespace, self.get_gateway_id(), logical_task_id 1966 ) 1967 1968 # Construct user_properties to ensure ownership can be determined by gateways 1969 user_properties = { 1970 "a2aUserConfig": a2a_context.get("a2a_user_config"), 1971 "clientId": a2a_context.get("client_id"), 1972 "delegating_agent_name": self.get_config("agent_name"), 1973 } 1974 1975 self._publish_a2a_event( 1976 payload_to_publish, target_topic, a2a_context, user_properties 1977 ) 1978 1979 log.debug( 1980 "%s Published %s status update to %s.", 1981 log_identifier, 1982 status_type, 1983 target_topic, 1984 ) 1985 1986 except Exception as e: 1987 log.exception( 1988 "%s Error publishing %s status update: %s", 1989 log_identifier, 1990 status_type, 1991 e, 1992 ) 1993 raise 1994 1995 async def _filter_text_from_final_streaming_event( 1996 self, adk_event: ADKEvent, a2a_context: Dict 1997 ) -> ADKEvent: 1998 """ 1999 Filters out text parts from the final ADKEvent of a turn for PERSISTENT streaming sessions. 2000 This prevents sending redundant, aggregated text that was already streamed. 2001 Non-text parts like function calls are preserved. 2002 """ 2003 is_run_based_session = a2a_context.get("is_run_based_session", False) 2004 is_streaming = a2a_context.get("is_streaming", False) 2005 is_final_turn_event = not adk_event.partial 2006 has_content_parts = adk_event.content and adk_event.content.parts 2007 2008 # Only filter for PERSISTENT (not run-based) streaming sessions. 2009 if ( 2010 not is_run_based_session 2011 and is_streaming 2012 and is_final_turn_event 2013 and has_content_parts 2014 ): 2015 log_id = f"{self.log_identifier}[FilterFinalStreamEvent:{a2a_context.get('logical_task_id', 'unknown')}]" 2016 log.debug( 2017 "%s Filtering final streaming event to remove redundant text.", log_id 2018 ) 2019 2020 non_text_parts = [ 2021 part for part in adk_event.content.parts if part.text is None 2022 ] 2023 2024 if len(non_text_parts) < len(adk_event.content.parts): 2025 event_copy = adk_event.model_copy(deep=True) 2026 event_copy.content = ( 2027 adk_types.Content(parts=non_text_parts) if non_text_parts else None 2028 ) 2029 log.info( 2030 "%s Removed text from final streaming event. Kept %d non-text part(s).", 2031 log_id, 2032 len(non_text_parts), 2033 ) 2034 return event_copy 2035 2036 return adk_event 2037 2038 async def process_and_publish_adk_event( 2039 self, adk_event: ADKEvent, a2a_context: Dict 2040 ): 2041 """ 2042 Main orchestrator for processing ADK events. 2043 Handles text buffering, embed resolution, and event routing based on 2044 whether the event is partial or the final event of a turn. 2045 """ 2046 logical_task_id = a2a_context.get("logical_task_id", "unknown_task") 2047 log_id_main = ( 2048 f"{self.log_identifier}[ProcessADKEvent:{logical_task_id}:{adk_event.id}]" 2049 ) 2050 log.debug( 2051 "%s Received ADKEvent (Partial: %s, Final Turn: %s).", 2052 log_id_main, 2053 adk_event.partial, 2054 not adk_event.partial, 2055 ) 2056 2057 if adk_event.content and adk_event.content.parts: 2058 if any( 2059 p.function_response 2060 and p.function_response.name == "_continue_generation" 2061 for p in adk_event.content.parts 2062 ): 2063 log.debug( 2064 "%s Discarding _continue_generation tool response event.", 2065 log_id_main, 2066 ) 2067 return 2068 2069 if adk_event.custom_metadata and adk_event.custom_metadata.get( 2070 "was_interrupted" 2071 ): 2072 log.debug( 2073 "%s Found 'was_interrupted' signal. Skipping event.", 2074 log_id_main, 2075 ) 2076 return 2077 2078 with self.active_tasks_lock: 2079 task_context = self.active_tasks.get(logical_task_id) 2080 2081 if not task_context: 2082 log.error( 2083 "%s TaskExecutionContext not found for task %s. Cannot process ADK event.", 2084 log_id_main, 2085 logical_task_id, 2086 ) 2087 return 2088 2089 is_run_based_session = a2a_context.get("is_run_based_session", False) 2090 is_final_turn_event = not adk_event.partial 2091 2092 try: 2093 from solace_agent_mesh_enterprise.auth.tool_auth import ( 2094 handle_tool_auth_event, 2095 ) 2096 2097 auth_status_update = await handle_tool_auth_event( 2098 adk_event, self, a2a_context 2099 ) 2100 if auth_status_update: 2101 await self._publish_status_update_with_buffer_flush( 2102 auth_status_update, 2103 a2a_context, 2104 skip_buffer_flush=False, 2105 ) 2106 return 2107 except ImportError: 2108 pass 2109 2110 if not is_final_turn_event: 2111 if adk_event.content and adk_event.content.parts: 2112 for part in adk_event.content.parts: 2113 if part.text is not None: 2114 # Check if this is a new turn by comparing invocation_id 2115 if adk_event.invocation_id: 2116 task_context.check_and_update_invocation( 2117 adk_event.invocation_id 2118 ) 2119 is_first_text = task_context.is_first_text_in_turn() 2120 should_add_spacing = task_context.should_add_turn_spacing() 2121 2122 # Add spacing if this is the first text of a new turn 2123 # We add it BEFORE the text, regardless of current buffer content 2124 if should_add_spacing and is_first_text: 2125 # Add double newline to separate turns (new paragraph) 2126 task_context.append_to_streaming_buffer("\n\n") 2127 log.debug( 2128 "%s Added turn spacing before new invocation %s", 2129 log_id_main, 2130 adk_event.invocation_id, 2131 ) 2132 2133 task_context.append_to_streaming_buffer(part.text) 2134 log.debug( 2135 "%s Appended text to buffer. New buffer size: %d bytes", 2136 log_id_main, 2137 len( 2138 task_context.get_streaming_buffer_content().encode( 2139 "utf-8" 2140 ) 2141 ), 2142 ) 2143 2144 buffer_content = task_context.get_streaming_buffer_content() 2145 batching_disabled = self.stream_batching_threshold_bytes <= 0 2146 buffer_has_content = bool(buffer_content) 2147 threshold_met = ( 2148 buffer_has_content 2149 and not batching_disabled 2150 and ( 2151 len(buffer_content.encode("utf-8")) 2152 >= self.stream_batching_threshold_bytes 2153 ) 2154 ) 2155 2156 if buffer_has_content and (batching_disabled or threshold_met): 2157 log.debug( 2158 "%s Partial event triggered buffer flush due to size/batching config.", 2159 log_id_main, 2160 ) 2161 resolved_text, _ = await self._flush_and_resolve_buffer( 2162 a2a_context, is_final=False 2163 ) 2164 2165 if resolved_text: 2166 if is_run_based_session: 2167 task_context.append_to_run_based_buffer(resolved_text) 2168 log.debug( 2169 "%s [RUN_BASED] Appended %d bytes to run_based_response_buffer.", 2170 log_id_main, 2171 len(resolved_text.encode("utf-8")), 2172 ) 2173 else: 2174 await self._publish_text_as_partial_a2a_status_update( 2175 resolved_text, a2a_context 2176 ) 2177 else: 2178 buffer_content = task_context.get_streaming_buffer_content() 2179 if buffer_content: 2180 log.debug( 2181 "%s Final event triggered flush of remaining buffer content.", 2182 log_id_main, 2183 ) 2184 resolved_text, _ = await self._flush_and_resolve_buffer( 2185 a2a_context, is_final=True 2186 ) 2187 if resolved_text: 2188 if is_run_based_session: 2189 task_context.append_to_run_based_buffer(resolved_text) 2190 log.debug( 2191 "%s [RUN_BASED] Appended final %d bytes to run_based_response_buffer.", 2192 log_id_main, 2193 len(resolved_text.encode("utf-8")), 2194 ) 2195 else: 2196 await self._publish_text_as_partial_a2a_status_update( 2197 resolved_text, a2a_context 2198 ) 2199 2200 # Prepare and publish the final event for observability 2201 event_to_publish = await self._filter_text_from_final_streaming_event( 2202 adk_event, a2a_context 2203 ) 2204 2205 ( 2206 a2a_payload, 2207 target_topic, 2208 user_properties, 2209 _, 2210 ) = await format_and_route_adk_event(event_to_publish, a2a_context, self) 2211 2212 if a2a_payload and target_topic: 2213 self._publish_a2a_event(a2a_payload, target_topic, a2a_context) 2214 log.debug( 2215 "%s Published final turn event (e.g., tool call) to %s.", 2216 log_id_main, 2217 target_topic, 2218 ) 2219 else: 2220 log.debug( 2221 "%s Final turn event did not result in a publishable A2A message.", 2222 log_id_main, 2223 ) 2224 2225 await self._handle_artifact_return_signals(adk_event, a2a_context) 2226 2227 async def _flush_and_resolve_buffer( 2228 self, a2a_context: Dict, is_final: bool 2229 ) -> Tuple[str, str]: 2230 """Flushes buffer, resolves embeds, handles signals, returns (resolved_text, unprocessed_tail).""" 2231 logical_task_id = a2a_context.get("logical_task_id", "unknown_task") 2232 log_id = f"{self.log_identifier}[FlushBuffer:{logical_task_id}]" 2233 2234 with self.active_tasks_lock: 2235 task_context = self.active_tasks.get(logical_task_id) 2236 2237 if not task_context: 2238 log.error( 2239 "%s TaskExecutionContext not found for task %s. Cannot flush/resolve buffer.", 2240 log_id, 2241 logical_task_id, 2242 ) 2243 return "", "" 2244 2245 text_to_process = task_context.flush_streaming_buffer() 2246 2247 resolved_text, signals_found, unprocessed_tail = ( 2248 await self._resolve_early_embeds_and_handle_signals( 2249 text_to_process, a2a_context 2250 ) 2251 ) 2252 2253 if not is_final: 2254 if unprocessed_tail: 2255 task_context.append_to_streaming_buffer(unprocessed_tail) 2256 log.debug( 2257 "%s Placed unprocessed tail (length %d) back into buffer.", 2258 log_id, 2259 len(unprocessed_tail.encode("utf-8")), 2260 ) 2261 else: 2262 if unprocessed_tail is not None and unprocessed_tail != "": 2263 resolved_text = resolved_text + unprocessed_tail 2264 2265 if signals_found: 2266 log.info( 2267 "%s Handling %d signals from buffer resolution.", 2268 log_id, 2269 len(signals_found), 2270 ) 2271 for _signal_index, signal_data_tuple, _placeholder in signals_found: 2272 if ( 2273 isinstance(signal_data_tuple, tuple) 2274 and len(signal_data_tuple) == 3 2275 and signal_data_tuple[0] is None 2276 and signal_data_tuple[1] == "SIGNAL_STATUS_UPDATE" 2277 ): 2278 status_text = signal_data_tuple[2] 2279 log.info( 2280 "%s Publishing SIGNAL_STATUS_UPDATE from buffer: '%s'", 2281 log_id, 2282 status_text, 2283 ) 2284 await self._publish_agent_status_signal_update( 2285 status_text, a2a_context 2286 ) 2287 resolved_text = resolved_text.replace(_placeholder, "") 2288 2289 return resolved_text, unprocessed_tail 2290 2291 async def _handle_artifact_return_signals( 2292 self, adk_event: ADKEvent, a2a_context: Dict 2293 ): 2294 """ 2295 Processes artifact return signals. 2296 This method is triggered by a placeholder in state_delta, but reads the 2297 actual list of signals from the TaskExecutionContext. 2298 """ 2299 logical_task_id = a2a_context.get("logical_task_id", "unknown_task") 2300 log_id = f"{self.log_identifier}[ArtifactSignals:{logical_task_id}]" 2301 2302 # Check for the trigger in state_delta. The presence of any key is enough. 2303 has_signal_trigger = ( 2304 adk_event.actions 2305 and adk_event.actions.state_delta 2306 and any( 2307 k.startswith("temp:a2a_return_artifact:") 2308 for k in adk_event.actions.state_delta 2309 ) 2310 ) 2311 2312 if not has_signal_trigger: 2313 return 2314 2315 with self.active_tasks_lock: 2316 task_context = self.active_tasks.get(logical_task_id) 2317 2318 if not task_context: 2319 log.warning( 2320 "%s No TaskExecutionContext found for task %s. Cannot process artifact signals.", 2321 log_id, 2322 logical_task_id, 2323 ) 2324 return 2325 2326 all_signals = task_context.get_and_clear_artifact_signals() 2327 2328 if not all_signals: 2329 log.info( 2330 "%s Triggered for artifact signals, but none were found in the execution context.", 2331 log_id, 2332 ) 2333 return 2334 2335 log.info( 2336 "%s Found %d artifact return signal(s) in the execution context.", 2337 log_id, 2338 len(all_signals), 2339 ) 2340 2341 original_session_id = a2a_context.get("session_id") 2342 user_id = a2a_context.get("user_id") 2343 adk_app_name = self.get_config("agent_name") 2344 2345 peer_status_topic = a2a_context.get("statusTopic") 2346 namespace = self.get_config("namespace") 2347 gateway_id = self.get_gateway_id() 2348 2349 artifact_topic = peer_status_topic or a2a.get_gateway_status_topic( 2350 namespace, gateway_id, logical_task_id 2351 ) 2352 2353 if not self.artifact_service: 2354 log.error("%s Artifact service not available.", log_id) 2355 return 2356 if not artifact_topic: 2357 log.error("%s Could not determine artifact topic.", log_id) 2358 return 2359 2360 for item in all_signals: 2361 try: 2362 filename = item["filename"] 2363 version = item["version"] 2364 2365 log.info( 2366 "%s Processing artifact return signal for '%s' v%d from context.", 2367 log_id, 2368 filename, 2369 version, 2370 ) 2371 2372 loaded_adk_part = await self.artifact_service.load_artifact( 2373 app_name=adk_app_name, 2374 user_id=user_id, 2375 session_id=original_session_id, 2376 filename=filename, 2377 version=version, 2378 ) 2379 2380 if not loaded_adk_part: 2381 log.warning( 2382 "%s Failed to load artifact '%s' v%d.", 2383 log_id, 2384 filename, 2385 version, 2386 ) 2387 continue 2388 2389 a2a_file_part = await a2a.translate_adk_part_to_a2a_filepart( 2390 adk_part=loaded_adk_part, 2391 filename=filename, 2392 a2a_context=a2a_context, 2393 artifact_service=self.artifact_service, 2394 artifact_handling_mode=self.artifact_handling_mode, 2395 adk_app_name=self.get_config("agent_name"), 2396 log_identifier=self.log_identifier, 2397 version=version, 2398 ) 2399 2400 if a2a_file_part: 2401 a2a_message = a2a.create_agent_parts_message( 2402 parts=[a2a_file_part], 2403 task_id=logical_task_id, 2404 context_id=original_session_id, 2405 ) 2406 task_status = a2a.create_task_status( 2407 state=TaskState.working, message=a2a_message 2408 ) 2409 status_update_event = TaskStatusUpdateEvent( 2410 task_id=logical_task_id, 2411 context_id=original_session_id, 2412 status=task_status, 2413 final=False, 2414 kind="status-update", 2415 ) 2416 artifact_payload = a2a.create_success_response( 2417 result=status_update_event, 2418 request_id=a2a_context.get("jsonrpc_request_id"), 2419 ).model_dump(exclude_none=True) 2420 2421 self._publish_a2a_event( 2422 artifact_payload, artifact_topic, a2a_context 2423 ) 2424 2425 log.info( 2426 "%s Published TaskStatusUpdateEvent with FilePart for '%s' to %s", 2427 log_id, 2428 filename, 2429 artifact_topic, 2430 ) 2431 else: 2432 log.warning( 2433 "%s Failed to translate artifact '%s' v%d to A2A FilePart.", 2434 log_id, 2435 filename, 2436 version, 2437 ) 2438 2439 except Exception as e: 2440 log.exception( 2441 "%s Error processing artifact signal item %s from context: %s", 2442 log_id, 2443 item, 2444 e, 2445 ) 2446 2447 def _format_final_task_status( 2448 self, last_event: Optional[ADKEvent], override_text: Optional[str] = None 2449 ) -> TaskStatus: 2450 """Helper to format the final TaskStatus based on the last ADK event.""" 2451 log.debug( 2452 "%s Formatting final task status from last ADK event %s", 2453 self.log_identifier, 2454 last_event.id if last_event else "None", 2455 ) 2456 a2a_state = TaskState.completed 2457 a2a_parts = [] 2458 2459 if override_text is not None: 2460 a2a_parts.append(a2a.create_text_part(text=override_text)) 2461 # Add non-text parts from the last event 2462 if last_event and last_event.content and last_event.content.parts: 2463 for part in last_event.content.parts: 2464 if part.text is None: 2465 if part.function_response: 2466 a2a_parts.extend( 2467 a2a.translate_adk_function_response_to_a2a_parts(part) 2468 ) 2469 else: 2470 # Original logic 2471 if last_event and last_event.content and last_event.content.parts: 2472 for part in last_event.content.parts: 2473 if part.text: 2474 a2a_parts.append(a2a.create_text_part(text=part.text)) 2475 elif part.function_response: 2476 a2a_parts.extend( 2477 a2a.translate_adk_function_response_to_a2a_parts(part) 2478 ) 2479 2480 if last_event and last_event.actions: 2481 if last_event.actions.requested_auth_configs: 2482 a2a_state = TaskState.input_required 2483 a2a_parts.append( 2484 a2a.create_text_part(text="[Agent requires input/authentication]") 2485 ) 2486 2487 if not a2a_parts: 2488 a2a_message = a2a.create_agent_text_message(text="") 2489 else: 2490 a2a_message = a2a.create_agent_parts_message(parts=a2a_parts) 2491 return a2a.create_task_status(state=a2a_state, message=a2a_message) 2492 2493 async def finalize_task_success(self, a2a_context: Dict): 2494 """ 2495 Finalizes a task successfully. Fetches final state, publishes final A2A response, 2496 and ACKs the original message. 2497 For RUN_BASED tasks, it uses the aggregated response buffer. 2498 For STREAMING tasks, it uses the content of the last ADK event. 2499 """ 2500 logical_task_id = a2a_context.get("logical_task_id") 2501 2502 # Retrieve the original Solace message from TaskExecutionContext 2503 original_message: Optional[SolaceMessage] = None 2504 with self.active_tasks_lock: 2505 task_context = self.active_tasks.get(logical_task_id) 2506 if task_context: 2507 original_message = task_context.get_original_solace_message() 2508 2509 log.info( 2510 "%s Finalizing task %s successfully.", self.log_identifier, logical_task_id 2511 ) 2512 try: 2513 session_id_to_retrieve = a2a_context.get( 2514 "effective_session_id", a2a_context.get("session_id") 2515 ) 2516 original_session_id = a2a_context.get("session_id") 2517 user_id = a2a_context.get("user_id") 2518 client_id = a2a_context.get("client_id") 2519 jsonrpc_request_id = a2a_context.get("jsonrpc_request_id") 2520 peer_reply_topic = a2a_context.get("replyToTopic") 2521 namespace = self.get_config("namespace") 2522 agent_name = self.get_config("agent_name") 2523 is_run_based_session = a2a_context.get("is_run_based_session", False) 2524 2525 final_status: TaskStatus 2526 2527 with self.active_tasks_lock: 2528 task_context = self.active_tasks.get(logical_task_id) 2529 2530 final_adk_session = await self.session_service.get_session( 2531 app_name=agent_name, 2532 user_id=user_id, 2533 session_id=session_id_to_retrieve, 2534 ) 2535 if not final_adk_session: 2536 raise RuntimeError( 2537 f"Could not retrieve final session state for {session_id_to_retrieve}" 2538 ) 2539 2540 last_event = ( 2541 final_adk_session.events[-1] if final_adk_session.events else None 2542 ) 2543 2544 if is_run_based_session: 2545 aggregated_text = "" 2546 if task_context: 2547 aggregated_text = task_context.run_based_response_buffer 2548 log.info( 2549 "%s Using aggregated response buffer for RUN_BASED task %s (length: %d bytes).", 2550 self.log_identifier, 2551 logical_task_id, 2552 len(aggregated_text.encode("utf-8")), 2553 ) 2554 final_status = self._format_final_task_status( 2555 last_event, override_text=aggregated_text 2556 ) 2557 else: 2558 if last_event: 2559 final_status = self._format_final_task_status(last_event) 2560 else: 2561 final_status = a2a.create_task_status( 2562 state=TaskState.completed, 2563 message=a2a.create_agent_text_message(text="Task completed."), 2564 ) 2565 2566 final_a2a_artifacts: List[A2AArtifact] = [] 2567 log.debug( 2568 "%s Final artifact bundling is removed. Artifacts sent via TaskArtifactUpdateEvent.", 2569 self.log_identifier, 2570 ) 2571 2572 final_task_metadata = {"agent_name": agent_name} 2573 if task_context and task_context.produced_artifacts: 2574 final_task_metadata["produced_artifacts"] = ( 2575 task_context.produced_artifacts 2576 ) 2577 log.info( 2578 "%s Attaching manifest of %d produced artifacts to final task metadata.", 2579 self.log_identifier, 2580 len(task_context.produced_artifacts), 2581 ) 2582 else: 2583 if not task_context: 2584 log.warning( 2585 "%s TaskExecutionContext not found for task %s during finalization, cannot attach produced artifacts.", 2586 self.log_identifier, 2587 logical_task_id, 2588 ) 2589 else: 2590 log.debug( 2591 "%s No produced artifacts to attach for task %s.", 2592 self.log_identifier, 2593 logical_task_id, 2594 ) 2595 2596 # Add token usage summary 2597 if task_context: 2598 token_summary = task_context.get_token_usage_summary() 2599 if token_summary["total_tokens"] > 0: 2600 final_task_metadata["token_usage"] = token_summary 2601 log.info( 2602 "%s Task %s used %d total tokens (input: %d, output: %d, cached: %d)", 2603 self.log_identifier, 2604 logical_task_id, 2605 token_summary["total_tokens"], 2606 token_summary["total_input_tokens"], 2607 token_summary["total_output_tokens"], 2608 token_summary["total_cached_input_tokens"], 2609 ) 2610 2611 final_task = a2a.create_final_task( 2612 task_id=logical_task_id, 2613 context_id=original_session_id, 2614 final_status=final_status, 2615 artifacts=(final_a2a_artifacts if final_a2a_artifacts else None), 2616 metadata=final_task_metadata, 2617 ) 2618 final_response = a2a.create_success_response( 2619 result=final_task, request_id=jsonrpc_request_id 2620 ) 2621 a2a_payload = final_response.model_dump(exclude_none=True) 2622 target_topic = peer_reply_topic or a2a.get_client_response_topic( 2623 namespace, client_id 2624 ) 2625 2626 self._publish_a2a_event(a2a_payload, target_topic, a2a_context) 2627 log.info( 2628 "%s Published final successful response for task %s to %s (Artifacts NOT bundled).", 2629 self.log_identifier, 2630 logical_task_id, 2631 target_topic, 2632 ) 2633 if original_message: 2634 try: 2635 original_message.call_acknowledgements() 2636 log.info( 2637 "%s Called ACK for original message of task %s.", 2638 self.log_identifier, 2639 logical_task_id, 2640 ) 2641 except Exception as ack_e: 2642 log.error( 2643 "%s Failed to call ACK for task %s: %s", 2644 self.log_identifier, 2645 logical_task_id, 2646 ack_e, 2647 ) 2648 else: 2649 log.warning( 2650 "%s Original Solace message not found in context for task %s. Cannot ACK.", 2651 self.log_identifier, 2652 logical_task_id, 2653 ) 2654 2655 except Exception as e: 2656 log.exception( 2657 "%s Error during successful finalization of task %s: %s", 2658 self.log_identifier, 2659 logical_task_id, 2660 e, 2661 ) 2662 if original_message: 2663 try: 2664 original_message.call_negative_acknowledgements() 2665 log.warning( 2666 "%s Called NACK for original message of task %s due to finalization error.", 2667 self.log_identifier, 2668 logical_task_id, 2669 ) 2670 except Exception as nack_e: 2671 log.error( 2672 "%s Failed to call NACK for task %s after finalization error: %s", 2673 self.log_identifier, 2674 logical_task_id, 2675 nack_e, 2676 ) 2677 else: 2678 log.warning( 2679 "%s Original Solace message not found in context for task %s during finalization error. Cannot NACK.", 2680 self.log_identifier, 2681 logical_task_id, 2682 ) 2683 2684 try: 2685 jsonrpc_request_id = a2a_context.get("jsonrpc_request_id") 2686 client_id = a2a_context.get("client_id") 2687 peer_reply_topic = a2a_context.get("replyToTopic") 2688 namespace = self.get_config("namespace") 2689 error_response = a2a.create_internal_error_response( 2690 message=f"Failed to finalize successful task: {e}", 2691 request_id=jsonrpc_request_id, 2692 data={"taskId": logical_task_id}, 2693 ) 2694 target_topic = peer_reply_topic or a2a.get_client_response_topic( 2695 namespace, client_id 2696 ) 2697 self.publish_a2a_message( 2698 error_response.model_dump(exclude_none=True), target_topic 2699 ) 2700 except Exception as report_err: 2701 log.error( 2702 "%s Failed to report finalization error for task %s: %s", 2703 self.log_identifier, 2704 logical_task_id, 2705 report_err, 2706 ) 2707 2708 def finalize_task_canceled(self, a2a_context: Dict): 2709 """ 2710 Finalizes a task as CANCELED. Publishes A2A Task response with CANCELED state 2711 and ACKs the original message if available. 2712 Called by the background ADK thread wrapper when a task is cancelled. 2713 """ 2714 logical_task_id = a2a_context.get("logical_task_id") 2715 2716 # Retrieve the original Solace message from TaskExecutionContext 2717 original_message: Optional[SolaceMessage] = None 2718 with self.active_tasks_lock: 2719 task_context = self.active_tasks.get(logical_task_id) 2720 if task_context: 2721 original_message = task_context.get_original_solace_message() 2722 2723 log.info( 2724 "%s Finalizing task %s as CANCELED.", self.log_identifier, logical_task_id 2725 ) 2726 try: 2727 jsonrpc_request_id = a2a_context.get("jsonrpc_request_id") 2728 client_id = a2a_context.get("client_id") 2729 peer_reply_topic = a2a_context.get("replyToTopic") 2730 namespace = self.get_config("namespace") 2731 2732 canceled_status = a2a.create_task_status( 2733 state=TaskState.canceled, 2734 message=a2a.create_agent_text_message( 2735 text="Task cancelled by request." 2736 ), 2737 ) 2738 agent_name = self.get_config("agent_name") 2739 final_task = a2a.create_final_task( 2740 task_id=logical_task_id, 2741 context_id=a2a_context.get("contextId"), 2742 final_status=canceled_status, 2743 metadata={"agent_name": agent_name}, 2744 ) 2745 final_response = a2a.create_success_response( 2746 result=final_task, request_id=jsonrpc_request_id 2747 ) 2748 a2a_payload = final_response.model_dump(exclude_none=True) 2749 target_topic = peer_reply_topic or a2a.get_client_response_topic( 2750 namespace, client_id 2751 ) 2752 2753 self._publish_a2a_event(a2a_payload, target_topic, a2a_context) 2754 log.info( 2755 "%s Published final CANCELED response for task %s to %s.", 2756 self.log_identifier, 2757 logical_task_id, 2758 target_topic, 2759 ) 2760 2761 if original_message: 2762 try: 2763 original_message.call_acknowledgements() 2764 log.info( 2765 "%s Called ACK for original message of cancelled task %s.", 2766 self.log_identifier, 2767 logical_task_id, 2768 ) 2769 except Exception as ack_e: 2770 log.error( 2771 "%s Failed to call ACK for cancelled task %s: %s", 2772 self.log_identifier, 2773 logical_task_id, 2774 ack_e, 2775 ) 2776 else: 2777 log.warning( 2778 "%s Original Solace message not found in context for cancelled task %s. Cannot ACK.", 2779 self.log_identifier, 2780 logical_task_id, 2781 ) 2782 2783 except Exception as e: 2784 log.exception( 2785 "%s Error during CANCELED finalization of task %s: %s", 2786 self.log_identifier, 2787 logical_task_id, 2788 e, 2789 ) 2790 if original_message: 2791 try: 2792 original_message.call_negative_acknowledgements() 2793 except Exception: 2794 pass 2795 2796 async def _publish_tool_failure_status( 2797 self, exception: Exception, a2a_context: Dict 2798 ): 2799 """ 2800 Publishes an intermediate status update indicating a tool execution has failed. 2801 This method will flush the buffer before publishing to maintain proper message ordering. 2802 """ 2803 logical_task_id = a2a_context.get("logical_task_id") 2804 log_identifier_helper = ( 2805 f"{self.log_identifier}[ToolFailureStatus:{logical_task_id}]" 2806 ) 2807 try: 2808 # Create the status update event 2809 tool_error_data_part = a2a.create_data_part( 2810 data={ 2811 "a2a_signal_type": "tool_execution_error", 2812 "error_message": str(exception), 2813 "details": "An unhandled exception occurred during tool execution.", 2814 } 2815 ) 2816 2817 status_message = a2a.create_agent_parts_message( 2818 parts=[tool_error_data_part], 2819 task_id=logical_task_id, 2820 context_id=a2a_context.get("contextId"), 2821 ) 2822 status_update_event = a2a.create_status_update( 2823 task_id=logical_task_id, 2824 context_id=a2a_context.get("contextId"), 2825 message=status_message, 2826 is_final=False, 2827 metadata={"agent_name": self.get_config("agent_name")}, 2828 ) 2829 2830 await self._publish_status_update_with_buffer_flush( 2831 status_update_event, 2832 a2a_context, 2833 skip_buffer_flush=False, 2834 ) 2835 2836 log.debug( 2837 "%s Published tool failure status update.", 2838 log_identifier_helper, 2839 ) 2840 2841 except Exception as e: 2842 log.error( 2843 "%s Failed to publish intermediate tool failure status: %s", 2844 log_identifier_helper, 2845 e, 2846 ) 2847 2848 async def _repair_session_history_on_error( 2849 self, exception: Exception, a2a_context: Dict 2850 ): 2851 """ 2852 Reactively repairs the session history if the last event was a tool call. 2853 This is "the belt" in the belt-and-suspenders strategy. 2854 """ 2855 log_identifier = f"{self.log_identifier}[HistoryRepair]" 2856 try: 2857 from ...agent.adk.callbacks import create_dangling_tool_call_repair_content 2858 from ...agent.adk.services import append_event_with_retry 2859 2860 session_id = a2a_context.get("effective_session_id") 2861 user_id = a2a_context.get("user_id") 2862 agent_name = self.get_config("agent_name") 2863 2864 if not all([session_id, user_id, agent_name, self.session_service]): 2865 log.warning( 2866 "%s Skipping history repair due to missing context.", log_identifier 2867 ) 2868 return 2869 2870 session = await self.session_service.get_session( 2871 app_name=agent_name, user_id=user_id, session_id=session_id 2872 ) 2873 2874 if not session or not session.events: 2875 log.debug( 2876 "%s No session or events found for history repair.", log_identifier 2877 ) 2878 return 2879 2880 last_event = session.events[-1] 2881 function_calls = last_event.get_function_calls() 2882 2883 if not function_calls: 2884 log.debug( 2885 "%s Last event was not a function call. No repair needed.", 2886 log_identifier, 2887 ) 2888 return 2889 2890 log.info( 2891 "%s Last event contained function_call(s). Repairing session history.", 2892 log_identifier, 2893 ) 2894 2895 repair_content = create_dangling_tool_call_repair_content( 2896 dangling_calls=function_calls, 2897 error_message=f"Tool execution failed with an unhandled exception: {str(exception)}", 2898 ) 2899 2900 repair_event = ADKEvent( 2901 invocation_id=last_event.invocation_id, 2902 author=agent_name, 2903 content=repair_content, 2904 ) 2905 2906 # Use retry helper to handle stale session race conditions 2907 await append_event_with_retry( 2908 session_service=self.session_service, 2909 session=session, 2910 event=repair_event, 2911 app_name=agent_name, 2912 user_id=user_id, 2913 session_id=session_id, 2914 log_identifier=log_identifier, 2915 ) 2916 log.info( 2917 "%s Session history repaired successfully with an error function_response.", 2918 log_identifier, 2919 ) 2920 2921 except Exception as e: 2922 log.exception( 2923 "%s Critical error during session history repair: %s", log_identifier, e 2924 ) 2925 2926 def finalize_task_limit_reached( 2927 self, a2a_context: Dict, exception: LlmCallsLimitExceededError 2928 ): 2929 """ 2930 Finalizes a task when the LLM call limit is reached, prompting the user to continue. 2931 Sends a COMPLETED status with an informative message. 2932 """ 2933 logical_task_id = a2a_context.get("logical_task_id") 2934 2935 # Retrieve the original Solace message from TaskExecutionContext 2936 original_message: Optional[SolaceMessage] = None 2937 with self.active_tasks_lock: 2938 task_context = self.active_tasks.get(logical_task_id) 2939 if task_context: 2940 original_message = task_context.get_original_solace_message() 2941 2942 log.info( 2943 "%s Finalizing task %s as COMPLETED (LLM call limit reached).", 2944 self.log_identifier, 2945 logical_task_id, 2946 ) 2947 try: 2948 jsonrpc_request_id = a2a_context.get("jsonrpc_request_id") 2949 client_id = a2a_context.get("client_id") 2950 peer_reply_topic = a2a_context.get("replyToTopic") 2951 namespace = self.get_config("namespace") 2952 agent_name = self.get_config("agent_name") 2953 original_session_id = a2a_context.get("session_id") 2954 2955 limit_message_text = ( 2956 f"This interaction has reached its processing limit. " 2957 "If you'd like to continue this conversation, please type 'continue'. " 2958 "Otherwise, you can start a new topic." 2959 ) 2960 2961 final_response = a2a.create_internal_error_response( 2962 message=limit_message_text, 2963 request_id=jsonrpc_request_id, 2964 data={"taskId": logical_task_id, "reason": "llm_call_limit_reached"}, 2965 ) 2966 a2a_payload = final_response.model_dump(exclude_none=True) 2967 2968 target_topic = peer_reply_topic or a2a.get_client_response_topic( 2969 namespace, client_id 2970 ) 2971 2972 self._publish_a2a_event(a2a_payload, target_topic, a2a_context) 2973 log.info( 2974 "%s Published ERROR response for task %s to %s (LLM limit reached, user guided to continue).", 2975 self.log_identifier, 2976 logical_task_id, 2977 target_topic, 2978 ) 2979 2980 if original_message: 2981 try: 2982 original_message.call_acknowledgements() 2983 log.info( 2984 "%s Called ACK for original message of task %s (LLM limit reached).", 2985 self.log_identifier, 2986 logical_task_id, 2987 ) 2988 except Exception as ack_e: 2989 log.error( 2990 "%s Failed to call ACK for task %s (LLM limit reached): %s", 2991 self.log_identifier, 2992 logical_task_id, 2993 ack_e, 2994 ) 2995 else: 2996 log.warning( 2997 "%s Original Solace message not found in context for task %s (LLM limit reached). Cannot ACK.", 2998 self.log_identifier, 2999 logical_task_id, 3000 ) 3001 3002 except Exception as e: 3003 log.exception( 3004 "%s Error during COMPLETED (LLM limit) finalization of task %s: %s", 3005 self.log_identifier, 3006 logical_task_id, 3007 e, 3008 ) 3009 self.finalize_task_error(e, a2a_context) 3010 3011 async def finalize_task_error(self, exception: Exception, a2a_context: Dict): 3012 """ 3013 Finalizes a task with an error. Publishes a final A2A Task with a FAILED 3014 status and NACKs the original message. 3015 Called by the background ADK thread wrapper. 3016 """ 3017 logical_task_id = a2a_context.get("logical_task_id") 3018 3019 # Retrieve the original Solace message from TaskExecutionContext 3020 original_message: Optional[SolaceMessage] = None 3021 with self.active_tasks_lock: 3022 task_context = self.active_tasks.get(logical_task_id) 3023 if task_context: 3024 original_message = task_context.get_original_solace_message() 3025 3026 log.error( 3027 "%s Finalizing task %s with error: %s", 3028 self.log_identifier, 3029 logical_task_id, 3030 exception, 3031 ) 3032 try: 3033 await self._repair_session_history_on_error(exception, a2a_context) 3034 3035 await self._publish_tool_failure_status(exception, a2a_context) 3036 3037 client_id = a2a_context.get("client_id") 3038 jsonrpc_request_id = a2a_context.get("jsonrpc_request_id") 3039 peer_reply_topic = a2a_context.get("replyToTopic") 3040 namespace = self.get_config("namespace") 3041 3042 # Use centralized error handler for all LLM-related exceptions 3043 if is_llm_exception(exception): 3044 error_message, is_context_limit = get_error_message(exception) 3045 3046 if is_context_limit: 3047 log.error( 3048 "%s Context limit exceeded for task %s. Error: %s", 3049 self.log_identifier, 3050 logical_task_id, 3051 exception, 3052 ) 3053 else: 3054 error_message = ( 3055 "An unexpected error occurred while processing your request. " 3056 "Please try again. If the problem persists, contact an administrator." 3057 ) 3058 3059 failed_status = a2a.create_task_status( 3060 state=TaskState.failed, 3061 message=a2a.create_agent_text_message(text=error_message), 3062 ) 3063 3064 final_task = a2a.create_final_task( 3065 task_id=logical_task_id, 3066 context_id=a2a_context.get("contextId"), 3067 final_status=failed_status, 3068 metadata={"agent_name": self.get_config("agent_name")}, 3069 ) 3070 3071 final_response = a2a.create_success_response( 3072 result=final_task, request_id=jsonrpc_request_id 3073 ) 3074 a2a_payload = final_response.model_dump(exclude_none=True) 3075 target_topic = peer_reply_topic or a2a.get_client_response_topic( 3076 namespace, client_id 3077 ) 3078 3079 self._publish_a2a_event(a2a_payload, target_topic, a2a_context) 3080 log.info( 3081 "%s Published final FAILED Task response for task %s to %s", 3082 self.log_identifier, 3083 logical_task_id, 3084 target_topic, 3085 ) 3086 3087 if original_message: 3088 try: 3089 original_message.call_negative_acknowledgements() 3090 log.info( 3091 "%s Called NACK for original message of failed task %s.", 3092 self.log_identifier, 3093 logical_task_id, 3094 ) 3095 except Exception as nack_e: 3096 log.error( 3097 "%s Failed to call NACK for failed task %s: %s", 3098 self.log_identifier, 3099 logical_task_id, 3100 nack_e, 3101 ) 3102 else: 3103 log.warning( 3104 "%s Original Solace message not found in context for failed task %s. Cannot NACK.", 3105 self.log_identifier, 3106 logical_task_id, 3107 ) 3108 3109 except Exception as e: 3110 log.exception( 3111 "%s Error during error finalization of task %s: %s", 3112 self.log_identifier, 3113 logical_task_id, 3114 e, 3115 ) 3116 if original_message: 3117 try: 3118 original_message.call_negative_acknowledgements() 3119 log.warning( 3120 "%s Called NACK for task %s during error finalization fallback.", 3121 self.log_identifier, 3122 logical_task_id, 3123 ) 3124 except Exception as nack_e: 3125 log.error( 3126 "%s Failed to call NACK for task %s during error finalization fallback: %s", 3127 self.log_identifier, 3128 logical_task_id, 3129 nack_e, 3130 ) 3131 else: 3132 log.warning( 3133 "%s Original Solace message not found for task %s during error finalization fallback. Cannot NACK.", 3134 self.log_identifier, 3135 logical_task_id, 3136 ) 3137 3138 async def finalize_task_with_cleanup( 3139 self, a2a_context: Dict, is_paused: bool, exception: Optional[Exception] = None 3140 ): 3141 """ 3142 Centralized async method to finalize a task and perform all necessary cleanup. 3143 This is scheduled on the component's event loop to ensure it runs after 3144 any pending status updates. 3145 3146 Args: 3147 a2a_context: The context dictionary for the task. 3148 is_paused: Boolean indicating if the task is paused for a long-running tool. 3149 exception: The exception that occurred, if any. 3150 """ 3151 logical_task_id = a2a_context.get("logical_task_id", "unknown_task") 3152 log_id = f"{self.log_identifier}[FinalizeTask:{logical_task_id}]" 3153 log.info( 3154 "%s Starting finalization and cleanup. Paused: %s, Exception: %s", 3155 log_id, 3156 is_paused, 3157 type(exception).__name__ if exception else "None", 3158 ) 3159 3160 try: 3161 if is_paused: 3162 log.info( 3163 "%s Task is paused for a long-running tool. Skipping finalization logic.", 3164 log_id, 3165 ) 3166 else: 3167 try: 3168 if exception: 3169 if isinstance(exception, TaskCancelledError): 3170 self.finalize_task_canceled(a2a_context) 3171 elif isinstance(exception, LlmCallsLimitExceededError): 3172 self.finalize_task_limit_reached(a2a_context, exception) 3173 else: 3174 await self.finalize_task_error(exception, a2a_context) 3175 else: 3176 await self.finalize_task_success(a2a_context) 3177 except Exception as e: 3178 log.exception( 3179 "%s An unexpected error occurred during the finalization logic itself: %s", 3180 log_id, 3181 e, 3182 ) 3183 # Retrieve the original Solace message from TaskExecutionContext for fallback NACK 3184 original_message: Optional[SolaceMessage] = None 3185 with self.active_tasks_lock: 3186 task_context = self.active_tasks.get(logical_task_id) 3187 if task_context: 3188 original_message = ( 3189 task_context.get_original_solace_message() 3190 ) 3191 3192 if original_message: 3193 try: 3194 original_message.call_negative_acknowledgements() 3195 except Exception as nack_err: 3196 log.error( 3197 "%s Fallback NACK failed during finalization error: %s", 3198 log_id, 3199 nack_err, 3200 ) 3201 finally: 3202 if not is_paused: 3203 # Cleanup for RUN_BASED sessions remains, as it's a service-level concern 3204 if a2a_context.get("is_run_based_session"): 3205 temp_session_id_to_delete = a2a_context.get( 3206 "temporary_run_session_id_for_cleanup" 3207 ) 3208 agent_name_for_session = a2a_context.get("agent_name_for_session") 3209 user_id_for_session = a2a_context.get("user_id_for_session") 3210 3211 if ( 3212 temp_session_id_to_delete 3213 and agent_name_for_session 3214 and user_id_for_session 3215 ): 3216 log.info( 3217 "%s Cleaning up RUN_BASED session (app: %s, user: %s, id: %s) from shared service for task_id='%s'", 3218 log_id, 3219 agent_name_for_session, 3220 user_id_for_session, 3221 temp_session_id_to_delete, 3222 logical_task_id, 3223 ) 3224 try: 3225 if self.session_service: 3226 await self.session_service.delete_session( 3227 app_name=agent_name_for_session, 3228 user_id=user_id_for_session, 3229 session_id=temp_session_id_to_delete, 3230 ) 3231 else: 3232 log.error( 3233 "%s self.session_service is None, cannot delete RUN_BASED session %s.", 3234 log_id, 3235 temp_session_id_to_delete, 3236 ) 3237 except AttributeError: 3238 log.error( 3239 "%s self.session_service does not support 'delete_session'. Cleanup for RUN_BASED session (app: %s, user: %s, id: %s) skipped.", 3240 log_id, 3241 agent_name_for_session, 3242 user_id_for_session, 3243 temp_session_id_to_delete, 3244 ) 3245 except Exception as e_cleanup: 3246 log.error( 3247 "%s Error cleaning up RUN_BASED session (app: %s, user: %s, id: %s) from shared service: %s", 3248 log_id, 3249 agent_name_for_session, 3250 user_id_for_session, 3251 temp_session_id_to_delete, 3252 e_cleanup, 3253 exc_info=True, 3254 ) 3255 else: 3256 log.warning( 3257 "%s Could not clean up RUN_BASED session for task %s due to missing context (id_to_delete: %s, agent_name: %s, user_id: %s).", 3258 log_id, 3259 logical_task_id, 3260 temp_session_id_to_delete, 3261 agent_name_for_session, 3262 user_id_for_session, 3263 ) 3264 3265 with self.active_tasks_lock: 3266 removed_task_context = self.active_tasks.pop(logical_task_id, None) 3267 if removed_task_context: 3268 log.debug( 3269 "%s Removed TaskExecutionContext for task %s.", 3270 log_id, 3271 logical_task_id, 3272 ) 3273 else: 3274 log.warning( 3275 "%s TaskExecutionContext for task %s was already removed.", 3276 log_id, 3277 logical_task_id, 3278 ) 3279 else: 3280 log.info( 3281 "%s Task %s is paused for a long-running tool. Skipping all cleanup.", 3282 log_id, 3283 logical_task_id, 3284 ) 3285 3286 log.info( 3287 "%s Finalization and cleanup complete for task %s.", 3288 log_id, 3289 logical_task_id, 3290 ) 3291 3292 def _resolve_instruction_provider( 3293 self, config_value: Any 3294 ) -> Union[str, InstructionProvider]: 3295 """Resolves instruction config using helper.""" 3296 return resolve_instruction_provider(self, config_value) 3297 3298 def _get_a2a_base_topic(self) -> str: 3299 """Returns the base topic prefix using helper.""" 3300 return a2a.get_a2a_base_topic(self.namespace) 3301 3302 def _get_discovery_topic(self) -> str: 3303 """Returns the agent discovery topic for publishing.""" 3304 return a2a.get_agent_discovery_topic(self.namespace) 3305 3306 def _get_agent_request_topic(self, agent_id: str) -> str: 3307 """Returns the agent request topic using helper.""" 3308 return a2a.get_agent_request_topic(self.namespace, agent_id) 3309 3310 def _get_agent_response_topic( 3311 self, delegating_agent_name: str, sub_task_id: str 3312 ) -> str: 3313 """Returns the agent response topic using helper.""" 3314 return a2a.get_agent_response_topic( 3315 self.namespace, delegating_agent_name, sub_task_id 3316 ) 3317 3318 def _get_peer_agent_status_topic( 3319 self, delegating_agent_name: str, sub_task_id: str 3320 ) -> str: 3321 """Returns the peer agent status topic using helper.""" 3322 return a2a.get_peer_agent_status_topic( 3323 self.namespace, delegating_agent_name, sub_task_id 3324 ) 3325 3326 def _get_client_response_topic(self, client_id: str) -> str: 3327 """Returns the client response topic using helper.""" 3328 return a2a.get_client_response_topic(self.namespace, client_id) 3329 3330 def _publish_a2a_event( 3331 self, 3332 payload: Dict, 3333 topic: str, 3334 a2a_context: Dict, 3335 user_properties_override: Optional[Dict] = None, 3336 ): 3337 """ 3338 Centralized helper to publish an A2A event, ensuring user properties 3339 are consistently attached from the a2a_context or an override. 3340 """ 3341 if user_properties_override is not None: 3342 user_properties = user_properties_override 3343 else: 3344 user_properties = {} 3345 if a2a_context.get("a2a_user_config"): 3346 user_properties["a2aUserConfig"] = a2a_context["a2a_user_config"] 3347 3348 self.publish_a2a_message(payload, topic, user_properties) 3349 3350 def submit_a2a_task( 3351 self, 3352 target_agent_name: str, 3353 a2a_message: A2AMessage, 3354 user_id: str, 3355 user_config: Dict[str, Any], 3356 sub_task_id: str, 3357 ) -> str: 3358 """ 3359 Submits a task to a peer agent in a non-blocking way. 3360 Returns the sub_task_id for correlation. 3361 """ 3362 log_identifier_helper = ( 3363 f"{self.log_identifier}[SubmitA2ATask:{target_agent_name}]" 3364 ) 3365 main_task_id = a2a_message.metadata.get("parentTaskId", "unknown_parent") 3366 3367 log.debug( 3368 "%s Submitting non-blocking task for main task %s", 3369 log_identifier_helper, 3370 main_task_id, 3371 ) 3372 3373 # Validate agent access is allowed 3374 validate_agent_access( 3375 user_config=user_config, 3376 target_agent_name=target_agent_name, 3377 validation_context={ 3378 "delegating_agent": self.get_config("agent_name"), 3379 "source": "agent_delegation", 3380 }, 3381 log_identifier=log_identifier_helper, 3382 ) 3383 3384 peer_request_topic = self._get_agent_request_topic(target_agent_name) 3385 3386 # Create a compliant SendMessageRequest 3387 send_params = MessageSendParams(message=a2a_message) 3388 a2a_request = SendMessageRequest(id=sub_task_id, params=send_params) 3389 3390 delegating_agent_name = self.get_config("agent_name") 3391 reply_to_topic = self._get_agent_response_topic( 3392 delegating_agent_name=delegating_agent_name, 3393 sub_task_id=sub_task_id, 3394 ) 3395 status_topic = self._get_peer_agent_status_topic( 3396 delegating_agent_name=delegating_agent_name, 3397 sub_task_id=sub_task_id, 3398 ) 3399 3400 user_properties = { 3401 "replyTo": reply_to_topic, 3402 "a2aStatusTopic": status_topic, 3403 "userId": user_id, 3404 "delegating_agent_name": delegating_agent_name, 3405 } 3406 if isinstance(user_config, dict): 3407 user_properties["a2aUserConfig"] = user_config 3408 3409 # Retrieve call depth and auth token from parent task context 3410 parent_task_id = a2a_message.metadata.get("parentTaskId") 3411 current_depth = 0 3412 if parent_task_id: 3413 with self.active_tasks_lock: 3414 parent_task_context = self.active_tasks.get(parent_task_id) 3415 3416 if parent_task_context: 3417 # Get current call depth from parent context 3418 current_depth = parent_task_context.a2a_context.get("call_depth", 0) 3419 3420 auth_token = parent_task_context.get_security_data("auth_token") 3421 if auth_token: 3422 user_properties["authToken"] = auth_token 3423 log.debug( 3424 "%s Propagating authentication token to peer agent %s for sub-task %s", 3425 log_identifier_helper, 3426 target_agent_name, 3427 sub_task_id, 3428 ) 3429 else: 3430 log.debug( 3431 "%s No authentication token found in parent task context for sub-task %s", 3432 log_identifier_helper, 3433 sub_task_id, 3434 ) 3435 else: 3436 log.warning( 3437 "%s Parent task context not found for task %s, cannot propagate authentication token", 3438 log_identifier_helper, 3439 parent_task_id, 3440 ) 3441 3442 # Add call depth to user properties (increment for outgoing call) 3443 user_properties["callDepth"] = current_depth + 1 3444 3445 self.publish_a2a_message( 3446 payload=a2a_request.model_dump(by_alias=True, exclude_none=True), 3447 topic=peer_request_topic, 3448 user_properties=user_properties, 3449 ) 3450 log.info( 3451 "%s Published delegation request to %s (Sub-Task ID: %s, ReplyTo: %s, StatusTo: %s)", 3452 log_identifier_helper, 3453 peer_request_topic, 3454 sub_task_id, 3455 reply_to_topic, 3456 status_topic, 3457 ) 3458 3459 return sub_task_id 3460 3461 def _handle_scheduled_task_completion( 3462 self, future: concurrent.futures.Future, event_type_for_log: EventType 3463 ): 3464 """Callback to handle completion of futures from run_coroutine_threadsafe.""" 3465 try: 3466 if future.cancelled(): 3467 log.warning( 3468 "%s Coroutine for event type %s (scheduled via run_coroutine_threadsafe) was cancelled.", 3469 self.log_identifier, 3470 event_type_for_log, 3471 ) 3472 elif future.done() and future.exception() is not None: 3473 exception = future.exception() 3474 log.error( 3475 "%s Coroutine for event type %s (scheduled via run_coroutine_threadsafe) failed with exception: %s", 3476 self.log_identifier, 3477 event_type_for_log, 3478 exception, 3479 exc_info=exception, 3480 ) 3481 else: 3482 pass 3483 except Exception as e: 3484 log.error( 3485 "%s Error during _handle_scheduled_task_completion (for run_coroutine_threadsafe future) for event type %s: %s", 3486 self.log_identifier, 3487 event_type_for_log, 3488 e, 3489 exc_info=e, 3490 ) 3491 3492 async def _get_toolset_manifest_entries(self, toolset, toolset_type_name): 3493 """Retrieve manifest entries from an MCPToolset or OpenAPIToolset.""" 3494 no_description = "No description available." 3495 try: 3496 log.debug( 3497 "%s Retrieving tools from %s for Agent %s...", 3498 self.log_identifier, 3499 toolset_type_name, 3500 self.agent_name, 3501 ) 3502 tools = await toolset.get_tools() 3503 except Exception as e: 3504 log.error( 3505 "%s Error retrieving tools from %s for Agent Card %s: %s", 3506 self.log_identifier, 3507 toolset_type_name, 3508 self.agent_name, 3509 e, 3510 ) 3511 return [] 3512 toolset_scopes = getattr(toolset, "required_scopes", []) 3513 return [ 3514 { 3515 "id": t.name, 3516 "name": t.name, 3517 "description": t.description or no_description, 3518 "required_scopes": toolset_scopes, 3519 } 3520 for t in tools 3521 ] 3522 3523 def _get_single_tool_manifest_entry(self, tool): 3524 """Build a manifest entry for a single non-toolset tool, or return None.""" 3525 no_description = "No description available." 3526 # For DynamicTool subclasses, use tool_name/tool_description properties 3527 # (the inherited 'name'/'description' attrs may still be placeholders) 3528 if hasattr(tool, "tool_name"): 3529 tool_name = tool.tool_name 3530 tool_description = tool.tool_description 3531 else: 3532 tool_name = getattr(tool, "name", getattr(tool, "__name__", None)) 3533 tool_description = getattr( 3534 tool, "description", getattr(tool, "__doc__", None) 3535 ) 3536 if tool_name is None: 3537 return None 3538 return { 3539 "id": tool_name, 3540 "name": tool_name, 3541 "description": tool_description or no_description, 3542 "required_scopes": self.tool_scopes_map.get(tool_name, []), 3543 } 3544 3545 def _signal_async_init_future(self, *, success, error=None): 3546 """Signal the main thread with the result of async initialization.""" 3547 if not self._async_init_future or self._async_init_future.done(): 3548 log.warning( 3549 "%s _perform_async_init: _async_init_future is None or already done before signaling.", 3550 self.log_identifier, 3551 ) 3552 return 3553 if success: 3554 log.info( 3555 "%s _perform_async_init: Signaling success to main thread.", 3556 self.log_identifier, 3557 ) 3558 self._async_loop.call_soon_threadsafe( 3559 self._async_init_future.set_result, True 3560 ) 3561 else: 3562 log.error( 3563 "%s _perform_async_init: Signaling failure to main thread.", 3564 self.log_identifier, 3565 ) 3566 self._async_loop.call_soon_threadsafe( 3567 self._async_init_future.set_exception, error 3568 ) 3569 3570 async def _build_tool_manifest(self, loaded_tools): 3571 """Build the agent card tool manifest from loaded tools.""" 3572 tool_manifest = [] 3573 for tool in loaded_tools: 3574 if isinstance(tool, (MCPToolset, OpenAPIToolset)): 3575 toolset_type_name = type(tool).__name__ 3576 entries = await self._get_toolset_manifest_entries(tool, toolset_type_name) 3577 tool_manifest.extend(entries) 3578 else: 3579 entry = self._get_single_tool_manifest_entry(tool) 3580 if entry: 3581 tool_manifest.append(entry) 3582 return tool_manifest 3583 3584 async def _perform_async_init(self): 3585 """Coroutine executed on the dedicated loop to perform async initialization.""" 3586 try: 3587 log.info( 3588 "%s Loading tools asynchronously in dedicated thread...", 3589 self.log_identifier, 3590 ) 3591 ( 3592 loaded_tools, 3593 enabled_builtin_tools, 3594 self._tool_cleanup_hooks, 3595 self.tool_scopes_map, 3596 ) = await load_adk_tools(self) 3597 log.info( 3598 "%s Initializing ADK Agent/Runner asynchronously in dedicated thread...", 3599 self.log_identifier, 3600 ) 3601 self.adk_agent = initialize_adk_agent( 3602 self, loaded_tools, enabled_builtin_tools 3603 ) 3604 self.runner = initialize_adk_runner(self) 3605 3606 log.info("%s Populating agent card tool manifest...", self.log_identifier) 3607 self.agent_card_tool_manifest = await self._build_tool_manifest(loaded_tools) 3608 log.info( 3609 "%s Agent card tool manifest populated with %d tools.", 3610 self.log_identifier, 3611 len(self.agent_card_tool_manifest), 3612 ) 3613 3614 log.info( 3615 "%s Async initialization steps complete in dedicated thread.", 3616 self.log_identifier, 3617 ) 3618 self._signal_async_init_future(success=True) 3619 except Exception as e: 3620 log.exception( 3621 "%s _perform_async_init: Error during async initialization in dedicated thread: %s", 3622 self.log_identifier, 3623 e, 3624 ) 3625 self._signal_async_init_future(success=False, error=e) 3626 raise e 3627 3628 def cleanup(self): 3629 """Clean up resources on component shutdown.""" 3630 log.info("%s Cleaning up A2A ADK Host Component.", self.log_identifier) 3631 self.cancel_timer(self._card_publish_timer_id) 3632 self.cancel_timer(self.HEALTH_CHECK_TIMER_ID) 3633 3634 cleanup_func_details = self.get_config("agent_cleanup_function") 3635 3636 from .app import AgentInitCleanupConfig # Avoid circular import 3637 3638 if cleanup_func_details and isinstance( 3639 cleanup_func_details, AgentInitCleanupConfig 3640 ): 3641 module_name = cleanup_func_details.get("module") 3642 func_name = cleanup_func_details.get("name") 3643 base_path = cleanup_func_details.get("base_path") 3644 3645 if module_name and func_name: 3646 log.info( 3647 "%s Attempting to load and execute cleanup_function: %s.%s", 3648 self.log_identifier, 3649 module_name, 3650 func_name, 3651 ) 3652 try: 3653 module = import_module(module_name, base_path=base_path) 3654 cleanup_function = getattr(module, func_name) 3655 3656 if not callable(cleanup_function): 3657 log.error( 3658 "%s Cleanup function '%s' in module '%s' is not callable. Skipping.", 3659 self.log_identifier, 3660 func_name, 3661 module_name, 3662 ) 3663 else: 3664 cleanup_function(self) 3665 log.info( 3666 "%s Successfully executed cleanup_function: %s.%s", 3667 self.log_identifier, 3668 module_name, 3669 func_name, 3670 ) 3671 except Exception as e: 3672 log.exception( 3673 "%s Error during agent cleanup via cleanup_function '%s.%s': %s", 3674 self.log_identifier, 3675 module_name, 3676 func_name, 3677 e, 3678 ) 3679 if self._tool_cleanup_hooks: 3680 log.info( 3681 "%s Executing %d tool cleanup hooks...", 3682 self.log_identifier, 3683 len(self._tool_cleanup_hooks), 3684 ) 3685 if self._async_loop and self._async_loop.is_running(): 3686 3687 async def run_tool_cleanup(): 3688 results = await asyncio.gather( 3689 *[hook() for hook in self._tool_cleanup_hooks], 3690 return_exceptions=True, 3691 ) 3692 for i, result in enumerate(results): 3693 if isinstance(result, Exception): 3694 log.error( 3695 "%s Error during tool cleanup hook #%d: %s", 3696 self.log_identifier, 3697 i, 3698 result, 3699 exc_info=result, 3700 ) 3701 3702 future = asyncio.run_coroutine_threadsafe( 3703 run_tool_cleanup(), self._async_loop 3704 ) 3705 try: 3706 future.result(timeout=15) # Wait for cleanup to complete 3707 log.info("%s All tool cleanup hooks executed.", self.log_identifier) 3708 except Exception as e: 3709 log.error( 3710 "%s Exception while waiting for tool cleanup hooks to finish: %s", 3711 self.log_identifier, 3712 e, 3713 ) 3714 else: 3715 log.warning( 3716 "%s Cannot execute tool cleanup hooks because the async loop is not running.", 3717 self.log_identifier, 3718 ) 3719 3720 # The base class cleanup() will handle stopping the async loop and joining the thread. 3721 # We just need to cancel any active tasks before that happens. 3722 with self.active_tasks_lock: 3723 if self._async_loop and self._async_loop.is_running(): 3724 for task_context in self.active_tasks.values(): 3725 task_context.cancel() 3726 self.active_tasks.clear() 3727 log.debug("%s Cleared all active tasks.", self.log_identifier) 3728 3729 super().cleanup() 3730 log.info("%s Component cleanup finished.", self.log_identifier) 3731 3732 def set_agent_specific_state(self, key: str, value: Any): 3733 """ 3734 Sets a key-value pair in the agent-specific state. 3735 Intended to be used by the custom init_function. 3736 """ 3737 if not hasattr(self, "agent_specific_state"): 3738 self.agent_specific_state = {} 3739 self.agent_specific_state[key] = value 3740 log.debug("%s Set agent_specific_state['%s']", self.log_identifier, key) 3741 3742 def get_agent_specific_state(self, key: str, default: Optional[Any] = None) -> Any: 3743 """ 3744 Gets a value from the agent-specific state. 3745 Intended to be used by tools and the custom cleanup_function. 3746 """ 3747 if not hasattr(self, "agent_specific_state"): 3748 return default 3749 return self.agent_specific_state.get(key, default) 3750 3751 def get_async_loop(self) -> Optional[asyncio.AbstractEventLoop]: 3752 """Returns the dedicated asyncio event loop for this component's async tasks.""" 3753 return self._async_loop 3754 3755 def publish_data_signal_from_thread( 3756 self, 3757 a2a_context: Dict[str, Any], 3758 signal_data: BaseModel, 3759 skip_buffer_flush: bool = False, 3760 log_identifier: Optional[str] = None, 3761 ) -> bool: 3762 """ 3763 Publishes a data signal status update from any thread by scheduling it on the async loop. 3764 3765 This is a convenience method for tools and callbacks that need to publish status updates 3766 but are not running in an async context. It handles: 3767 1. Extracting task_id and context_id from a2a_context 3768 2. Creating the status update event 3769 3. Checking if the async loop is available and running 3770 4. Scheduling the publish operation on the async loop 3771 3772 Args: 3773 a2a_context: The A2A context dictionary containing logical_task_id and contextId 3774 signal_data: A Pydantic BaseModel instance (e.g., AgentProgressUpdateData, 3775 DeepResearchProgressData, ArtifactCreationProgressData) 3776 skip_buffer_flush: If True, skip buffer flushing before publishing 3777 log_identifier: Optional log identifier for debugging 3778 3779 Returns: 3780 bool: True if the publish was successfully scheduled, False otherwise 3781 """ 3782 from ...common import a2a 3783 3784 log_id = log_identifier or f"{self.log_identifier}[PublishDataSignal]" 3785 3786 if not a2a_context: 3787 log.error("%s No a2a_context provided. Cannot publish data signal.", log_id) 3788 return False 3789 3790 logical_task_id = a2a_context.get("logical_task_id") 3791 context_id = a2a_context.get("contextId") 3792 3793 if not logical_task_id: 3794 log.error("%s No logical_task_id in a2a_context. Cannot publish data signal.", log_id) 3795 return False 3796 3797 # Create status update event using the standard data signal pattern 3798 status_update_event = a2a.create_data_signal_event( 3799 task_id=logical_task_id, 3800 context_id=context_id, 3801 signal_data=signal_data, 3802 agent_name=self.agent_name, 3803 ) 3804 3805 # Get the async loop and schedule the publish 3806 loop = self.get_async_loop() 3807 if loop and loop.is_running(): 3808 asyncio.run_coroutine_threadsafe( 3809 self._publish_status_update_with_buffer_flush( 3810 status_update_event, 3811 a2a_context, 3812 skip_buffer_flush=skip_buffer_flush, 3813 ), 3814 loop, 3815 ) 3816 log.debug( 3817 "%s Scheduled data signal status update (type: %s).", 3818 log_id, 3819 type(signal_data).__name__, 3820 ) 3821 return True 3822 else: 3823 log.error( 3824 "%s Async loop not available or not running. Cannot publish data signal.", 3825 log_id, 3826 ) 3827 return False 3828 3829 def set_agent_system_instruction_string(self, instruction_string: str) -> None: 3830 """ 3831 Sets a static string to be injected into the LLM system prompt. 3832 Called by the agent's init_function. 3833 """ 3834 if not isinstance(instruction_string, str): 3835 log.error( 3836 "%s Invalid type for instruction_string: %s. Must be a string.", 3837 self.log_identifier, 3838 type(instruction_string), 3839 ) 3840 return 3841 self._agent_system_instruction_string = instruction_string 3842 self._agent_system_instruction_callback = None 3843 log.info("%s Static agent system instruction string set.", self.log_identifier) 3844 3845 def set_agent_system_instruction_callback( 3846 self, 3847 callback_function: Optional[ 3848 Callable[[CallbackContext, LlmRequest], Optional[str]] 3849 ], 3850 ) -> None: 3851 """ 3852 Sets a callback function to dynamically generate system prompt injections. 3853 Called by the agent's init_function. 3854 """ 3855 if callback_function is not None and not callable(callback_function): 3856 log.error( 3857 "%s Invalid type for callback_function: %s. Must be callable.", 3858 self.log_identifier, 3859 type(callback_function), 3860 ) 3861 return 3862 self._agent_system_instruction_callback = callback_function 3863 self._agent_system_instruction_string = None 3864 log.info("%s Agent system instruction callback set.", self.log_identifier) 3865 3866 def get_gateway_id(self) -> str: 3867 """ 3868 Returns a unique identifier for this specific gateway/host instance. 3869 For now, using the agent name, but could be made more robust (e.g., hostname + agent name). 3870 """ 3871 return self.agent_name 3872 3873 def _check_agent_health(self): 3874 """ 3875 Checks the health of peer agents and de-registers unresponsive ones. 3876 This is called periodically by the health check timer. 3877 Uses TTL-based expiration to determine if an agent is unresponsive. 3878 """ 3879 3880 log.debug("%s Performing agent health check...", self.log_identifier) 3881 3882 ttl_seconds = self.agent_discovery_config.get( 3883 "health_check_ttl_seconds", HEALTH_CHECK_TTL_SECONDS 3884 ) 3885 health_check_interval = self.agent_discovery_config.get( 3886 "health_check_interval_seconds", HEALTH_CHECK_INTERVAL_SECONDS 3887 ) 3888 3889 log.debug( 3890 "%s Health check configuration: interval=%d seconds, TTL=%d seconds", 3891 self.log_identifier, 3892 health_check_interval, 3893 ttl_seconds, 3894 ) 3895 3896 # Validate configuration values 3897 if ( 3898 ttl_seconds <= 0 3899 or health_check_interval <= 0 3900 or ttl_seconds < health_check_interval 3901 ): 3902 log.error( 3903 "%s agent_health_check_ttl_seconds (%d) and agent_health_check_interval_seconds (%d) must be positive and TTL must be greater than interval.", 3904 self.log_identifier, 3905 ttl_seconds, 3906 health_check_interval, 3907 ) 3908 raise ValueError( 3909 f"Invalid health check configuration. agent_health_check_ttl_seconds ({ttl_seconds}) and agent_health_check_interval_seconds ({health_check_interval}) must be positive and TTL must be greater than interval." 3910 ) 3911 3912 # Get all agent names from the registry 3913 agent_names = self.agent_registry.get_agent_names() 3914 total_agents = len(agent_names) 3915 agents_to_deregister = [] 3916 3917 log.debug( 3918 "%s Checking health of %d peer agents", self.log_identifier, total_agents 3919 ) 3920 3921 for agent_name in agent_names: 3922 # Skip our own agent 3923 if agent_name == self.agent_name: 3924 continue 3925 3926 # Check if the agent's TTL has expired 3927 is_expired, time_since_last_seen = self.agent_registry.check_ttl_expired( 3928 agent_name, ttl_seconds 3929 ) 3930 3931 if is_expired: 3932 log.warning( 3933 "%s Agent '%s' TTL has expired. De-registering. Time since last seen: %d seconds (TTL: %d seconds)", 3934 self.log_identifier, 3935 agent_name, 3936 time_since_last_seen, 3937 ttl_seconds, 3938 ) 3939 agents_to_deregister.append(agent_name) 3940 3941 # De-register unresponsive agents 3942 for agent_name in agents_to_deregister: 3943 self._deregister_agent(agent_name) 3944 3945 log.debug( 3946 "%s Agent health check completed. Total agents: %d, De-registered: %d", 3947 self.log_identifier, 3948 total_agents, 3949 len(agents_to_deregister), 3950 ) 3951 3952 def _deregister_agent(self, agent_name: str): 3953 """ 3954 De-registers an agent from the registry and publishes a de-registration event. 3955 """ 3956 # Remove from registry 3957 registry_removed = self.agent_registry.remove_agent(agent_name) 3958 3959 # Always remove from peer_agents regardless of registry result 3960 peer_removed = False 3961 if agent_name in self.peer_agents: 3962 del self.peer_agents[agent_name] 3963 peer_removed = True 3964 log.info( 3965 "%s Removed agent '%s' from peer_agents dictionary", 3966 self.log_identifier, 3967 agent_name, 3968 ) 3969 3970 # Publish de-registration event if agent was in either data structure 3971 if registry_removed or peer_removed: 3972 try: 3973 # Create a de-registration event topic 3974 namespace = self.get_config("namespace") 3975 deregistration_topic = f"{namespace}/a2a/events/agent/deregistered" 3976 3977 current_time = time.time() 3978 3979 # Create the payload 3980 deregistration_payload = { 3981 "event_type": "agent.deregistered", 3982 "agent_name": agent_name, 3983 "reason": "health_check_failure", 3984 "metadata": { 3985 "timestamp": current_time, 3986 "deregistered_by": self.agent_name, 3987 }, 3988 } 3989 3990 # Publish the event 3991 self.publish_a2a_message( 3992 payload=deregistration_payload, topic=deregistration_topic 3993 ) 3994 3995 log.info( 3996 "%s Published de-registration event for agent '%s' to topic '%s'", 3997 self.log_identifier, 3998 agent_name, 3999 deregistration_topic, 4000 ) 4001 except Exception as e: 4002 log.error( 4003 "%s Failed to publish de-registration event for agent '%s': %s", 4004 self.log_identifier, 4005 agent_name, 4006 e, 4007 ) 4008 4009 async def _resolve_early_embeds_and_handle_signals( 4010 self, raw_text: str, a2a_context: Dict 4011 ) -> Tuple[str, List[Tuple[int, Any]], str]: 4012 """ 4013 Resolves early-stage embeds in raw text and extracts signals. 4014 Returns the resolved text, a list of signals, and any unprocessed tail. 4015 This is called by process_and_publish_adk_event. 4016 """ 4017 logical_task_id = a2a_context.get("logical_task_id", "unknown_task") 4018 method_context_log_identifier = ( 4019 f"{self.log_identifier}[ResolveEmbeds:{logical_task_id}]" 4020 ) 4021 log.debug( 4022 "%s Resolving early embeds for text (length: %d).", 4023 method_context_log_identifier, 4024 len(raw_text), 4025 ) 4026 4027 original_session_id = a2a_context.get("session_id") 4028 user_id = a2a_context.get("user_id") 4029 adk_app_name = self.get_config("agent_name") 4030 4031 if not all([self.artifact_service, original_session_id, user_id, adk_app_name]): 4032 log.error( 4033 "%s Missing necessary context for embed resolution (artifact_service, session_id, user_id, or adk_app_name). Skipping.", 4034 method_context_log_identifier, 4035 ) 4036 return ( 4037 raw_text, 4038 [], 4039 "", 4040 ) 4041 context_for_embeds = { 4042 "artifact_service": self.artifact_service, 4043 "session_context": { 4044 "app_name": adk_app_name, 4045 "user_id": user_id, 4046 "session_id": original_session_id, 4047 }, 4048 "config": { 4049 "gateway_max_artifact_resolve_size_bytes": self.get_config( 4050 "tool_output_llm_return_max_bytes", 4096 4051 ), 4052 "gateway_recursive_embed_depth": self.get_config( 4053 "gateway_recursive_embed_depth", 12 4054 ), 4055 }, 4056 } 4057 4058 resolver_config = context_for_embeds["config"] 4059 4060 try: 4061 from ...common.utils.embeds.constants import EARLY_EMBED_TYPES 4062 from ...common.utils.embeds.types import ResolutionMode 4063 from ...common.utils.embeds.resolver import ( 4064 evaluate_embed, 4065 resolve_embeds_in_string, 4066 ) 4067 4068 resolved_text, processed_until_index, signals_found = ( 4069 await resolve_embeds_in_string( 4070 text=raw_text, 4071 context=context_for_embeds, 4072 resolver_func=evaluate_embed, 4073 types_to_resolve=EARLY_EMBED_TYPES, 4074 resolution_mode=ResolutionMode.TOOL_PARAMETER, 4075 log_identifier=method_context_log_identifier, 4076 config=resolver_config, 4077 ) 4078 ) 4079 unprocessed_tail = raw_text[processed_until_index:] 4080 log.debug( 4081 "%s Embed resolution complete. Resolved text: '%s...', Signals found: %d, Unprocessed tail: '%s...'", 4082 method_context_log_identifier, 4083 resolved_text[:100], 4084 len(signals_found), 4085 unprocessed_tail[:100], 4086 ) 4087 return resolved_text, signals_found, unprocessed_tail 4088 except Exception as e: 4089 log.exception( 4090 "%s Error during embed resolution: %s", method_context_log_identifier, e 4091 ) 4092 return raw_text, [], "" 4093 4094 def _publish_agent_card(self) -> None: 4095 """ 4096 Schedules periodic publishing of the agent card based on configuration. 4097 """ 4098 try: 4099 publish_interval_sec = self.agent_card_publishing_config.get( 4100 "interval_seconds" 4101 ) 4102 if publish_interval_sec and publish_interval_sec > 0: 4103 log.info( 4104 "%s Scheduling agent card publishing every %d seconds.", 4105 self.log_identifier, 4106 publish_interval_sec, 4107 ) 4108 # Register timer with callback 4109 self.add_timer( 4110 delay_ms=1000, 4111 timer_id=self._card_publish_timer_id, 4112 interval_ms=publish_interval_sec * 1000, 4113 callback=lambda timer_data: publish_agent_card(self), 4114 ) 4115 else: 4116 log.warning( 4117 "%s Agent card publishing interval not configured or invalid, card will not be published periodically.", 4118 self.log_identifier, 4119 ) 4120 except Exception as e: 4121 log.exception( 4122 "%s Error during _publish_agent_card setup: %s", 4123 self.log_identifier, 4124 e, 4125 ) 4126 raise e 4127 4128 async def _async_setup_and_run(self) -> None: 4129 """ 4130 Main async logic for the agent component. 4131 This is called by the base class's `_run_async_operations`. 4132 """ 4133 try: 4134 # Call base class to initialize Trust Manager 4135 await super()._async_setup_and_run() 4136 4137 # Perform agent-specific async initialization 4138 await self._perform_async_init() 4139 4140 self._publish_agent_card() 4141 4142 except Exception as e: 4143 log.exception( 4144 "%s Error during _async_setup_and_run: %s", 4145 self.log_identifier, 4146 e, 4147 ) 4148 self.cleanup() 4149 raise e 4150 4151 def _pre_async_cleanup(self) -> None: 4152 """ 4153 Pre-cleanup actions for the agent component. 4154 Called by the base class before stopping the async loop. 4155 """ 4156 # Cleanup Trust Manager if present (ENTERPRISE FEATURE) 4157 if self.trust_manager: 4158 try: 4159 self.trust_manager.cleanup(self.cancel_timer) 4160 except Exception as e: 4161 log.error( 4162 "%s Error during Trust Manager cleanup: %s", self.log_identifier, e 4163 )