agent_caller.py
1 """ 2 AgentCaller component for invoking agents via A2A. 3 """ 4 5 import logging 6 import uuid 7 import re 8 import json 9 from datetime import datetime, timezone 10 from typing import Any, Dict, Optional, TYPE_CHECKING 11 12 from a2a.types import MessageSendParams, SendMessageRequest, Message as A2AMessage 13 14 from ..common import a2a 15 from ..common.constants import ARTIFACT_TAG_WORKING 16 from ..common.data_parts import StructuredInvocationRequest 17 from ..common.agent_card_utils import get_schemas_from_agent_card 18 from ..agent.utils.artifact_helpers import ( 19 save_artifact_with_metadata, 20 format_artifact_uri, 21 ) 22 from .app import WorkflowNode, WorkflowInvokeNode 23 from .workflow_execution_context import WorkflowExecutionContext, WorkflowExecutionState 24 25 if TYPE_CHECKING: 26 from .component import WorkflowExecutorComponent 27 28 log = logging.getLogger(__name__) 29 30 31 class AgentCaller: 32 """Manages A2A calls to agents from workflow.""" 33 34 def __init__(self, host_component: "WorkflowExecutorComponent"): 35 self.host = host_component 36 37 def _resolve_string_with_templates( 38 self, template_string: str, workflow_state: WorkflowExecutionState 39 ) -> Optional[str]: 40 """ 41 Resolve a string that may contain embedded template expressions. 42 43 Unlike dag_executor.resolve_value which only handles strings that ARE templates, 44 this method handles strings that CONTAIN templates (e.g., "Hello {{name}}!"). 45 46 Args: 47 template_string: A string that may contain {{...}} template expressions 48 workflow_state: Current workflow state for resolving variables 49 50 Returns: 51 The string with all template expressions resolved, or None if resolution fails 52 """ 53 if not template_string: 54 return None 55 56 # Pattern to match {{...}} template expressions 57 template_pattern = re.compile(r"\{\{\s*(.+?)\s*\}\}") 58 59 def replace_template(match: re.Match) -> str: 60 """Replace a single template match with its resolved value.""" 61 full_match = match.group(0) # The full {{...}} string 62 try: 63 # Use dag_executor to resolve the full template 64 resolved = self.host.dag_executor.resolve_value( 65 full_match, workflow_state 66 ) 67 if resolved is None: 68 # Keep the original template if resolution fails 69 return full_match 70 return str(resolved) 71 except Exception as e: 72 log.warning( 73 f"{self.host.log_identifier} Failed to resolve template " 74 f"'{full_match}': {e}" 75 ) 76 return full_match 77 78 # Replace all template expressions in the string 79 resolved = template_pattern.sub(replace_template, template_string) 80 return resolved 81 82 async def call_agent( 83 self, 84 node: WorkflowNode, 85 workflow_state: WorkflowExecutionState, 86 workflow_context: WorkflowExecutionContext, 87 sub_task_id: Optional[str] = None, 88 ) -> str: 89 """ 90 Invoke an agent for a workflow node. 91 Returns sub-task ID for correlation. 92 """ 93 log_id = f"{self.host.log_identifier}[CallAgent:{node.agent_name}]" 94 95 # Generate sub-task ID if not provided 96 if not sub_task_id: 97 sub_task_id = ( 98 f"wf_{workflow_state.execution_id}_{node.id}_{uuid.uuid4().hex[:8]}" 99 ) 100 # Resolve input data 101 input_data = await self._resolve_node_input(node, workflow_state) 102 103 # Resolve instruction template if present 104 # Handles both full templates ({{...}}) and embedded templates within strings 105 resolved_instruction = None 106 if hasattr(node, "instruction") and node.instruction: 107 resolved_instruction = self._resolve_string_with_templates( 108 node.instruction, workflow_state 109 ) 110 111 # Get agent card - required for proper structured invocation 112 agent_card = self.host.agent_registry.get_agent(node.agent_name) 113 if not agent_card: 114 raise ValueError( 115 f"Agent '{node.agent_name}' not found in registry. " 116 f"Ensure the agent is running and has published its agent card before " 117 f"starting the workflow." 118 ) 119 120 # Get schemas from agent card extensions 121 card_input_schema, card_output_schema = get_schemas_from_agent_card(agent_card) 122 123 # Use override schemas if provided, otherwise use schemas from agent card 124 input_schema = node.input_schema_override or card_input_schema 125 output_schema = node.output_schema_override or card_output_schema 126 127 # Construct A2A message 128 message = await self._construct_agent_message( 129 node, 130 input_data, 131 input_schema, 132 output_schema, 133 workflow_state, 134 sub_task_id, 135 workflow_context, 136 resolved_instruction, 137 ) 138 139 # Publish request 140 await self._publish_agent_request( 141 node.agent_name, message, sub_task_id, workflow_context 142 ) 143 144 # Track in workflow context 145 workflow_context.track_agent_call(node.id, sub_task_id) 146 147 return sub_task_id 148 149 async def call_workflow( 150 self, 151 node: WorkflowInvokeNode, 152 workflow_state: WorkflowExecutionState, 153 workflow_context: WorkflowExecutionContext, 154 sub_task_id: Optional[str] = None, 155 ) -> str: 156 """ 157 Invoke a sub-workflow. 158 159 Workflows register as agents, so this method adapts the workflow node 160 to use the agent calling mechanism. 161 162 Returns sub-task ID for correlation. 163 """ 164 log_id = f"{self.host.log_identifier}[CallWorkflow:{node.workflow_name}]" 165 166 # Generate sub-task ID if not provided 167 if not sub_task_id: 168 sub_task_id = ( 169 f"wf_{workflow_state.execution_id}_{node.id}_{uuid.uuid4().hex[:8]}" 170 ) 171 172 # Create an adapter object that makes WorkflowInvokeNode compatible 173 # with the existing _resolve_node_input and _construct_agent_message methods 174 class WorkflowNodeAdapter: 175 """Adapter to make WorkflowInvokeNode work with agent calling infrastructure.""" 176 177 def __init__(self, wf_node: WorkflowInvokeNode): 178 self.id = wf_node.id 179 self.type = "workflow" 180 self.agent_name = wf_node.workflow_name # Map workflow_name to agent_name 181 self.input = wf_node.input 182 self.instruction = wf_node.instruction 183 self.input_schema_override = wf_node.input_schema_override 184 self.output_schema_override = wf_node.output_schema_override 185 self.depends_on = wf_node.depends_on 186 187 adapted_node = WorkflowNodeAdapter(node) 188 189 # Resolve input data 190 input_data = await self._resolve_node_input(adapted_node, workflow_state) 191 192 # Resolve instruction template if present 193 resolved_instruction = None 194 if node.instruction: 195 resolved_instruction = self._resolve_string_with_templates( 196 node.instruction, workflow_state 197 ) 198 199 # Get agent card - required for proper structured invocation 200 # Workflows publish their schemas in their agent cards 201 agent_card = self.host.agent_registry.get_agent(node.workflow_name) 202 if not agent_card: 203 raise ValueError( 204 f"Workflow '{node.workflow_name}' not found in registry. " 205 f"Ensure the sub-workflow is running and has published its agent card before " 206 f"starting the parent workflow." 207 ) 208 209 # Get schemas from agent card extensions 210 card_input_schema, card_output_schema = get_schemas_from_agent_card(agent_card) 211 212 # Use override schemas if provided, otherwise use schemas from agent card 213 input_schema = node.input_schema_override or card_input_schema 214 output_schema = node.output_schema_override or card_output_schema 215 216 # Construct A2A message 217 message = await self._construct_agent_message( 218 adapted_node, 219 input_data, 220 input_schema, 221 output_schema, 222 workflow_state, 223 sub_task_id, 224 workflow_context, 225 resolved_instruction, 226 ) 227 228 # Publish request to the sub-workflow 229 await self._publish_agent_request( 230 node.workflow_name, message, sub_task_id, workflow_context 231 ) 232 233 # Track in workflow context 234 workflow_context.track_agent_call(node.id, sub_task_id) 235 236 log.info( 237 f"{log_id} Invoked sub-workflow '{node.workflow_name}' (sub_task_id: {sub_task_id})" 238 ) 239 240 return sub_task_id 241 242 async def _resolve_node_input( 243 self, node: WorkflowNode, workflow_state: WorkflowExecutionState 244 ) -> Dict[str, Any]: 245 """ 246 Resolve input mapping for a node. 247 If input is not provided, infer it from dependencies. 248 """ 249 # Case 1: Explicit Input Mapping 250 if node.input is not None: 251 resolved_input = {} 252 for key, value in node.input.items(): 253 # Use DAGExecutor's resolve_value to handle templates and operators 254 resolved_value = self.host.dag_executor.resolve_value( 255 value, workflow_state 256 ) 257 resolved_input[key] = resolved_value 258 return resolved_input 259 260 # Case 2: Implicit Input Inference 261 log.debug( 262 f"{self.host.log_identifier} Node '{node.id}' has no explicit input. Inferring from dependencies." 263 ) 264 265 # Case 2a: No dependencies (Initial Node) -> Use Workflow Input 266 if not node.depends_on: 267 if "workflow_input" not in workflow_state.node_outputs: 268 raise ValueError("Workflow input has not been initialized") 269 return workflow_state.node_outputs["workflow_input"]["output"] 270 271 # Case 2b: Single Dependency -> Use Dependency Output 272 if len(node.depends_on) == 1: 273 dep_id = node.depends_on[0] 274 275 # Check if dependency is a switch node - use workflow input instead of switch metadata 276 dep_node = self.host.dag_executor.nodes.get(dep_id) 277 if dep_node and dep_node.type == "switch": 278 log.debug( 279 f"{self.host.log_identifier} Node '{node.id}' depends on switch '{dep_id}'. Using workflow input." 280 ) 281 if "workflow_input" not in workflow_state.node_outputs: 282 raise ValueError("Workflow input has not been initialized") 283 return workflow_state.node_outputs["workflow_input"]["output"] 284 285 if dep_id not in workflow_state.node_outputs: 286 raise ValueError(f"Dependency '{dep_id}' has not completed") 287 return workflow_state.node_outputs[dep_id]["output"] 288 289 # Case 2c: Multiple Dependencies -> Ambiguous 290 raise ValueError( 291 f"Node '{node.id}' has multiple dependencies {node.depends_on} but no explicit 'input' mapping. " 292 "Implicit input inference is only supported for nodes with 0 or 1 dependency. " 293 "Please provide an explicit 'input' mapping." 294 ) 295 296 def _generate_result_embed_reminder( 297 self, output_schema: Optional[Dict[str, Any]] 298 ) -> str: 299 """Generate user-facing reminder about result embed requirement.""" 300 if output_schema: 301 return """ 302 REMINDER: When you complete this task, you MUST end your response with: 303 «result:artifact=<your_artifact_name>:<version> status=success» 304 305 For example: «result:artifact=analysis_results.json:0 status=success» 306 307 This is required for the workflow to continue. Without this result embed, the workflow will fail. 308 """ 309 else: 310 return """ 311 REMINDER: When you complete this task, you MUST end your response with: 312 «result:artifact=<your_artifact_name>:<version> status=success» 313 314 This is MANDATORY for the workflow to continue. 315 """ 316 317 async def _construct_agent_message( 318 self, 319 node: WorkflowNode, 320 input_data: Dict[str, Any], 321 input_schema: Optional[Dict[str, Any]], 322 output_schema: Optional[Dict[str, Any]], 323 workflow_state: WorkflowExecutionState, 324 sub_task_id: str, 325 workflow_context: WorkflowExecutionContext, 326 resolved_instruction: Optional[str] = None, 327 ) -> A2AMessage: 328 """Construct A2A message for agent.""" 329 330 # Build message parts 331 parts = [] 332 333 # Generate unique output filename for this workflow node 334 # Use last 8 chars of sub_task_id for uniqueness (contains UUID) 335 unique_suffix = sub_task_id[-8:] if len(sub_task_id) >= 8 else sub_task_id 336 # Sanitize workflow name (replace spaces/special chars with underscore) 337 safe_workflow_name = re.sub( 338 r"[^a-zA-Z0-9_-]", "_", workflow_state.workflow_name 339 ) 340 # node.id already includes iteration index for map nodes (e.g., "generate_data_0") 341 suggested_output_filename = f"{safe_workflow_name}_{node.id}_{unique_suffix}.json" 342 343 # 1. Structured invocation request (must be first) 344 invocation_request = StructuredInvocationRequest( 345 type="structured_invocation_request", 346 workflow_name=workflow_state.workflow_name, 347 node_id=node.id, 348 input_schema=input_schema, 349 output_schema=output_schema, 350 suggested_output_filename=suggested_output_filename, 351 ) 352 parts.append(a2a.create_data_part(data=invocation_request.model_dump())) 353 354 # 2. Add instruction text part if provided 355 if resolved_instruction and resolved_instruction.strip(): 356 parts.append(a2a.create_text_part(text=resolved_instruction)) 357 358 # Determine if we should send as structured artifact or text 359 # For structured invocations (workflow calls), we ALWAYS send input as FilePart 360 # unless it's explicitly a single text schema. This ensures the receiver can 361 # properly handle the structured input even if we don't have the agent's schema 362 # yet (e.g., due to timing issues with agent card discovery). 363 should_send_artifact = True 364 if input_schema: 365 # Only use text mode if schema is explicitly a single text field 366 is_single_text = ( 367 input_schema.get("type") == "object" 368 and len(input_schema.get("properties", {})) == 1 369 and "text" in input_schema.get("properties", {}) 370 and input_schema["properties"]["text"].get("type") == "string" 371 ) 372 if is_single_text: 373 should_send_artifact = False 374 375 if should_send_artifact: 376 # Create and save input artifact, then add FilePart with URI 377 filename = f"input_{node.id}_{sub_task_id}.json" 378 content_bytes = json.dumps(input_data).encode("utf-8") 379 user_id = workflow_context.a2a_context["user_id"] 380 session_id = workflow_context.a2a_context["session_id"] 381 382 try: 383 save_result = await save_artifact_with_metadata( 384 artifact_service=self.host.artifact_service, 385 app_name=self.host.workflow_name, 386 user_id=user_id, 387 session_id=session_id, 388 filename=filename, 389 content_bytes=content_bytes, 390 mime_type="application/json", 391 metadata_dict={ 392 "description": f"Input for node {node.id}", 393 "source": "workflow_execution", 394 }, 395 timestamp=datetime.now(timezone.utc), 396 tags=[ARTIFACT_TAG_WORKING], 397 ) 398 399 if save_result["status"] == "success": 400 version = save_result["data_version"] 401 uri = format_artifact_uri( 402 app_name=self.host.workflow_name, 403 user_id=user_id, 404 session_id=session_id, 405 filename=filename, 406 version=version, 407 ) 408 parts.append( 409 a2a.create_file_part_from_uri( 410 uri=uri, name=filename, mime_type="application/json" 411 ) 412 ) 413 log.info( 414 f"{self.host.log_identifier} Created input artifact for node " 415 f"{node.id}: {filename}" 416 ) 417 else: 418 raise RuntimeError( 419 f"Failed to save input artifact: {save_result.get('message')}" 420 ) 421 422 except Exception as e: 423 log.error( 424 f"{self.host.log_identifier} Error saving input artifact for node " 425 f"{node.id}: {e}" 426 ) 427 raise e 428 429 else: 430 # Send as text/data parts (Chat Mode) 431 if "query" in input_data: 432 parts.append(a2a.create_text_part(text=input_data["query"])) 433 elif "text" in input_data: 434 parts.append(a2a.create_text_part(text=input_data["text"])) 435 else: 436 # Fallback for unstructured data without 'query'/'text' keys 437 text_parts = [] 438 for key, value in input_data.items(): 439 text_parts.append(f"{key}: {value}") 440 if text_parts: 441 parts.append(a2a.create_text_part(text="\n".join(text_parts))) 442 443 # Add reminder about result embed requirement 444 reminder_text = self._generate_result_embed_reminder(output_schema) 445 parts.append(a2a.create_text_part(text=reminder_text)) 446 447 # Construct message using helper function 448 # Use the original workflow session ID as context_id so that RUN_BASED sessions 449 # will be created as {workflow_session_id}:{sub_task_id}:run, allowing the workflow 450 # to find artifacts saved by the node using get_original_session_id() 451 message = a2a.create_user_message( 452 parts=parts, 453 task_id=sub_task_id, 454 context_id=workflow_context.a2a_context["session_id"], 455 metadata={ 456 "workflow_name": workflow_state.workflow_name, 457 "node_id": node.id, 458 "sub_task_id": sub_task_id, 459 "parentTaskId": workflow_context.workflow_task_id, 460 }, 461 ) 462 463 return message 464 465 async def _publish_agent_request( 466 self, 467 agent_name: str, 468 message: A2AMessage, 469 sub_task_id: str, 470 workflow_context: WorkflowExecutionContext, 471 ): 472 """Publish A2A request to agent.""" 473 log_id = f"{self.host.log_identifier}[PublishAgentRequest:{agent_name}]" 474 475 # Get agent request topic 476 request_topic = a2a.get_agent_request_topic(self.host.namespace, agent_name) 477 478 # Create SendMessageRequest 479 send_params = MessageSendParams(message=message) 480 a2a_request = SendMessageRequest(id=sub_task_id, params=send_params) 481 482 # Construct reply-to and status topics 483 reply_to_topic = a2a.get_agent_response_topic( 484 self.host.namespace, self.host.workflow_name, sub_task_id 485 ) 486 status_topic = a2a.get_peer_agent_status_topic( 487 self.host.namespace, self.host.workflow_name, sub_task_id 488 ) 489 490 # Get current call depth and increment for outgoing request 491 current_depth = workflow_context.a2a_context.get("call_depth", 0) 492 493 # User properties 494 user_properties = { 495 "replyTo": reply_to_topic, 496 "a2aStatusTopic": status_topic, 497 "userId": workflow_context.a2a_context["user_id"], 498 "a2aUserConfig": workflow_context.a2a_context.get("a2a_user_config", {}), 499 "callDepth": current_depth + 1, 500 } 501 502 # Publish request 503 self.host.publish_a2a_message( 504 payload=a2a_request.model_dump(by_alias=True, exclude_none=True), 505 topic=request_topic, 506 user_properties=user_properties, 507 ) 508 509 log.debug( 510 f"{log_id} Published agent request to {request_topic} (sub_task_id: {sub_task_id})" 511 ) 512 513 # Set timeout tracking 514 timeout_seconds = self.host.get_config("default_node_timeout_seconds", 300) 515 self.host.cache_service.add_data( 516 key=sub_task_id, 517 value=workflow_context.workflow_task_id, 518 expiry=timeout_seconds, 519 component=self.host, 520 )