handler.py
1 """ 2 StructuredInvocationHandler implementation. 3 4 Enables agents to be invoked with schema-validated input/output, 5 functioning as a "structured function call" pattern. Used by workflows 6 and other programmatic callers that require predictable, validated responses. 7 """ 8 9 import logging 10 import json 11 import asyncio 12 import re 13 import yaml 14 import csv 15 import io 16 from datetime import datetime, timezone 17 from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING 18 19 from pydantic import ValidationError 20 from google.adk.models.llm_request import LlmRequest 21 from google.adk.models.llm_response import LlmResponse 22 from google.adk.agents.callback_context import CallbackContext 23 from google.adk.events import Event as ADKEvent 24 from google.genai import types as adk_types 25 from google.adk.agents import RunConfig 26 from google.adk.agents.run_config import StreamingMode 27 28 from a2a.types import ( 29 Message as A2AMessage, 30 FilePart, 31 FileWithBytes, 32 FileWithUri, 33 TaskState, 34 ) 35 36 from ....common import a2a 37 from ....common.constants import ARTIFACT_TAG_WORKING 38 from ....common.data_parts import ( 39 ArtifactRef, 40 StructuredInvocationRequest, 41 StructuredInvocationResult, 42 ) 43 from ....agent.adk.runner import run_adk_async_task_thread_wrapper 44 from ....common.utils.embeds.constants import EMBED_REGEX 45 from ....agent.utils.artifact_helpers import parse_artifact_uri 46 47 if TYPE_CHECKING: 48 from ..component import SamAgentComponent 49 50 log = logging.getLogger(__name__) 51 52 53 class ResultEmbed: 54 """Parsed result embed from agent output.""" 55 56 def __init__( 57 self, 58 artifact_name: Optional[str] = None, 59 version: Optional[int] = None, 60 status: str = "success", 61 message: Optional[str] = None, 62 ): 63 self.artifact_name = artifact_name 64 self.version = version 65 self.status = status 66 self.message = message 67 68 69 class StructuredInvocationHandler: 70 """ 71 Handles structured invocation logic for an agent. 72 73 Enables agents to be invoked with schema-validated input and output, 74 supporting retry on validation failure. Used by workflows and other 75 programmatic callers that need predictable, validated responses. 76 """ 77 78 def __init__(self, host_component: "SamAgentComponent"): 79 self.host = host_component 80 self.input_schema = host_component.get_config("input_schema") 81 self.output_schema = host_component.get_config("output_schema") 82 self.max_validation_retries = host_component.get_config( 83 "validation_max_retries", 2 84 ) 85 86 def extract_structured_invocation_context( 87 self, message: A2AMessage 88 ) -> Optional[StructuredInvocationRequest]: 89 """ 90 Extract structured invocation context from message if present. 91 Structured invocation messages contain StructuredInvocationRequest in a DataPart. 92 93 Note: The DataPart may not be first in the message - the base gateway prepends 94 a timestamp TextPart. We scan all DataParts to find the request. 95 """ 96 if not message.parts: 97 return None 98 99 # Scan all DataParts for structured invocation request 100 # The base gateway may prepend other parts (e.g., timestamp), so we can't assume position 101 data_parts = a2a.get_data_parts_from_message(message) 102 103 for data_part in data_parts: 104 # Check if this DataPart contains a structured_invocation_request 105 data = data_part.data if hasattr(data_part, "data") else None 106 if not data or not isinstance(data, dict): 107 continue 108 109 if data.get("type") != "structured_invocation_request": 110 continue 111 112 # Found it - parse and return 113 try: 114 invocation_data = StructuredInvocationRequest.model_validate(data) 115 return invocation_data 116 except ValidationError as e: 117 log.error(f"{self.host.log_identifier} Invalid structured invocation request data: {e}") 118 return None 119 120 return None 121 122 async def execute_structured_invocation( 123 self, 124 message: A2AMessage, 125 invocation_data: StructuredInvocationRequest, 126 a2a_context: Dict[str, Any], 127 original_solace_message: Any = None, 128 ): 129 """Execute agent as a structured invocation with schema validation.""" 130 log_id = f"{self.host.log_identifier}[StructuredInvocation:{invocation_data.node_id}]" 131 132 log.debug( 133 f"{log_id} Received structured invocation request. Context: {invocation_data.workflow_name}, " 134 f"node_id: {invocation_data.node_id}, suggested_output_filename: {invocation_data.suggested_output_filename}" 135 ) 136 137 try: 138 # Determine effective schemas 139 input_schema = invocation_data.input_schema or self.input_schema 140 output_schema = invocation_data.output_schema or self.output_schema 141 142 # Default input schema to single text field if not provided 143 if not input_schema: 144 input_schema = { 145 "type": "object", 146 "properties": {"text": {"type": "string"}}, 147 "required": ["text"], 148 } 149 log.debug( 150 f"{log_id} No input schema provided, using default text schema" 151 ) 152 153 # Validate input against schema 154 validation_errors = await self._validate_input( 155 message, input_schema, a2a_context, log_id 156 ) 157 158 if validation_errors: 159 log.error(f"{log_id} Input validation failed: {validation_errors}") 160 161 # Return validation error immediately 162 result_data = StructuredInvocationResult( 163 type="structured_invocation_result", 164 status="error", 165 error_message=f"Input validation failed: {validation_errors}", 166 ) 167 return await self._return_structured_result( 168 invocation_data, result_data, a2a_context 169 ) 170 171 # Input valid, proceed with execution 172 return await self._execute_with_output_validation( 173 message, 174 invocation_data, 175 output_schema, 176 a2a_context, 177 original_solace_message, 178 ) 179 180 except Exception as e: 181 # Catch any unhandled exceptions and return as structured invocation failure 182 log.warning(f"{log_id} Structured invocation execution failed: {e}", exc_info=True) 183 184 result_data = StructuredInvocationResult( 185 type="structured_invocation_result", 186 status="error", 187 error_message=f"Node execution error: {str(e)}", 188 ) 189 return await self._return_structured_result( 190 invocation_data, result_data, a2a_context 191 ) 192 193 async def _validate_input( 194 self, 195 message: A2AMessage, 196 input_schema: Dict[str, Any], 197 a2a_context: Dict[str, Any], 198 log_id: str = "", 199 ) -> Optional[List[str]]: 200 """ 201 Validate message content against input schema. 202 Returns list of validation errors or None if valid. 203 """ 204 from .validator import validate_against_schema 205 206 # Extract input data from message 207 input_data = await self._extract_input_data(message, input_schema, a2a_context) 208 209 log.debug( 210 f"{log_id} Resolved input data: {json.dumps(input_data, default=str)}" 211 ) 212 213 # Validate against schema 214 errors = validate_against_schema(input_data, input_schema) 215 216 return errors if errors else None 217 218 async def _extract_input_data( 219 self, 220 message: A2AMessage, 221 input_schema: Dict[str, Any], 222 a2a_context: Dict[str, Any], 223 ) -> Dict[str, Any]: 224 """ 225 Extract structured input data from message parts. 226 227 Handles two cases: 228 1. Single text field schema: Aggregates all text parts into 'text' field 229 2. Structured schema: Extracts from first FilePart (JSON/YAML/CSV) 230 231 Returns: 232 Validated input data dictionary 233 """ 234 log_id = f"{self.host.log_identifier}[ExtractInput]" 235 236 # Check if this is a single text field schema 237 if self._is_single_text_schema(input_schema): 238 log.debug(f"{log_id} Using single text field extraction") 239 return await self._extract_text_input(message) 240 241 # Otherwise, extract from FilePart 242 log.debug(f"{log_id} Using structured FilePart extraction") 243 return await self._extract_file_input(message, input_schema, a2a_context) 244 245 def _is_single_text_schema(self, schema: Dict[str, Any]) -> bool: 246 """ 247 Check if schema represents a single text field. 248 Returns True if schema has exactly one property named 'text' of type 'string'. 249 """ 250 if schema.get("type") != "object": 251 return False 252 253 properties = schema.get("properties", {}) 254 if len(properties) != 1: 255 return False 256 257 if "text" not in properties: 258 return False 259 260 return properties["text"].get("type") == "string" 261 262 async def _extract_text_input(self, message: A2AMessage) -> Dict[str, Any]: 263 """ 264 Extract text input by aggregating all text parts. 265 Returns: {"text": "<aggregated_text>"} 266 """ 267 unwrapped_parts = [p.root for p in message.parts] 268 text_parts = [] 269 270 for part in unwrapped_parts: 271 if hasattr(part, "text") and part.text: 272 text_parts.append(part.text) 273 274 aggregated_text = "\n".join(text_parts) if text_parts else "" 275 return {"text": aggregated_text} 276 277 async def _extract_file_input( 278 self, 279 message: A2AMessage, 280 input_schema: Dict[str, Any], 281 a2a_context: Dict[str, Any], 282 ) -> Dict[str, Any]: 283 """ 284 Extract input data from first FilePart in message. 285 Handles both inline bytes and URI references. 286 """ 287 log_id = f"{self.host.log_identifier}[ExtractFile]" 288 289 # Find first FilePart 290 file_parts = a2a.get_file_parts_from_message(message) 291 292 if not file_parts: 293 raise ValueError("No FilePart found in message for structured schema") 294 295 file_part = file_parts[0] 296 297 # Determine if this is bytes or URI 298 if a2a.is_file_part_bytes(file_part): 299 log.debug(f"{log_id} Processing FileWithBytes") 300 return await self._process_file_with_bytes( 301 file_part, input_schema, a2a_context 302 ) 303 elif a2a.is_file_part_uri(file_part): 304 log.debug(f"{log_id} Processing FileWithUri") 305 return await self._process_file_with_uri(file_part, a2a_context) 306 else: 307 raise ValueError(f"Unknown FilePart type: {type(file_part)}") 308 309 async def _process_file_with_bytes( 310 self, 311 file_part: FilePart, 312 input_schema: Dict[str, Any], 313 a2a_context: Dict[str, Any], 314 ) -> Dict[str, Any]: 315 """ 316 Process inline file bytes: decode, validate, and save to artifact store. 317 """ 318 log_id = f"{self.host.log_identifier}[ProcessBytes]" 319 320 # Decode bytes according to MIME type 321 mime_type = a2a.get_mimetype_from_file_part(file_part) 322 content_bytes = a2a.get_bytes_from_file_part(file_part) 323 324 if content_bytes is None: 325 raise ValueError("FilePart has no content bytes") 326 327 data = self._decode_file_bytes(content_bytes, mime_type) 328 329 log.debug(f"{log_id} Decoded {mime_type} file data") 330 331 # Save to artifact store with appropriate name 332 artifact_name = self._generate_input_artifact_name(mime_type) 333 334 # Use helper to save artifact 335 from ....agent.utils.artifact_helpers import save_artifact_with_metadata 336 337 await save_artifact_with_metadata( 338 artifact_service=self.host.artifact_service, 339 app_name=self.host.agent_name, 340 user_id=a2a_context["user_id"], 341 session_id=a2a_context["effective_session_id"], 342 filename=artifact_name, 343 content_bytes=content_bytes, 344 mime_type=mime_type, 345 metadata_dict={"source": "workflow_input"}, 346 timestamp=datetime.now(timezone.utc), 347 tags=[ARTIFACT_TAG_WORKING], 348 ) 349 350 log.info(f"{log_id} Saved input data to artifact: {artifact_name}") 351 352 return data 353 354 async def _process_file_with_uri( 355 self, file_part: FilePart, a2a_context: Dict[str, Any] 356 ) -> Dict[str, Any]: 357 """ 358 Process file URI: load artifact and decode. 359 """ 360 log_id = f"{self.host.log_identifier}[ProcessURI]" 361 362 # Parse URI to extract artifact name and version 363 uri = a2a.get_uri_from_file_part(file_part) 364 if not uri: 365 raise ValueError("FilePart has no URI") 366 367 try: 368 uri_parts = parse_artifact_uri(uri) 369 except ValueError as e: 370 raise ValueError(f"Invalid artifact URI: {e}") 371 372 log.debug(f"{log_id} Loading artifact from URI: {uri}") 373 374 # Load artifact using the context from the URI (app_name, user_id, session_id) 375 # This ensures we can read artifacts created by the workflow orchestrator 376 artifact = await self.host.artifact_service.load_artifact( 377 app_name=uri_parts["app_name"], 378 user_id=uri_parts["user_id"], 379 session_id=uri_parts["session_id"], 380 filename=uri_parts["filename"], 381 version=uri_parts["version"], 382 ) 383 384 if not artifact or not artifact.inline_data: 385 raise ValueError( 386 f"Artifact not found or has no data: {uri_parts['filename']}" 387 ) 388 389 # Decode artifact data 390 mime_type = artifact.inline_data.mime_type 391 data = self._decode_file_bytes(artifact.inline_data.data, mime_type) 392 393 log.info(f"{log_id} Loaded and decoded artifact: {uri_parts['filename']}") 394 395 return data 396 397 def _decode_file_bytes(self, data: bytes, mime_type: str) -> Dict[str, Any]: 398 """ 399 Decode file bytes according to MIME type. 400 Supports: application/json, application/yaml, text/yaml, text/csv 401 """ 402 log_id = f"{self.host.log_identifier}[Decode]" 403 404 if mime_type in ["application/json", "text/json"]: 405 return json.loads(data.decode("utf-8")) 406 407 elif mime_type in ["application/yaml", "text/yaml", "application/x-yaml"]: 408 return yaml.safe_load(data.decode("utf-8")) 409 410 elif mime_type in ["text/csv", "application/csv"]: 411 # CSV to dict list 412 csv_text = data.decode("utf-8") 413 reader = csv.DictReader(io.StringIO(csv_text)) 414 return {"rows": list(reader)} 415 416 else: 417 raise ValueError(f"Unsupported MIME type for input data: {mime_type}") 418 419 def _generate_input_artifact_name(self, mime_type: str) -> str: 420 """ 421 Generate artifact name for input data based on MIME type. 422 Format: {agent-name}_input_data.{ext} 423 """ 424 ext_map = { 425 "application/json": "json", 426 "text/json": "json", 427 "application/yaml": "yaml", 428 "text/yaml": "yaml", 429 "application/x-yaml": "yaml", 430 "text/csv": "csv", 431 "application/csv": "csv", 432 } 433 434 extension = ext_map.get(mime_type, "dat") 435 return f"{self.host.agent_name}_input_data.{extension}" 436 437 async def _execute_with_output_validation( 438 self, 439 message: A2AMessage, 440 invocation_data: StructuredInvocationRequest, 441 output_schema: Optional[Dict[str, Any]], 442 a2a_context: Dict[str, Any], 443 original_solace_message: Any = None, 444 ): 445 """Execute agent with output validation and retry logic.""" 446 log_id = f"{self.host.log_identifier}[StructuredInvocation:{invocation_data.node_id}]" 447 448 # Create callback for instruction injection 449 workflow_callback = self._create_workflow_callback(invocation_data, output_schema) 450 451 # We need to register this callback with the agent. 452 # Since SamAgentComponent manages the agent lifecycle, we need a way to inject this. 453 # SamAgentComponent supports `_agent_system_instruction_callback`. 454 # We can temporarily override it or chain it. 455 456 original_callback = self.host._agent_system_instruction_callback 457 458 def chained_callback(context, request): 459 # Call original if exists 460 original_instr = ( 461 original_callback(context, request) if original_callback else None 462 ) 463 # Call workflow callback 464 workflow_instr = workflow_callback(context, request) 465 466 parts = [] 467 if original_instr: 468 parts.append(original_instr) 469 if workflow_instr: 470 parts.append(workflow_instr) 471 return "\n\n".join(parts) if parts else None 472 473 self.host.set_agent_system_instruction_callback(chained_callback) 474 475 # Import TaskExecutionContext 476 from ..task_execution_context import TaskExecutionContext 477 478 logical_task_id = a2a_context.get("logical_task_id") 479 480 # Create and register TaskExecutionContext for this structured invocation 481 task_context = TaskExecutionContext( 482 task_id=logical_task_id, a2a_context=a2a_context 483 ) 484 # Mark this task as a structured invocation so artifacts are auto-tagged as internal 485 task_context.set_flag("is_structured_invocation", True) 486 487 # Store the original Solace message if provided 488 # Note: original_solace_message is passed as a parameter, not stored in a2a_context, 489 # to avoid serialization issues when a2a_context is stored in ADK session state 490 if original_solace_message: 491 task_context.set_original_solace_message(original_solace_message) 492 493 # Register the task context 494 with self.host.active_tasks_lock: 495 self.host.active_tasks[logical_task_id] = task_context 496 497 log.debug( 498 f"{self.host.log_identifier}[StructuredInvocation:{invocation_data.node_id}] Created TaskExecutionContext for task {logical_task_id}" 499 ) 500 501 try: 502 # Execute agent (existing ADK execution path) 503 # We need to trigger the standard handle_a2a_request logic but intercept the result. 504 # However, handle_a2a_request is designed to run the agent and return. 505 # It calls `run_adk_async_task_thread_wrapper`. 506 # We can call that directly. 507 508 # Prepare ADK content 509 user_id = a2a_context.get("user_id") 510 # For structured invocations, create a run-based session ID following the same pattern 511 # as RUN_BASED A2A requests: {original_session_id}:{logical_task_id}:run 512 # This ensures: 513 # 1. Each invocation starts with a fresh session (RUN_BASED behavior) 514 # 2. get_original_session_id() can extract the parent session for artifact sharing 515 original_session_id = a2a_context.get("session_id") 516 logical_task_id = a2a_context.get("logical_task_id") 517 session_id = f"{original_session_id}:{logical_task_id}:run" 518 519 adk_content = await a2a.translate_a2a_to_adk_content( 520 a2a_message=message, 521 component=self.host, 522 user_id=user_id, 523 session_id=session_id, 524 ) 525 526 # Always create a new session for structured invocations (RUN_BASED behavior) 527 adk_session = await self.host.session_service.create_session( 528 app_name=self.host.agent_name, 529 user_id=user_id, 530 session_id=session_id, 531 ) 532 533 # Update effective_session_id to the run-based session so that 534 # retriggering after peer-agent responses can find the correct session. 535 a2a_context["effective_session_id"] = session_id 536 537 run_config = RunConfig( 538 streaming_mode=StreamingMode.SSE, 539 max_llm_calls=self.host.get_config("max_llm_calls_per_task", 20), 540 ) 541 542 # Mark this task as a structured invocation so the runner knows 543 # to run deferred SI finalization instead of normal finalization on retrigger. 544 task_context.set_flag("structured_invocation", True) 545 546 # Store state needed for deferred finalization if the agent pauses 547 # for peer-agent calls. These are retrieved by 548 # finalize_deferred_structured_invocation() when the retrigger completes. 549 task_context.set_flag("si_invocation_data", invocation_data) 550 task_context.set_flag("si_output_schema", output_schema) 551 task_context.set_flag("si_original_callback", original_callback) 552 553 # Execute 554 is_paused = await run_adk_async_task_thread_wrapper( 555 self.host, 556 adk_session, 557 adk_content, 558 run_config, 559 a2a_context, 560 skip_finalization=True, # Structured invocations do custom finalization 561 ) 562 563 # If the agent is paused (waiting for peer-agent responses), 564 # return immediately. The runner will call 565 # finalize_deferred_structured_invocation() when the retrigger 566 # completes with is_paused=False. 567 if is_paused: 568 log.info( 569 f"{log_id} Agent is paused waiting for peer-agent responses. " 570 "Deferring SI finalization until retrigger completes." 571 ) 572 return 573 574 # Agent completed immediately — run validation and finalization inline. 575 await self._run_si_finalization( 576 task_context, a2a_context, log_id 577 ) 578 579 finally: 580 # Only clean up if the task is NOT paused waiting for peer responses. 581 # If paused, finalize_deferred_structured_invocation() handles cleanup. 582 if not task_context.get_is_paused(): 583 self._cleanup_structured_invocation(task_context, logical_task_id, original_callback) 584 585 async def _run_si_finalization( 586 self, 587 task_context, 588 a2a_context: Dict[str, Any], 589 log_id: str, 590 retry_count: int = 0, 591 ): 592 """ 593 Run structured invocation finalization: fetch session, validate result, 594 and return the structured result to the workflow. 595 596 This is used both inline (when the agent completes immediately) and 597 deferred (when the agent was paused for peer-agent calls and later 598 completed via retrigger). 599 """ 600 invocation_data = task_context.get_flag("si_invocation_data") 601 output_schema = task_context.get_flag("si_output_schema") 602 user_id = a2a_context.get("user_id") 603 session_id = a2a_context.get("effective_session_id") 604 605 # Fetch the updated session with the agent's final response 606 adk_session = await self.host.session_service.get_session( 607 app_name=self.host.agent_name, 608 user_id=user_id, 609 session_id=session_id, 610 ) 611 612 # Find the last model response event 613 # The session might end with a tool response (e.g. _notify_artifact_save) if the model 614 # outputs nothing in the final turn. We scan backwards for the text output. 615 last_model_event = None 616 if adk_session.events: 617 for i, event in enumerate(reversed(adk_session.events)): 618 if event.content and event.content.role == "model": 619 last_model_event = event 620 log.debug(f"{log_id} Found last model event at index -{i+1}: {event.id}") 621 break 622 623 if not last_model_event: 624 log.warning(f"{log_id} No model event found in session history.") 625 626 result_data = await self._finalize_structured_invocation( 627 adk_session, last_model_event, invocation_data, output_schema, retry_count 628 ) 629 630 if result_data is None: 631 # A retry paused for peer-agent calls — finalization is deferred again. 632 # The runner will call finalize_deferred_structured_invocation() when 633 # the retrigger completes. 634 return 635 636 log.debug( 637 f"{log_id} Final result data: {result_data.model_dump_json()}" 638 ) 639 640 # Send result back to workflow 641 await self._return_structured_result(invocation_data, result_data, a2a_context) 642 643 def _cleanup_structured_invocation( 644 self, 645 task_context, 646 logical_task_id: str, 647 original_callback, 648 ): 649 """Clean up task context and restore original callback after SI completion.""" 650 with self.host.active_tasks_lock: 651 if logical_task_id in self.host.active_tasks: 652 del self.host.active_tasks[logical_task_id] 653 log.debug( 654 "%s Removed TaskExecutionContext for task %s", 655 self.host.log_identifier, 656 logical_task_id, 657 ) 658 659 self.host.set_agent_system_instruction_callback(original_callback) 660 661 async def finalize_deferred_structured_invocation( 662 self, 663 task_context, 664 a2a_context: Dict[str, Any], 665 exception: Optional[Exception] = None, 666 ): 667 """ 668 Called by the runner when a structured invocation task completes after 669 being paused for peer-agent responses. Runs SI validation/finalization 670 and cleanup. 671 """ 672 invocation_data = task_context.get_flag("si_invocation_data") 673 original_callback = task_context.get_flag("si_original_callback") 674 logical_task_id = a2a_context.get("logical_task_id") 675 log_id = f"{self.host.log_identifier}[StructuredInvocation:{invocation_data.node_id}]" 676 677 try: 678 if exception: 679 log.error( 680 f"{log_id} Deferred SI finalization received error: {exception}" 681 ) 682 result_data = StructuredInvocationResult( 683 type="structured_invocation_result", 684 status="error", 685 error_message=f"Error during execution: {exception}", 686 ) 687 await self._return_structured_result(invocation_data, result_data, a2a_context) 688 return 689 690 # Restore the retry count that was saved when the task paused. 691 # This prevents the count from resetting to 0 on deferred finalization. 692 retry_count = task_context.get_flag("si_retry_count", 0) 693 await self._run_si_finalization( 694 task_context, a2a_context, log_id, retry_count=retry_count 695 ) 696 697 except Exception as e: 698 log.exception( 699 f"{log_id} Error in deferred SI finalization: {e}" 700 ) 701 try: 702 result_data = StructuredInvocationResult( 703 type="structured_invocation_result", 704 status="error", 705 error_message=f"Internal error during finalization: {e}", 706 ) 707 await self._return_structured_result(invocation_data, result_data, a2a_context) 708 except Exception as e2: 709 log.exception(f"{log_id} Failed to send error result: {e2}") 710 711 finally: 712 self._cleanup_structured_invocation(task_context, logical_task_id, original_callback) 713 714 def _create_workflow_callback( 715 self, 716 invocation_data: StructuredInvocationRequest, 717 output_schema: Optional[Dict[str, Any]], 718 ) -> Callable: 719 """Create callback for workflow instruction injection.""" 720 721 def inject_instructions( 722 callback_context: CallbackContext, llm_request: LlmRequest 723 ) -> Optional[str]: 724 return self._generate_workflow_instructions(invocation_data, output_schema) 725 726 return inject_instructions 727 728 def _generate_workflow_instructions( 729 self, 730 invocation_data: StructuredInvocationRequest, 731 output_schema: Optional[Dict[str, Any]], 732 ) -> str: 733 """Generate workflow-specific instructions.""" 734 735 workflow_instructions = f""" 736 737 === WORKFLOW EXECUTION CONTEXT === 738 You are executing as node '{invocation_data.node_id}' in workflow '{invocation_data.workflow_name}'. 739 """ 740 741 # Add required output filename if provided 742 if invocation_data.suggested_output_filename: 743 workflow_instructions += f""" 744 === REQUIRED OUTPUT ARTIFACT FILENAME === 745 You MUST save your output artifact with this exact filename: 746 {invocation_data.suggested_output_filename} 747 748 When you complete this task, use: «result:artifact={invocation_data.suggested_output_filename} status=success» 749 """ 750 751 # Add output schema requirement if present 752 if output_schema: 753 workflow_instructions += f""" 754 755 === CRITICAL: REQUIRED OUTPUT FORMAT === 756 You MUST follow these steps to complete this task: 757 758 1. Create an artifact containing your result data conforming to this JSON Schema: 759 760 {json.dumps(output_schema, indent=2)} 761 762 2. MANDATORY: End your response with the result embed marking your output artifact: 763 «result:artifact=<artifact_name> status=success» 764 765 Example: «result:artifact=customer_data.json status=success» 766 767 IMPORTANT: Do NOT include a version number if returning the latest version - the system will automatically provide the most recent version. 768 769 3. The artifact MUST strictly conform to the provided schema. Your output will be validated. 770 If validation fails, you will be asked to retry with error feedback. 771 772 IMPORTANT NOTES: 773 - Use the save_artifact tool OR inline fenced blocks to create the output artifact 774 - The result embed («result:artifact=...») is MANDATORY - the invocation will fail without it 775 - The artifact format (JSON, YAML, etc.) must be parseable 776 - Additional fields beyond the schema are allowed, but all required fields must be present 777 778 FAILURE TO INCLUDE THE RESULT EMBED WILL CAUSE THE INVOCATION TO FAIL. 779 """ 780 else: 781 # No output schema, just mark result 782 workflow_instructions += """ 783 784 === CRITICAL: REQUIRED OUTPUT FORMAT === 785 You MUST end your response with the result embed to mark your completion: 786 787 «result:artifact=<artifact_name> status=success» 788 789 This result embed is MANDATORY. The invocation cannot proceed without it. 790 791 IMPORTANT: Do NOT include a version number if returning the latest version - the system will automatically provide the most recent version. 792 793 If you cannot complete the task, use: 794 «result:artifact=<artifact_name> status=error message="<reason>"» 795 """ 796 return workflow_instructions.strip() 797 798 async def _finalize_structured_invocation( 799 self, 800 session, 801 last_event: ADKEvent, 802 invocation_data: StructuredInvocationRequest, 803 output_schema: Optional[Dict[str, Any]], 804 retry_count: int = 0, 805 ) -> Optional[StructuredInvocationResult]: 806 """ 807 Finalize structured invocation with output validation. 808 Handles retry on validation failure or missing result embed. 809 810 Returns: 811 StructuredInvocationResult if finalization completed, or None if a retry 812 paused for peer-agent calls (finalization will be deferred). 813 """ 814 log_id = f"{self.host.log_identifier}[Node:{invocation_data.node_id}]" 815 816 # 1. Parse result embed from agent output 817 result_embed = self._parse_result_embed(last_event) 818 819 if not result_embed: 820 error_msg = "Agent did not output the mandatory result embed: «result:artifact=... status=success»" 821 log.warning(f"{log_id} {error_msg}") 822 823 if retry_count < self.max_validation_retries: 824 log.info(f"{log_id} Retrying due to missing result embed (Attempt {retry_count + 1})") 825 feedback_text = f""" 826 ERROR: You failed to provide the mandatory result embed in your response. 827 You MUST end your response with: 828 «result:artifact=<your_artifact_name>:<version> status=success» 829 830 Please retry and ensure you include this embed. 831 """ 832 return await self._execute_retry_loop( 833 session, 834 invocation_data, 835 output_schema, 836 feedback_text, 837 retry_count + 1, 838 ) 839 else: 840 return StructuredInvocationResult( 841 type="structured_invocation_result", 842 status="error", 843 error_message=error_msg, 844 retry_count=retry_count, 845 ) 846 847 # Handle explicit failure status 848 if result_embed.status == "error": 849 return StructuredInvocationResult( 850 type="structured_invocation_result", 851 status="error", 852 error_message=result_embed.message or "Agent reported failure", 853 output_artifact_ref=ArtifactRef(name=result_embed.artifact_name) if result_embed.artifact_name else None, 854 retry_count=retry_count, 855 ) 856 857 # 2. Load artifact from artifact service 858 try: 859 # If version is missing, query for latest version 860 version = int(result_embed.version) if result_embed.version else None 861 862 if version is None: 863 # Use original session ID to query for versions (same as when artifacts were saved) 864 from ....agent.utils.context_helpers import get_original_session_id 865 original_session_id_for_versions = get_original_session_id(session.id) 866 867 # Query for the latest version 868 versions = await self.host.artifact_service.list_versions( 869 app_name=self.host.agent_name, 870 user_id=session.user_id, 871 session_id=original_session_id_for_versions, 872 filename=result_embed.artifact_name, 873 ) 874 if versions: 875 version = max(versions) 876 log.debug( 877 f"{log_id} Resolved latest version for {result_embed.artifact_name}: v{version}" 878 ) 879 else: 880 log.error( 881 f"{log_id} No versions found for artifact {result_embed.artifact_name}" 882 ) 883 return StructuredInvocationResult( 884 type="structured_invocation_result", 885 status="error", 886 error_message=f"Artifact {result_embed.artifact_name} not found (no versions available)", 887 retry_count=retry_count, 888 ) 889 890 # Use original session ID (without :run suffix) to load artifacts 891 # This ensures we can access artifacts saved by the agent, which uses 892 # get_original_session_id() to store them in the parent session scope 893 from ....agent.utils.context_helpers import get_original_session_id 894 original_session_id = get_original_session_id(session.id) 895 896 artifact = await self.host.artifact_service.load_artifact( 897 app_name=self.host.agent_name, 898 user_id=session.user_id, 899 session_id=original_session_id, 900 filename=result_embed.artifact_name, 901 version=version, 902 ) 903 except Exception as e: 904 log.error(f"{log_id} Failed to load artifact: {e}") 905 return StructuredInvocationResult( 906 type="structured_invocation_result", 907 status="error", 908 error_message=f"Failed to load result artifact: {e}", 909 retry_count=retry_count, 910 ) 911 912 # 3. Validate artifact against output schema 913 if output_schema: 914 validation_errors = self._validate_artifact(artifact, output_schema) 915 916 if validation_errors: 917 log.warning(f"{log_id} Output validation failed: {validation_errors}") 918 919 # Check if we can retry 920 if retry_count < self.max_validation_retries: 921 log.info(f"{log_id} Retrying with validation feedback (Attempt {retry_count + 1})") 922 923 error_text = "\n".join([f"- {err}" for err in validation_errors]) 924 feedback_text = f""" 925 Your previous output artifact failed schema validation with the following errors: 926 927 {error_text} 928 929 Please review the required schema and create a corrected artifact that addresses these errors: 930 931 {json.dumps(output_schema, indent=2)} 932 933 Remember to end your response with the result embed: 934 «result:artifact=<corrected_artifact_name>:<version> status=success» 935 """ 936 return await self._execute_retry_loop( 937 session, 938 invocation_data, 939 output_schema, 940 feedback_text, 941 retry_count + 1, 942 ) 943 else: 944 # Max retries exceeded 945 return StructuredInvocationResult( 946 type="structured_invocation_result", 947 status="error", 948 error_message="Output validation failed after max retries", 949 validation_errors=validation_errors, 950 retry_count=retry_count, 951 ) 952 953 # 4. Validation succeeded 954 return StructuredInvocationResult( 955 type="structured_invocation_result", 956 status="success", 957 output_artifact_ref=ArtifactRef(name=result_embed.artifact_name, version=version), 958 retry_count=retry_count, 959 ) 960 961 def _parse_result_embed(self, adk_event: ADKEvent) -> Optional[ResultEmbed]: 962 """ 963 Parse result embed from agent's final output. 964 Format: «result:artifact=<name>:v<version> status=<success|error> message="<text>"» 965 """ 966 if not adk_event or not adk_event.content or not adk_event.content.parts: 967 log.debug("Result embed parse: Event is empty or has no content.") 968 return None 969 970 # Only parse result embeds from agent responses (role="model"), not instructions (role="user") 971 # This prevents parsing example embeds from the workflow instructions 972 if adk_event.content.role != "model": 973 log.debug(f"Result embed parse: Event role is {adk_event.content.role}, skipping.") 974 return None 975 976 # Extract text from last event 977 text_content = "" 978 for part in adk_event.content.parts: 979 if part.text: 980 text_content += part.text 981 982 log.debug(f"Result embed parse: Scanning text content (len={len(text_content)}): {text_content[:100]}...") 983 984 # Parse embeds using EMBED_REGEX 985 result_embeds = [] 986 for match in EMBED_REGEX.finditer(text_content): 987 embed_type = match.group(1) 988 if embed_type == "result": 989 expression = match.group(2) 990 result_embeds.append(expression) 991 992 if not result_embeds: 993 return None 994 995 # Take last result embed and parse its parameters 996 # Format: artifact=<name>:v<version> status=<success|error> message="<text>" 997 expression = result_embeds[-1] 998 999 # Parse parameters from expression 1000 params = {} 1001 1002 # Match key=value patterns, handling quoted values 1003 param_pattern = r'(\w+)=(?:"([^"]*)"|([^\s]+))' 1004 for param_match in re.finditer(param_pattern, expression): 1005 key = param_match.group(1) 1006 # Use quoted value if present, otherwise use unquoted 1007 value = ( 1008 param_match.group(2) 1009 if param_match.group(2) is not None 1010 else param_match.group(3) 1011 ) 1012 params[key] = value 1013 1014 # Extract artifact name and version 1015 artifact_spec = params.get("artifact", "") 1016 artifact_name = artifact_spec 1017 version = None 1018 1019 # Check if version is in artifact spec (e.g., "filename:v1" or "filename:1") 1020 if ":" in artifact_spec: 1021 parts = artifact_spec.split(":", 1) 1022 artifact_name = parts[0] 1023 version_str = parts[1] 1024 1025 # Handle both "v1" and "1" formats 1026 if version_str.startswith("v"): 1027 version_str = version_str[1:] 1028 1029 try: 1030 version = int(version_str) 1031 except (ValueError, IndexError): 1032 pass 1033 1034 # Also check for standalone version parameter (less common) 1035 if version is None and "version" in params: 1036 try: 1037 version_str = params["version"] 1038 if version_str.startswith("v"): 1039 version_str = version_str[1:] 1040 version = int(version_str) 1041 except (ValueError, TypeError): 1042 pass 1043 1044 # Validate: must have artifact OR explicit error status 1045 status = params.get("status", "success") 1046 if not artifact_name and status != "error": 1047 log.debug( 1048 "Result embed parse: Malformed embed - no artifact and no explicit error status" 1049 ) 1050 return None 1051 1052 return ResultEmbed( 1053 artifact_name=artifact_name, 1054 version=version, 1055 status=status, 1056 message=params.get("message"), 1057 ) 1058 1059 def _validate_artifact( 1060 self, artifact_part: adk_types.Part, schema: Dict[str, Any] 1061 ) -> Optional[List[str]]: 1062 """Validate artifact content against schema.""" 1063 from .validator import validate_against_schema 1064 1065 if not artifact_part: 1066 return ["Artifact is None"] 1067 1068 if not artifact_part.inline_data: 1069 return ["Artifact has no inline data"] 1070 1071 try: 1072 data = json.loads(artifact_part.inline_data.data.decode("utf-8")) 1073 return validate_against_schema(data, schema) 1074 except json.JSONDecodeError: 1075 return ["Artifact content is not valid JSON"] 1076 except Exception as e: 1077 return [f"Error validating artifact: {e}"] 1078 1079 async def _execute_retry_loop( 1080 self, 1081 session, 1082 invocation_data: StructuredInvocationRequest, 1083 output_schema: Optional[Dict[str, Any]], 1084 feedback_text: str, 1085 retry_count: int, 1086 ) -> Optional[StructuredInvocationResult]: 1087 """ 1088 Execute a retry loop: append feedback, run agent, and validate result. 1089 1090 Returns: 1091 StructuredInvocationResult if completed, or None if the agent paused 1092 for peer-agent calls (finalization will be deferred). 1093 """ 1094 log_id = f"{self.host.log_identifier}[Node:{invocation_data.node_id}]" 1095 log.info(f"{log_id} Executing retry loop {retry_count}/{self.max_validation_retries}") 1096 1097 # 1. Prepare feedback content 1098 feedback_content = adk_types.Content( 1099 role="user", 1100 parts=[adk_types.Part(text=feedback_text)], 1101 ) 1102 1103 # 2. Re-run the agent 1104 # We need to reconstruct the context needed for execution. 1105 # We need the original a2a_context to pass through. 1106 # Since we don't have it passed in here, we need to retrieve it from the active task context. 1107 # The session ID contains the logical_task_id: {original}:{logical}:run 1108 1109 try: 1110 parts = session.id.split(":") 1111 if len(parts) >= 3 and parts[-1] == "run": 1112 logical_task_id = parts[-2] 1113 else: 1114 # Fallback or error 1115 log.error(f"{log_id} Could not extract logical_task_id from session ID {session.id}. Cannot retry.") 1116 return StructuredInvocationResult( 1117 type="structured_invocation_result", 1118 status="error", 1119 error_message="Internal error: Lost task context during retry", 1120 retry_count=retry_count 1121 ) 1122 1123 with self.host.active_tasks_lock: 1124 task_context = self.host.active_tasks.get(logical_task_id) 1125 1126 if not task_context: 1127 log.error(f"{log_id} TaskExecutionContext not found for {logical_task_id}. Cannot retry.") 1128 return StructuredInvocationResult( 1129 type="structured_invocation_result", 1130 status="error", 1131 error_message="Internal error: Task context lost during retry", 1132 retry_count=retry_count 1133 ) 1134 1135 a2a_context = task_context.a2a_context 1136 1137 # Prepare run config 1138 run_config = RunConfig( 1139 streaming_mode=StreamingMode.SSE, 1140 max_llm_calls=self.host.get_config("max_llm_calls_per_task", 20), 1141 ) 1142 1143 # Run the agent again with the feedback content 1144 # The runner will handle appending the event to the session 1145 is_paused = await run_adk_async_task_thread_wrapper( 1146 self.host, 1147 session, 1148 feedback_content, 1149 run_config, 1150 a2a_context, 1151 skip_finalization=True, 1152 append_context_event=False # Context already set 1153 ) 1154 1155 # If the agent is paused (waiting for peer-agent responses), 1156 # return None to signal that finalization is deferred. 1157 # The runner will call finalize_deferred_structured_invocation() 1158 # when the retrigger completes. 1159 if is_paused: 1160 log.info( 1161 f"{log_id} Agent paused during retry for peer-agent responses. " 1162 "Deferring SI finalization until retrigger completes." 1163 ) 1164 # Preserve the retry count so deferred finalization continues 1165 # from the correct retry position instead of resetting to 0. 1166 task_context.set_flag("si_retry_count", retry_count) 1167 return None 1168 1169 # 3. Fetch updated session and validate new result 1170 updated_session = await self.host.session_service.get_session( 1171 app_name=self.host.agent_name, 1172 user_id=session.user_id, 1173 session_id=session.id, 1174 ) 1175 1176 # Find the new last model event 1177 last_model_event = None 1178 if updated_session.events: 1179 for i, event in enumerate(reversed(updated_session.events)): 1180 if event.content and event.content.role == "model": 1181 last_model_event = event 1182 break 1183 1184 if not last_model_event: 1185 log.warning(f"{log_id} No model response in retry turn.") 1186 # This will trigger another retry if count allows, via _finalize... 1187 1188 # Recursively call finalize to validate the new output 1189 return await self._finalize_structured_invocation( 1190 updated_session, 1191 last_model_event, 1192 invocation_data, 1193 output_schema, 1194 retry_count 1195 ) 1196 1197 except Exception as e: 1198 log.exception(f"{log_id} Error during retry execution: {e}") 1199 return StructuredInvocationResult( 1200 type="structured_invocation_result", 1201 status="error", 1202 error_message=f"Retry execution failed: {e}", 1203 retry_count=retry_count 1204 ) 1205 1206 async def _return_structured_result( 1207 self, 1208 invocation_data: StructuredInvocationRequest, 1209 result_data: StructuredInvocationResult, 1210 a2a_context: Dict[str, Any], 1211 ): 1212 """Return structured invocation result to the caller.""" 1213 try: 1214 # Create message with result data part 1215 result_message = a2a.create_agent_parts_message( 1216 parts=[a2a.create_data_part(data=result_data.model_dump())], 1217 task_id=a2a_context["logical_task_id"], 1218 context_id=a2a_context["session_id"], 1219 ) 1220 1221 # Create task status 1222 task_state = ( 1223 TaskState.completed 1224 if result_data.status == "success" 1225 else TaskState.failed 1226 ) 1227 task_status = a2a.create_task_status( 1228 state=task_state, message=result_message 1229 ) 1230 1231 # Create final task 1232 final_task = a2a.create_final_task( 1233 task_id=a2a_context["logical_task_id"], 1234 context_id=a2a_context["session_id"], 1235 final_status=task_status, 1236 metadata={ 1237 "agent_name": self.host.agent_name, 1238 "workflow_node_id": invocation_data.node_id, 1239 "workflow_name": invocation_data.workflow_name, 1240 }, 1241 ) 1242 1243 # Create JSON-RPC response 1244 response = a2a.create_success_response( 1245 result=final_task, request_id=a2a_context["jsonrpc_request_id"] 1246 ) 1247 1248 # Publish to workflow's response topic 1249 response_topic = a2a_context.get("replyToTopic") 1250 1251 # DEBUG: Log task ID when agent returns result to caller 1252 log.debug( 1253 f"{self.host.log_identifier}[StructuredInvocation:{invocation_data.node_id}] " 1254 f"Returning structured invocation result to caller | " 1255 f"sub_task_id={a2a_context['logical_task_id']} | " 1256 f"jsonrpc_request_id={a2a_context['jsonrpc_request_id']} | " 1257 f"result_status={result_data.status} | " 1258 f"response_topic={response_topic} | " 1259 f"workflow_name={invocation_data.workflow_name} | " 1260 f"node_id={invocation_data.node_id}" 1261 ) 1262 1263 if not response_topic: 1264 log.error( 1265 f"{self.host.log_identifier}[StructuredInvocation:{invocation_data.node_id}] " 1266 f"No replyToTopic in a2a_context! Cannot send structured invocation result. " 1267 f"a2a_context keys: {list(a2a_context.keys())}" 1268 ) 1269 # Still ACK the message to avoid redelivery 1270 # Retrieve from TaskExecutionContext 1271 logical_task_id = a2a_context.get("logical_task_id") 1272 with self.host.active_tasks_lock: 1273 task_context = self.host.active_tasks.get(logical_task_id) 1274 if task_context: 1275 original_message = task_context.get_original_solace_message() 1276 if original_message: 1277 original_message.call_acknowledgements() 1278 return 1279 1280 log.info( 1281 f"{self.host.log_identifier}[StructuredInvocation:{invocation_data.node_id}] " 1282 f"Publishing structured invocation result (status={result_data.status}) to {response_topic}" 1283 ) 1284 1285 self.host.publish_a2a_message( 1286 payload=response.model_dump(exclude_none=True), 1287 topic=response_topic, 1288 user_properties={"a2aUserConfig": a2a_context.get("a2a_user_config")}, 1289 ) 1290 1291 # ACK original message 1292 # Retrieve from TaskExecutionContext 1293 logical_task_id = a2a_context.get("logical_task_id") 1294 with self.host.active_tasks_lock: 1295 task_context = self.host.active_tasks.get(logical_task_id) 1296 if task_context: 1297 original_message = task_context.get_original_solace_message() 1298 if original_message: 1299 original_message.call_acknowledgements() 1300 1301 except Exception as e: 1302 log.error( 1303 f"{self.host.log_identifier}[StructuredInvocation:{invocation_data.node_id}] " 1304 f"CRITICAL: Failed to return structured invocation result to caller: {e}", 1305 exc_info=True, 1306 ) 1307 # Try to ACK message even on error to avoid redelivery loop 1308 try: 1309 # Retrieve from TaskExecutionContext 1310 logical_task_id = a2a_context.get("logical_task_id") 1311 with self.host.active_tasks_lock: 1312 task_context = self.host.active_tasks.get(logical_task_id) 1313 if task_context: 1314 original_message = task_context.get_original_solace_message() 1315 if original_message: 1316 original_message.call_acknowledgements() 1317 except Exception as ack_e: 1318 log.error( 1319 f"{self.host.log_identifier}[StructuredInvocation:{invocation_data.node_id}] " 1320 f"Failed to ACK message after error: {ack_e}" 1321 )