workflow_tool.py
1 """ 2 ADK Tool implementation for invoking Workflow agents via A2A. 3 """ 4 5 import logging 6 import json 7 import uuid 8 from datetime import datetime, timezone 9 from typing import Any, Dict, Optional, Tuple 10 11 import jsonschema 12 from google.adk.tools import BaseTool, ToolContext 13 from google.genai import types as adk_types 14 15 from ...common import a2a 16 from ...common.constants import ARTIFACT_TAG_WORKING, DEFAULT_COMMUNICATION_TIMEOUT 17 from ...common.exceptions import MessageSizeExceededError 18 from ...common.data_parts import StructuredInvocationRequest 19 from ...agent.utils.artifact_helpers import ( 20 save_artifact_with_metadata, 21 format_artifact_uri, 22 ) 23 24 log = logging.getLogger(__name__) 25 26 WORKFLOW_TOOL_PREFIX = "workflow_" 27 CORRELATION_DATA_PREFIX = "a2a_subtask_" 28 29 30 class WorkflowAgentTool(BaseTool): 31 """ 32 An ADK Tool that represents a discovered Workflow agent. 33 Supports dual-mode invocation: 34 1. Parameter Mode: Pass arguments directly (validated against schema). 35 2. Artifact Mode: Pass an 'input_artifact' reference. 36 """ 37 38 is_long_running = True 39 40 def __init__( 41 self, 42 target_agent_name: str, 43 input_schema: Dict[str, Any], 44 host_component, 45 ): 46 """ 47 Initializes the WorkflowAgentTool. 48 49 Args: 50 target_agent_name: The name of the workflow agent. 51 input_schema: The JSON schema defining the workflow's input parameters. 52 host_component: A reference to the SamAgentComponent instance. 53 """ 54 tool_name = f"{WORKFLOW_TOOL_PREFIX}{target_agent_name}" 55 # Sanitize tool name if necessary (replace hyphens with underscores) 56 tool_name = tool_name.replace("-", "_") 57 58 super().__init__( 59 name=tool_name, 60 description=f"Invoke the '{target_agent_name}' workflow.", 61 is_long_running=True, 62 ) 63 self.target_agent_name = target_agent_name 64 self.input_schema = input_schema 65 self.host_component = host_component 66 self.log_identifier = ( 67 f"{host_component.log_identifier}[WorkflowTool:{target_agent_name}]" 68 ) 69 70 def _json_schema_to_adk_schema( 71 self, json_schema: Dict[str, Any], nullable: bool = False 72 ) -> adk_types.Schema: 73 """ 74 Recursively converts a JSON schema to an ADK Schema, preserving nested structure. 75 76 Args: 77 json_schema: The JSON schema definition to convert. 78 nullable: Whether the schema should be nullable. 79 80 Returns: 81 An ADK Schema object representing the JSON schema. 82 """ 83 json_type = json_schema.get("type", "string") 84 description = json_schema.get("description", "") 85 86 # Map JSON schema type to ADK type 87 type_mapping = { 88 "string": adk_types.Type.STRING, 89 "integer": adk_types.Type.INTEGER, 90 "number": adk_types.Type.NUMBER, 91 "boolean": adk_types.Type.BOOLEAN, 92 "array": adk_types.Type.ARRAY, 93 "object": adk_types.Type.OBJECT, 94 } 95 adk_type = type_mapping.get(json_type, adk_types.Type.STRING) 96 97 schema_kwargs = { 98 "type": adk_type, 99 "description": description, 100 "nullable": nullable, 101 } 102 103 # Handle array items 104 if json_type == "array" and "items" in json_schema: 105 schema_kwargs["items"] = self._json_schema_to_adk_schema( 106 json_schema["items"] 107 ) 108 109 # Handle object properties 110 if json_type == "object" and "properties" in json_schema: 111 nested_properties = {} 112 for prop_name, prop_def in json_schema["properties"].items(): 113 nested_properties[prop_name] = self._json_schema_to_adk_schema(prop_def) 114 schema_kwargs["properties"] = nested_properties 115 116 return adk_types.Schema(**schema_kwargs) 117 118 def _get_declaration(self) -> adk_types.FunctionDeclaration: 119 """ 120 Dynamically generates the FunctionDeclaration based on the workflow's input schema. 121 Adds 'input_artifact' as an optional parameter and marks all parameters as optional 122 to support dual-mode invocation. 123 """ 124 properties = self.input_schema.get("properties", {}) 125 adk_properties = {} 126 127 # Add input_artifact parameter 128 adk_properties["input_artifact"] = adk_types.Schema( 129 type=adk_types.Type.STRING, 130 description="Filename of an existing artifact containing the input JSON data. Use this OR individual parameters.", 131 nullable=True, 132 ) 133 134 for prop_name, prop_def in properties.items(): 135 adk_properties[prop_name] = self._json_schema_to_adk_schema( 136 prop_def, nullable=True 137 ) 138 139 parameters_schema = adk_types.Schema( 140 type=adk_types.Type.OBJECT, 141 properties=adk_properties, 142 required=[], # All optional 143 ) 144 145 return adk_types.FunctionDeclaration( 146 name=self.name, 147 description=f"Invoke the '{self.target_agent_name}' workflow. Dual-mode: provide parameters directly OR 'input_artifact'.", 148 parameters=parameters_schema, 149 ) 150 151 async def run_async( 152 self, *, args: Dict[str, Any], tool_context: ToolContext 153 ) -> Any: 154 """ 155 Handles the workflow invocation. 156 """ 157 sub_task_id = f"{CORRELATION_DATA_PREFIX}{uuid.uuid4().hex}" 158 log_identifier = f"{self.log_identifier}[SubTask:{sub_task_id}]" 159 160 try: 161 # 1. Prepare Input Artifact 162 try: 163 ( 164 payload_artifact_name, 165 payload_artifact_version, 166 ) = await self._prepare_input_artifact( 167 args, tool_context, log_identifier 168 ) 169 except jsonschema.ValidationError as e: 170 log.warning( 171 "%s Input validation failed | message=%s", 172 log_identifier, 173 e.message, 174 ) 175 error_response = { 176 "status": "error", 177 "message": f"Input validation failed: {e.message}. Please provide required parameters or use input_artifact.", 178 } 179 return error_response 180 181 # 2. Prepare Context 182 original_task_context = tool_context.state.get("a2a_context", {}) 183 main_logical_task_id = original_task_context.get( 184 "logical_task_id", "unknown_task" 185 ) 186 invocation_id = tool_context._invocation_context.invocation_id 187 user_id = tool_context._invocation_context.user_id 188 user_config = original_task_context.get("a2a_user_config", {}) 189 190 # 3. Prepare Message 191 session_id = tool_context._invocation_context.session.id 192 a2a_message = self._prepare_a2a_message( 193 payload_artifact_name, 194 payload_artifact_version, 195 user_id, 196 session_id, 197 main_logical_task_id, 198 original_task_context, 199 ) 200 201 # 4. Submit Task 202 try: 203 self._submit_workflow_task( 204 sub_task_id, 205 main_logical_task_id, 206 invocation_id, 207 tool_context, 208 original_task_context, 209 a2a_message, 210 user_id, 211 user_config, 212 log_identifier, 213 ) 214 except MessageSizeExceededError as e: 215 log.error("%s Message size exceeded: %s", log_identifier, e) 216 return { 217 "status": "error", 218 "message": f"Error: {str(e)}. Message size exceeded.", 219 } 220 221 return None # Fire-and-forget 222 223 except Exception as e: 224 log.exception("%s Error in WorkflowAgentTool: %s", log_identifier, e) 225 return { 226 "status": "error", 227 "message": f"Failed to invoke workflow '{self.target_agent_name}': {e}", 228 } 229 230 async def _prepare_input_artifact( 231 self, args: Dict[str, Any], tool_context: ToolContext, log_identifier: str 232 ) -> Tuple[str, Optional[int]]: 233 """ 234 Determines input mode, validates parameters, and creates implicit artifact if needed. 235 Returns (artifact_name, artifact_version). 236 """ 237 input_artifact_name = args.get("input_artifact") 238 239 if input_artifact_name: 240 log.info( 241 "%s Invoking in Artifact Mode with '%s'", 242 log_identifier, 243 input_artifact_name, 244 ) 245 return input_artifact_name, None 246 247 # Parameter Mode - Validate against strict schema 248 try: 249 jsonschema.validate(instance=args, schema=self.input_schema) 250 except jsonschema.ValidationError as ve: 251 log.warning( 252 "%s Schema validation failed | error=%s | path=%s", 253 log_identifier, 254 ve.message, 255 list(ve.absolute_path), 256 ) 257 raise 258 259 # Create implicit artifact 260 payload_data = args 261 payload_bytes = json.dumps(payload_data).encode("utf-8") 262 263 # Generate unique filename using UUID to avoid collisions in parallel invocations 264 sanitized_wf_name = "".join( 265 c for c in self.target_agent_name if c.isalnum() or c in "_-" 266 ) 267 unique_suffix = uuid.uuid4().hex[:8] 268 payload_artifact_name = f"wi_{sanitized_wf_name}_{unique_suffix}.json" 269 270 # Save artifact 271 user_id = tool_context._invocation_context.user_id 272 session_id = tool_context._invocation_context.session.id 273 274 save_result = await save_artifact_with_metadata( 275 artifact_service=self.host_component.artifact_service, 276 app_name=self.host_component.agent_name, 277 user_id=user_id, 278 session_id=session_id, 279 filename=payload_artifact_name, 280 content_bytes=payload_bytes, 281 mime_type="application/json", 282 metadata_dict={ 283 "description": f"Auto-generated input for workflow '{self.target_agent_name}'", 284 "source": "workflow_tool_implicit_creation", 285 }, 286 timestamp=datetime.now(timezone.utc), 287 tags=[ARTIFACT_TAG_WORKING], 288 ) 289 290 if save_result["status"] != "success": 291 raise RuntimeError( 292 f"Failed to save implicit input artifact: {save_result.get('message')}" 293 ) 294 295 payload_artifact_version = save_result.get("data_version") 296 297 log.info( 298 "%s Created implicit input artifact: %s v%s", 299 log_identifier, 300 payload_artifact_name, 301 payload_artifact_version, 302 ) 303 304 return payload_artifact_name, payload_artifact_version 305 306 def _prepare_a2a_message( 307 self, 308 payload_artifact_name: str, 309 payload_artifact_version: Optional[int], 310 user_id: str, 311 session_id: str, 312 main_logical_task_id: str, 313 original_task_context: Dict[str, Any], 314 ) -> Any: 315 """Constructs the A2A message with StructuredInvocationRequest and FilePart.""" 316 parts = [] 317 318 # 1. Add StructuredInvocationRequest DataPart (triggers structured invocation) 319 invocation_request = StructuredInvocationRequest( 320 type="structured_invocation_request", 321 workflow_name=self.host_component.agent_name, 322 node_id=f"workflow_tool_{self.target_agent_name}", 323 input_schema=self.input_schema, 324 output_schema=None, # Workflow defines its own output schema 325 suggested_output_filename=f"wf_{self.target_agent_name}_result.json", 326 ) 327 parts.append(a2a.create_data_part(data=invocation_request.model_dump())) 328 329 # 2. Add FilePart with artifact URI (contains the input data) 330 if payload_artifact_name and payload_artifact_version is not None: 331 uri = format_artifact_uri( 332 app_name=self.host_component.agent_name, 333 user_id=user_id, 334 session_id=session_id, 335 filename=payload_artifact_name, 336 version=payload_artifact_version, 337 ) 338 parts.append( 339 a2a.create_file_part_from_uri( 340 uri=uri, 341 name=payload_artifact_name, 342 mime_type="application/json", 343 ) 344 ) 345 346 a2a_metadata = { 347 "sessionBehavior": "RUN_BASED", 348 "parentTaskId": main_logical_task_id, 349 "agent_name": self.target_agent_name, 350 } 351 352 return a2a.create_user_message( 353 parts=parts, 354 metadata=a2a_metadata, 355 context_id=original_task_context.get("contextId"), 356 ) 357 358 def _submit_workflow_task( 359 self, 360 sub_task_id: str, 361 main_logical_task_id: str, 362 invocation_id: str, 363 tool_context: ToolContext, 364 original_task_context: Dict[str, Any], 365 a2a_message: Any, 366 user_id: str, 367 user_config: Dict[str, Any], 368 log_identifier: str, 369 ): 370 """Handles task registration, correlation data, and submission.""" 371 # Register parallel call 372 task_context_obj = None 373 with self.host_component.active_tasks_lock: 374 task_context_obj = self.host_component.active_tasks.get( 375 main_logical_task_id 376 ) 377 378 if not task_context_obj: 379 log.error( 380 "%s TaskExecutionContext NOT FOUND | main_task_id=%s | active_tasks_keys=%s", 381 log_identifier, 382 main_logical_task_id, 383 list(self.host_component.active_tasks.keys()), 384 ) 385 raise ValueError( 386 f"TaskExecutionContext not found for task '{main_logical_task_id}'" 387 ) 388 389 # NOTE: register_parallel_call_sent is now called in 390 # preregister_long_running_tools_callback (after_model_callback) 391 # BEFORE tool execution begins. This prevents race conditions where 392 # one tool completes before another registers. 393 394 # Submit Task 395 correlation_data = { 396 "adk_function_call_id": tool_context.function_call_id, 397 "original_task_context": original_task_context, 398 "peer_tool_name": self.name, 399 "peer_agent_name": self.target_agent_name, 400 "logical_task_id": main_logical_task_id, 401 "invocation_id": invocation_id, 402 } 403 404 task_context_obj.register_peer_sub_task(sub_task_id, correlation_data) 405 406 timeout_sec = self.host_component.get_config( 407 "inter_agent_communication", {} 408 ).get("request_timeout_seconds", DEFAULT_COMMUNICATION_TIMEOUT) 409 410 self.host_component.cache_service.add_data( 411 key=sub_task_id, 412 value=main_logical_task_id, 413 expiry=timeout_sec, 414 component=self.host_component, 415 ) 416 417 self.host_component.submit_a2a_task( 418 target_agent_name=self.target_agent_name, 419 a2a_message=a2a_message, 420 user_id=user_id, 421 user_config=user_config, 422 sub_task_id=sub_task_id, 423 ) 424 425 log.info( 426 "%s Workflow task submitted for agent: %s", log_identifier, self.target_agent_name 427 )