dag_executor.py
1 """ 2 DAG Executor for Prescriptive Workflows. 3 Manages the execution order of workflow nodes based on their dependencies. 4 """ 5 6 import logging 7 import re 8 import asyncio 9 import json 10 import uuid 11 from datetime import datetime, timezone 12 from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING 13 14 from google.genai import types as adk_types 15 16 from .app import ( 17 WorkflowDefinition, 18 WorkflowNode, 19 AgentNode, 20 SwitchNode, 21 LoopNode, 22 MapNode, 23 WorkflowInvokeNode, 24 ) 25 from .workflow_execution_context import WorkflowExecutionContext, WorkflowExecutionState 26 from ..common.data_parts import ( 27 StructuredInvocationResult, 28 WorkflowNodeExecutionStartData, 29 WorkflowNodeExecutionResultData, 30 WorkflowMapProgressData, 31 ArtifactRef, 32 ) 33 from ..agent.utils.artifact_helpers import save_artifact_with_metadata 34 from ..common.constants import ARTIFACT_TAG_WORKING 35 36 if TYPE_CHECKING: 37 from .component import WorkflowExecutorComponent 38 39 log = logging.getLogger(__name__) 40 41 42 class WorkflowExecutionError(Exception): 43 """Raised when workflow execution fails.""" 44 45 pass 46 47 48 class WorkflowNodeFailureError(Exception): 49 """Raised when a workflow node fails.""" 50 51 def __init__(self, node_id: str, error_message: str): 52 self.node_id = node_id 53 self.error_message = error_message 54 super().__init__(f"Node '{node_id}' failed: {error_message}") 55 56 57 class DAGExecutor: 58 """Executes workflow DAG by coordinating node execution.""" 59 60 def __init__( 61 self, 62 workflow_definition: WorkflowDefinition, 63 host_component: "WorkflowExecutorComponent", 64 ): 65 self.workflow_def = workflow_definition 66 self.host = host_component 67 68 # Build dependency graph 69 self.nodes: Dict[str, WorkflowNode] = { 70 node.id: node for node in workflow_definition.nodes 71 } 72 73 # Identify inner nodes (targets of MapNodes/LoopNodes) that should not be executed directly 74 self.inner_nodes = set() 75 for node in workflow_definition.nodes: 76 if node.type == "map": 77 self.inner_nodes.add(node.node) 78 elif node.type == "loop": 79 self.inner_nodes.add(node.node) 80 81 self.dependencies = self._build_dependency_graph() 82 self.reverse_dependencies = self._build_reverse_dependencies() 83 84 def _build_dependency_graph(self) -> Dict[str, List[str]]: 85 """Build mapping of node_id -> list of node IDs it depends on.""" 86 dependencies = {} 87 88 for node in self.workflow_def.nodes: 89 deps = node.depends_on or [] 90 dependencies[node.id] = deps 91 92 return dependencies 93 94 def _build_reverse_dependencies(self) -> Dict[str, List[str]]: 95 """Build mapping of node_id -> list of nodes that depend on it.""" 96 reverse_deps = {node_id: [] for node_id in self.nodes} 97 98 for node_id, deps in self.dependencies.items(): 99 for dep in deps: 100 if dep in reverse_deps: 101 reverse_deps[dep].append(node_id) 102 103 return reverse_deps 104 105 def get_initial_nodes(self) -> List[str]: 106 """Get nodes with no dependencies (entry points).""" 107 return [ 108 node_id 109 for node_id, deps in self.dependencies.items() 110 if not deps and node_id not in self.inner_nodes 111 ] 112 113 def get_node_by_id(self, node_id: str) -> Optional[WorkflowNode]: 114 """Get a node by its ID.""" 115 return self.nodes.get(node_id) 116 117 def get_next_nodes( 118 self, workflow_state: WorkflowExecutionState 119 ) -> List[str]: 120 """ 121 Determine which nodes can execute next. 122 Returns nodes whose dependencies are all complete. 123 """ 124 completed = set(workflow_state.completed_nodes.keys()) 125 next_nodes = [] 126 127 for node_id, deps in self.dependencies.items(): 128 # Skip inner nodes (executed by MapNodes) 129 if node_id in self.inner_nodes: 130 continue 131 132 # Skip if already completed 133 if node_id in completed: 134 continue 135 136 # Skip if already pending 137 if node_id in workflow_state.pending_nodes: 138 continue 139 140 # Check if all dependencies are satisfied 141 if all(dep in completed for dep in deps): 142 next_nodes.append(node_id) 143 144 return next_nodes 145 146 def validate_dag(self) -> List[str]: 147 """ 148 Validate DAG structure. 149 Returns list of validation errors or empty list if valid. 150 """ 151 errors = [] 152 153 # Check for cycles 154 if self._has_cycles(): 155 errors.append("Workflow DAG contains cycles") 156 157 # Check for invalid dependencies 158 for node_id, deps in self.dependencies.items(): 159 for dep in deps: 160 if dep not in self.nodes: 161 errors.append( 162 f"Node '{node_id}' depends on non-existent node '{dep}'" 163 ) 164 165 # Check for unreachable nodes (excluding inner nodes which are reached via map execution) 166 reachable = self._get_reachable_nodes() 167 for node_id in self.nodes: 168 # Inner nodes (map targets) are reachable via their parent map node 169 if node_id not in reachable and node_id not in self.inner_nodes: 170 errors.append(f"Node '{node_id}' is unreachable") 171 172 return errors 173 174 def _has_cycles(self) -> bool: 175 """Detect cycles using depth-first search.""" 176 visited = set() 177 rec_stack = set() 178 179 def dfs(node_id: str) -> bool: 180 visited.add(node_id) 181 rec_stack.add(node_id) 182 183 for dependent in self.reverse_dependencies.get(node_id, []): 184 if dependent not in visited: 185 if dfs(dependent): 186 return True 187 elif dependent in rec_stack: 188 return True 189 190 rec_stack.remove(node_id) 191 return False 192 193 for node_id in self.nodes: 194 if node_id not in visited: 195 if dfs(node_id): 196 return True 197 198 return False 199 200 def _get_reachable_nodes(self) -> Set[str]: 201 """Get set of all reachable nodes from initial nodes.""" 202 reachable = set() 203 queue = self.get_initial_nodes() 204 205 while queue: 206 node_id = queue.pop(0) 207 if node_id in reachable: 208 continue 209 reachable.add(node_id) 210 queue.extend(self.reverse_dependencies.get(node_id, [])) 211 212 return reachable 213 214 async def execute_workflow( 215 self, 216 workflow_state: WorkflowExecutionState, 217 workflow_context: WorkflowExecutionContext, 218 ): 219 """ 220 Execute workflow DAG until completion or failure. 221 Main execution loop. 222 """ 223 log_id = f"{self.host.log_identifier}[Workflow:{workflow_state.execution_id}]" 224 225 while True: 226 # Check for cancellation before proceeding 227 if workflow_context.is_cancelled(): 228 log.info(f"{log_id} Workflow cancelled, stopping execution") 229 return 230 231 # Get next nodes to execute 232 next_nodes = self.get_next_nodes(workflow_state) 233 234 if not next_nodes: 235 # Check if workflow is complete 236 if len(workflow_state.completed_nodes) == len(self.nodes): 237 log.info(f"{log_id} Workflow completed successfully") 238 await self.host.finalize_workflow_success(workflow_context) 239 return 240 241 # Check if workflow is stuck 242 if ( 243 not workflow_state.pending_nodes 244 and not workflow_state.active_branches 245 ): 246 # If we have conditional nodes, it's possible some nodes are skipped. 247 # We need to check if we are truly stuck or just finished a path. 248 # For MVP, we assume if pending is empty and not all nodes are done, 249 # but no next nodes are available, we might be done if the remaining 250 # nodes are unreachable due to conditional branches. 251 # However, a simpler check is: are there any nodes running? 252 # If pending_nodes is empty, nothing is running. 253 # If we are not "complete" (all nodes visited), but nothing is running, 254 # and no next nodes, then we are done with this execution path. 255 log.info( 256 f"{log_id} Workflow execution path completed (some nodes may have been skipped)." 257 ) 258 await self.host.finalize_workflow_success(workflow_context) 259 return 260 261 # Wait for pending nodes to complete 262 log.debug( 263 f"{log_id} Waiting for {len(workflow_state.pending_nodes)} pending nodes" 264 ) 265 return # Execution will resume on node completion 266 267 # Execute next nodes with implicit parallelism detection and branch inheritance 268 # Determine parallel group and branch assignments for each node 269 node_parallel_info = {} # node_id -> (parallel_group_id, branch_index) 270 271 if len(next_nodes) > 1: 272 # Multiple nodes ready = implicit parallel execution (like implicit fork) 273 # Check if all nodes share the same single dependency in the same branch 274 # If so, this is a new fork point 275 parallel_group_id = f"implicit_parallel_{workflow_state.execution_id}_{uuid.uuid4().hex[:8]}" 276 log.info(f"{log_id} Implicit parallel execution: {len(next_nodes)} nodes, group={parallel_group_id}") 277 278 # Assign each node to a separate branch 279 for branch_idx, nid in enumerate(next_nodes): 280 node_parallel_info[nid] = (parallel_group_id, branch_idx) 281 # Track in workflow state 282 if parallel_group_id not in workflow_state.parallel_branch_assignments: 283 workflow_state.parallel_branch_assignments[parallel_group_id] = {} 284 workflow_state.parallel_branch_assignments[parallel_group_id][nid] = branch_idx 285 else: 286 # Single node - check if it should inherit a branch from its dependencies 287 for nid in next_nodes: 288 inherited_info = self._get_inherited_branch(nid, workflow_state) 289 if inherited_info: 290 parallel_group_id, branch_idx = inherited_info 291 node_parallel_info[nid] = inherited_info 292 # Track in workflow state 293 workflow_state.parallel_branch_assignments[parallel_group_id][nid] = branch_idx 294 log.debug(f"{log_id} Node {nid} inherits branch {branch_idx} from group {parallel_group_id}") 295 296 for node_id in next_nodes: 297 parallel_info = node_parallel_info.get(node_id) 298 pg_id = parallel_info[0] if parallel_info else None 299 branch_idx = parallel_info[1] if parallel_info else None 300 await self.execute_node(node_id, workflow_state, workflow_context, pg_id, branch_idx) 301 302 # Update pending nodes 303 # Only add if NOT completed (i.e. it was an async node that started) 304 if node_id not in workflow_state.completed_nodes: 305 workflow_state.pending_nodes.append(node_id) 306 307 # Persist state 308 await self.host._update_workflow_state(workflow_context, workflow_state) 309 310 def _get_inherited_branch( 311 self, 312 node_id: str, 313 workflow_state: WorkflowExecutionState, 314 ) -> Optional[tuple]: 315 """ 316 Check if a node should inherit a branch assignment from its dependencies. 317 318 A node inherits a branch if ALL its dependencies are in the same branch 319 of the same parallel group. If dependencies span multiple branches or 320 parallel groups, this is an implicit join point and no inheritance occurs. 321 322 Returns: (parallel_group_id, branch_index) or None 323 """ 324 node = self.nodes[node_id] 325 if not node.depends_on: 326 return None 327 328 # Find which parallel group/branch each dependency is in 329 dep_branches = [] # List of (parallel_group_id, branch_index) for each dep 330 for dep_id in node.depends_on: 331 # Skip dependencies that were skipped (conditional branches) 332 if dep_id in workflow_state.skipped_nodes: 333 continue 334 335 # Find if this dependency is in any parallel group 336 found = False 337 for pg_id, assignments in workflow_state.parallel_branch_assignments.items(): 338 if dep_id in assignments: 339 dep_branches.append((pg_id, assignments[dep_id])) 340 found = True 341 break 342 343 if not found: 344 # Dependency is not in any parallel group 345 dep_branches.append(None) 346 347 # Filter out None entries (deps not in parallel groups) 348 parallel_deps = [b for b in dep_branches if b is not None] 349 350 if not parallel_deps: 351 # No dependencies are in parallel groups 352 return None 353 354 # Check if all parallel dependencies are in the same group and branch 355 first_group, first_branch = parallel_deps[0] 356 for pg_id, branch_idx in parallel_deps[1:]: 357 if pg_id != first_group or branch_idx != first_branch: 358 # Dependencies span multiple branches - this is an implicit join 359 return None 360 361 # All parallel dependencies are in the same branch - inherit it 362 return (first_group, first_branch) 363 364 async def execute_node( 365 self, 366 node_id: str, 367 workflow_state: WorkflowExecutionState, 368 workflow_context: WorkflowExecutionContext, 369 implicit_parallel_group_id: Optional[str] = None, 370 implicit_branch_index: Optional[int] = None, 371 ): 372 """Execute a single workflow node.""" 373 log_id = f"{self.host.log_identifier}[Node:{node_id}]" 374 375 # Check for cancellation before executing node 376 if workflow_context.is_cancelled(): 377 log.info(f"{log_id} Workflow cancelled, not executing node") 378 return 379 380 try: 381 node = self.nodes[node_id] 382 383 # Generate sub-task ID for agent and workflow nodes to link events 384 sub_task_id = None 385 if node.type in ("agent", "workflow"): 386 sub_task_id = f"wf_{workflow_state.execution_id}_{node.id}_{uuid.uuid4().hex[:8]}" 387 388 # Publish start event 389 start_data_args = { 390 "type": "workflow_node_execution_start", 391 "node_id": node_id, 392 "node_type": node.type, 393 "agent_name": getattr(node, "agent_name", None) or getattr(node, "workflow_name", None), 394 "sub_task_id": sub_task_id, 395 } 396 397 # Include implicit parallel group ID and branch index for agent and workflow nodes (used for visualization) 398 if implicit_parallel_group_id and node.type in ("agent", "workflow"): 399 start_data_args["parallel_group_id"] = implicit_parallel_group_id 400 if implicit_branch_index is not None: 401 start_data_args["iteration_index"] = implicit_branch_index 402 403 if node.type == "switch": 404 # Include switch case info for visualization 405 from ..common.data_parts import SwitchCaseInfo 406 start_data_args["cases"] = [ 407 SwitchCaseInfo(condition=case.condition, node=case.node) 408 for case in node.cases 409 ] 410 start_data_args["default_branch"] = node.default 411 412 elif node.type == "loop": 413 # Include loop configuration for visualization 414 start_data_args["condition"] = node.condition 415 start_data_args["max_iterations"] = node.max_iterations 416 start_data_args["loop_delay"] = node.delay 417 418 # Generate parallel_group_id for map nodes so the frontend can group children 419 parallel_group_id = None 420 if node.type == "map": 421 parallel_group_id = f"map_{node.id}_{workflow_state.execution_id}" 422 start_data_args["parallel_group_id"] = parallel_group_id 423 424 start_data = WorkflowNodeExecutionStartData(**start_data_args) 425 await self.host.publish_workflow_event(workflow_context, start_data) 426 427 # Handle different node types 428 if node.type == "agent": 429 await self._execute_agent_node(node, workflow_state, workflow_context, sub_task_id) 430 elif node.type == "workflow": 431 await self._execute_workflow_node(node, workflow_state, workflow_context, sub_task_id) 432 elif node.type == "switch": 433 await self._execute_switch_node(node, workflow_state, workflow_context) 434 elif node.type == "loop": 435 await self._execute_loop_node(node, workflow_state, workflow_context) 436 elif node.type == "map": 437 await self._execute_map_node(node, workflow_state, workflow_context, parallel_group_id) 438 else: 439 raise ValueError(f"Unknown node type: {node.type}") 440 441 except Exception as e: 442 log.error(f"{log_id} Node execution failed: {e}") 443 444 # Set error state 445 workflow_state.error_state = { 446 "failed_node_id": node_id, 447 "failure_reason": "execution_error", 448 "error_message": str(e), 449 "timestamp": datetime.now(timezone.utc).isoformat(), 450 } 451 452 # Propagate error 453 raise WorkflowNodeFailureError(node_id, str(e)) from e 454 455 async def _execute_agent_node( 456 self, 457 node: AgentNode, 458 workflow_state: WorkflowExecutionState, 459 workflow_context: WorkflowExecutionContext, 460 sub_task_id: Optional[str] = None, 461 ): 462 """Execute an agent node by calling the agent.""" 463 log_id = f"{self.host.log_identifier}[Agent:{node.id}]" 464 465 # Check 'when' clause if present (Argo-style conditional) 466 if node.when: 467 from .flow_control.conditional import evaluate_condition 468 469 try: 470 should_execute = evaluate_condition(node.when, workflow_state) 471 except Exception as e: 472 log.warning(f"{log_id} 'when' clause evaluation failed: {e}") 473 should_execute = False 474 475 if not should_execute: 476 log.info( 477 f"{log_id} Skipping node due to 'when' clause: {node.when}" 478 ) 479 # Mark as skipped 480 workflow_state.skipped_nodes[node.id] = f"when_clause_false: {node.when}" 481 workflow_state.completed_nodes[node.id] = "SKIPPED_BY_WHEN" 482 workflow_state.node_outputs[node.id] = { 483 "output": None, 484 "skipped": True, 485 "skip_reason": "when_clause_false", 486 } 487 488 # Publish skipped event 489 result_data = WorkflowNodeExecutionResultData( 490 type="workflow_node_execution_result", 491 node_id=node.id, 492 status="skipped", 493 metadata={"skip_reason": "when_clause_false", "when": node.when}, 494 ) 495 await self.host.publish_workflow_event(workflow_context, result_data) 496 497 # Continue workflow 498 await self.execute_workflow(workflow_state, workflow_context) 499 return 500 501 await self.host.agent_caller.call_agent( 502 node, workflow_state, workflow_context, sub_task_id 503 ) 504 505 async def _execute_workflow_node( 506 self, 507 node: WorkflowInvokeNode, 508 workflow_state: WorkflowExecutionState, 509 workflow_context: WorkflowExecutionContext, 510 sub_task_id: Optional[str] = None, 511 ): 512 """Execute a workflow node by calling the sub-workflow.""" 513 log_id = f"{self.host.log_identifier}[Workflow:{node.id}]" 514 515 # Check for direct recursion (workflow invoking itself) 516 if node.workflow_name == self.host.workflow_name: 517 error_msg = ( 518 f"Direct recursion detected: workflow '{self.host.workflow_name}' " 519 f"cannot invoke itself via node '{node.id}'" 520 ) 521 log.error(f"{log_id} {error_msg}") 522 raise WorkflowExecutionError(error_msg) 523 524 # Check 'when' clause if present (Argo-style conditional) 525 if node.when: 526 from .flow_control.conditional import evaluate_condition 527 528 try: 529 should_execute = evaluate_condition(node.when, workflow_state) 530 except Exception as e: 531 log.warning(f"{log_id} 'when' clause evaluation failed: {e}") 532 should_execute = False 533 534 if not should_execute: 535 log.info( 536 f"{log_id} Skipping workflow node due to 'when' clause: {node.when}" 537 ) 538 # Mark as skipped 539 workflow_state.skipped_nodes[node.id] = f"when_clause_false: {node.when}" 540 workflow_state.completed_nodes[node.id] = "SKIPPED_BY_WHEN" 541 workflow_state.node_outputs[node.id] = { 542 "output": None, 543 "skipped": True, 544 "skip_reason": "when_clause_false", 545 } 546 547 # Publish skipped event 548 result_data = WorkflowNodeExecutionResultData( 549 type="workflow_node_execution_result", 550 node_id=node.id, 551 status="skipped", 552 metadata={"skip_reason": "when_clause_false", "when": node.when}, 553 ) 554 await self.host.publish_workflow_event(workflow_context, result_data) 555 556 # Continue workflow 557 await self.execute_workflow(workflow_state, workflow_context) 558 return 559 560 # Call the sub-workflow using the agent caller 561 # Workflows register as agents, so we use call_workflow which adapts the invocation 562 await self.host.agent_caller.call_workflow( 563 node, workflow_state, workflow_context, sub_task_id 564 ) 565 566 async def _execute_switch_node( 567 self, 568 node: SwitchNode, 569 workflow_state: WorkflowExecutionState, 570 workflow_context: WorkflowExecutionContext, 571 ): 572 """Execute switch node for multi-way branching.""" 573 log_id = f"{self.host.log_identifier}[Switch:{node.id}]" 574 575 from .flow_control.conditional import evaluate_condition 576 577 selected_branch = None 578 selected_case_index = None 579 580 # Evaluate cases in order, first match wins 581 for i, case in enumerate(node.cases): 582 try: 583 result = evaluate_condition(case.condition, workflow_state) 584 if result: 585 selected_branch = case.node 586 selected_case_index = i 587 log.info( 588 f"{log_id} Case {i} condition '{case.condition}' matched, " 589 f"selecting branch '{case.node}'" 590 ) 591 break 592 except Exception as e: 593 log.warning(f"{log_id} Case {i} evaluation failed: {e}") 594 continue 595 596 # Use default if no case matched 597 if selected_branch is None and node.default: 598 selected_branch = node.default 599 log.info(f"{log_id} No case matched, using default branch '{node.default}'") 600 601 # Mark switch as complete 602 workflow_state.completed_nodes[node.id] = "switch_evaluated" 603 workflow_state.node_outputs[node.id] = { 604 "output": { 605 "selected_branch": selected_branch, 606 "selected_case_index": selected_case_index, 607 } 608 } 609 610 # Publish result event 611 result_data = WorkflowNodeExecutionResultData( 612 type="workflow_node_execution_result", 613 node_id=node.id, 614 status="success", 615 metadata={ 616 "selected_branch": selected_branch, 617 "selected_case_index": selected_case_index, 618 }, 619 ) 620 await self.host.publish_workflow_event(workflow_context, result_data) 621 622 # Skip all non-selected branches 623 all_branches = [case.node for case in node.cases] 624 if node.default: 625 all_branches.append(node.default) 626 627 for branch_id in all_branches: 628 if branch_id != selected_branch: 629 await self._skip_branch(branch_id, workflow_state) 630 631 # Continue execution 632 await self.execute_workflow(workflow_state, workflow_context) 633 634 async def _execute_loop_node( 635 self, 636 node: LoopNode, 637 workflow_state: WorkflowExecutionState, 638 workflow_context: WorkflowExecutionContext, 639 ): 640 """Execute loop node for while-loop iteration.""" 641 log_id = f"{self.host.log_identifier}[Loop:{node.id}]" 642 643 # Check for cancellation before starting loop iteration 644 if workflow_context.is_cancelled(): 645 log.info(f"{log_id} Workflow cancelled, not starting loop iteration") 646 return 647 648 from .flow_control.conditional import evaluate_condition 649 from .utils import parse_duration 650 651 # Initialize or get iteration count 652 if node.id not in workflow_state.loop_iterations: 653 workflow_state.loop_iterations[node.id] = 0 654 655 iteration = workflow_state.loop_iterations[node.id] 656 657 # Check max iterations 658 if iteration >= node.max_iterations: 659 log.warning( 660 f"{log_id} Max iterations ({node.max_iterations}) reached, stopping loop" 661 ) 662 workflow_state.completed_nodes[node.id] = "loop_max_iterations" 663 workflow_state.node_outputs[node.id] = { 664 "output": { 665 "iterations_completed": iteration, 666 "stopped_reason": "max_iterations", 667 } 668 } 669 # Continue workflow 670 await self.execute_workflow(workflow_state, workflow_context) 671 return 672 673 # Evaluate loop condition 674 # On the first iteration (iteration=0), skip condition check and always run 675 # This makes the loop behave like a "do-while" - condition is checked after first run 676 if iteration == 0: 677 should_continue = True 678 else: 679 try: 680 should_continue = evaluate_condition(node.condition, workflow_state) 681 except Exception as e: 682 log.error(f"{log_id} Loop condition evaluation failed: {e}") 683 should_continue = False 684 685 if not should_continue: 686 workflow_state.completed_nodes[node.id] = "loop_condition_false" 687 if node.id in workflow_state.pending_nodes: 688 workflow_state.pending_nodes.remove(node.id) 689 workflow_state.node_outputs[node.id] = { 690 "output": { 691 "iterations_completed": iteration, 692 "stopped_reason": "condition_false", 693 } 694 } 695 696 # Publish result event for the loop node completion 697 result_data = WorkflowNodeExecutionResultData( 698 type="workflow_node_execution_result", 699 node_id=node.id, 700 status="success", 701 metadata={ 702 "iterations_completed": iteration, 703 "stopped_reason": "condition_false", 704 }, 705 ) 706 await self.host.publish_workflow_event(workflow_context, result_data) 707 708 # Continue workflow 709 await self.execute_workflow(workflow_state, workflow_context) 710 return 711 712 # Apply delay if configured 713 if node.delay and iteration > 0: # No delay on first iteration 714 delay_seconds = parse_duration(node.delay) 715 log.debug(f"{log_id} Applying delay of {delay_seconds}s before iteration {iteration}") 716 await asyncio.sleep(delay_seconds) 717 718 # Increment iteration count 719 workflow_state.loop_iterations[node.id] = iteration + 1 720 721 log.info(f"{log_id} Starting iteration {iteration + 1}") 722 723 # Execute inner node 724 target_node = self.nodes[node.node] 725 iter_node = target_node.model_copy() 726 # Assign unique ID for this iteration 727 iter_node.id = f"{node.id}_iter_{iteration}" 728 729 # Store iteration context 730 workflow_state.node_outputs["_loop_iteration"] = {"output": iteration} 731 732 # Generate sub-task ID 733 sub_task_id = f"wf_{workflow_state.execution_id}_{iter_node.id}_{uuid.uuid4().hex[:8]}" 734 735 # Handle both AgentInvokeNode (has agent_name) and WorkflowInvokeNode (has workflow_name) 736 target_name = getattr(iter_node, "agent_name", getattr(iter_node, "workflow_name", None)) 737 738 # Emit start event for loop iteration child 739 start_data = WorkflowNodeExecutionStartData( 740 type="workflow_node_execution_start", 741 node_id=iter_node.id, 742 node_type="agent", 743 agent_name=target_name, 744 iteration_index=iteration, 745 sub_task_id=sub_task_id, 746 parent_node_id=node.id, 747 ) 748 await self.host.publish_workflow_event(workflow_context, start_data) 749 750 # Track in active branches for completion handling 751 workflow_state.active_branches[node.id] = [ 752 { 753 "iteration": iteration, 754 "sub_task_id": sub_task_id, 755 "type": "loop", 756 } 757 ] 758 759 # Execute the inner node 760 if iter_node.type == "workflow": 761 await self.host.agent_caller.call_workflow( 762 iter_node, workflow_state, workflow_context, sub_task_id=sub_task_id 763 ) 764 else: 765 await self.host.agent_caller.call_agent( 766 iter_node, workflow_state, workflow_context, sub_task_id=sub_task_id 767 ) 768 769 async def _skip_branch( 770 self, node_id: str, workflow_state: WorkflowExecutionState 771 ): 772 """Recursively mark a branch as skipped.""" 773 if node_id in workflow_state.completed_nodes: 774 return 775 776 # Mark as skipped (using a special value in completed_nodes) 777 workflow_state.completed_nodes[node_id] = "SKIPPED" 778 779 # Publish skipped event (optional, but good for visualization) 780 # We need context to publish, but _skip_branch doesn't have it passed down. 781 # For now, we skip publishing "skipped" events to avoid signature changes, 782 # or we can rely on the UI inferring it from the conditional result. 783 # Actually, let's leave it implicit for now. 784 785 # Find children 786 children = self.reverse_dependencies.get(node_id, []) 787 for child_id in children: 788 # Only skip child if ALL its dependencies are skipped 789 child_deps = self.dependencies.get(child_id, []) 790 791 all_deps_skipped = True 792 for dep in child_deps: 793 # If dependency is not completed, or completed but not skipped, then child might still run 794 if dep not in workflow_state.completed_nodes: 795 all_deps_skipped = False 796 break 797 if workflow_state.completed_nodes[dep] != "SKIPPED": 798 all_deps_skipped = False 799 break 800 801 if all_deps_skipped: 802 await self._skip_branch(child_id, workflow_state) 803 804 async def _execute_map_node( 805 self, 806 node: MapNode, 807 workflow_state: WorkflowExecutionState, 808 workflow_context: WorkflowExecutionContext, 809 parallel_group_id: str, 810 ): 811 """Execute map node with concurrency control.""" 812 log_id = f"{self.host.log_identifier}[Map:{node.id}]" 813 814 # Resolve items array 815 items = self.resolve_value(node.items, workflow_state) 816 817 if items is None: 818 log.warning(f"{log_id} Map target resolved to None. Treating as empty list.") 819 items = [] 820 821 if not isinstance(items, list): 822 raise ValueError(f"Map target must be array, got: {type(items)}") 823 824 # Check item limit 825 max_items = ( 826 node.max_items or self.host.get_config("default_max_map_items", 100) 827 ) 828 829 if len(items) > max_items: 830 raise WorkflowExecutionError( 831 f"Map '{node.id}' exceeds max items: " f"{len(items)} > {max_items}" 832 ) 833 834 log.info(f"{log_id} Starting map with {len(items)} items") 835 836 # Initialize tracking state 837 # We store the full list of items and their status 838 map_state = { 839 "items": items, 840 "results": [None] * len(items), # Placeholders for results 841 "pending_indices": list(range(len(items))), # Indices waiting to run 842 "active_indices": set(), # Indices currently running 843 "completed_count": 0, 844 "concurrency_limit": node.concurrency_limit, 845 "target_node_id": node.node, 846 "parallel_group_id": parallel_group_id, 847 } 848 849 # Store in active_branches (using a dict instead of list for map state) 850 # Store map state in metadata field (designed for node-specific extensible state). 851 # active_branches tracks the currently executing sub-tasks for this map node. 852 workflow_state.metadata[f"map_state_{node.id}"] = map_state 853 workflow_state.active_branches[node.id] = [] 854 855 # Launch initial batch 856 await self._launch_map_iterations(node.id, workflow_state, workflow_context) 857 858 async def _launch_map_iterations( 859 self, 860 map_node_id: str, 861 workflow_state: WorkflowExecutionState, 862 workflow_context: WorkflowExecutionContext, 863 ): 864 """Launch pending map iterations up to concurrency limit.""" 865 # Check for cancellation before launching new iterations 866 if workflow_context.is_cancelled(): 867 log.info(f"{self.host.log_identifier}[Map:{map_node_id}] Workflow cancelled, not launching new iterations") 868 return 869 870 map_state = workflow_state.metadata.get(f"map_state_{map_node_id}") 871 if not map_state: 872 return 873 874 concurrency_limit = map_state["concurrency_limit"] 875 active_indices = map_state["active_indices"] 876 pending_indices = map_state["pending_indices"] 877 items = map_state["items"] 878 target_node_id = map_state["target_node_id"] 879 parallel_group_id = map_state.get("parallel_group_id") 880 881 # Determine how many to launch 882 while pending_indices: 883 if concurrency_limit and len(active_indices) >= concurrency_limit: 884 break 885 886 index = pending_indices.pop(0) 887 item = items[index] 888 active_indices.add(index) 889 890 # Create iteration state 891 iteration_state = workflow_state.model_copy(deep=False) 892 iteration_state.node_outputs = { 893 **workflow_state.node_outputs, 894 "_map_item": {"output": item}, 895 "_map_index": {"output": index}, 896 } 897 898 target_node = self.nodes[target_node_id] 899 iter_node = target_node.model_copy() 900 # Assign unique ID for this iteration to ensure distinct events and tracking 901 iter_node.id = f"{map_node_id}_{index}" 902 903 # Generate sub-task ID 904 sub_task_id = f"wf_{workflow_state.execution_id}_{iter_node.id}_{uuid.uuid4().hex[:8]}" 905 906 # Emit start event for iteration BEFORE execution 907 target_name = getattr(iter_node, "agent_name", getattr(iter_node, "workflow_name", None)) 908 start_data = WorkflowNodeExecutionStartData( 909 type="workflow_node_execution_start", 910 node_id=iter_node.id, 911 node_type="agent", 912 agent_name=target_name, 913 iteration_index=index, 914 sub_task_id=sub_task_id, 915 parent_node_id=map_node_id, 916 parallel_group_id=parallel_group_id, 917 ) 918 await self.host.publish_workflow_event(workflow_context, start_data) 919 920 # Execute 921 if iter_node.type == "workflow": 922 await self.host.agent_caller.call_workflow( 923 iter_node, iteration_state, workflow_context, sub_task_id=sub_task_id 924 ) 925 else: 926 await self.host.agent_caller.call_agent( 927 iter_node, iteration_state, workflow_context, sub_task_id=sub_task_id 928 ) 929 930 # Track active sub-task 931 workflow_state.active_branches[map_node_id].append( 932 { 933 "index": index, 934 "sub_task_id": sub_task_id, 935 } 936 ) 937 938 def resolve_value( 939 self, value_def: Any, workflow_state: WorkflowExecutionState 940 ) -> Any: 941 """ 942 Resolve a value definition, handling templates and operators. 943 Supports: 944 - Literal values 945 - Template strings: "{{...}}" 946 - Operators: coalesce, concat 947 - Nested dicts and lists (recursively resolved) 948 """ 949 # Handle template string 950 if isinstance(value_def, str) and value_def.startswith("{{"): 951 return self._resolve_template(value_def, workflow_state) 952 953 # Handle intrinsic functions (operators) 954 if isinstance(value_def, dict) and len(value_def) == 1: 955 op = next(iter(value_def)) 956 args = value_def[op] 957 958 if op == "coalesce": 959 if not isinstance(args, list): 960 raise ValueError("'coalesce' operator requires a list of values") 961 962 for arg in args: 963 resolved = self.resolve_value(arg, workflow_state) 964 if resolved is not None: 965 return resolved 966 return None 967 968 if op == "concat": 969 if not isinstance(args, list): 970 raise ValueError("'concat' operator requires a list of values") 971 972 parts = [] 973 for arg in args: 974 resolved = self.resolve_value(arg, workflow_state) 975 if resolved is not None: 976 parts.append(str(resolved)) 977 return "".join(parts) 978 979 # Handle nested dicts - recursively resolve all values 980 if isinstance(value_def, dict): 981 resolved_dict = {} 982 for key, value in value_def.items(): 983 resolved_dict[key] = self.resolve_value(value, workflow_state) 984 return resolved_dict 985 986 # Handle nested lists - recursively resolve all items 987 if isinstance(value_def, list): 988 return [self.resolve_value(item, workflow_state) for item in value_def] 989 990 # Return literal 991 return value_def 992 993 def _resolve_template( 994 self, template: str, workflow_state: WorkflowExecutionState 995 ) -> Any: 996 """ 997 Resolve template variable. 998 Format: {{node_id.output.field_path}} or {{workflow.input.field_path}} 999 1000 Supports Argo-style aliases: 1001 - {{item}} -> {{_map_item}} 1002 - {{workflow.parameters.x}} -> {{workflow.input.x}} 1003 """ 1004 # Apply Argo-compatible aliases 1005 from .flow_control.conditional import _apply_template_aliases 1006 1007 template = _apply_template_aliases(template) 1008 1009 # Extract variable path 1010 # Use fullmatch to ensure the template takes up the entire string 1011 # and handle optional whitespace inside braces: {{ value }} 1012 match = re.fullmatch(r"\{\{\s*(.+?)\s*\}\}", template) 1013 if not match: 1014 return template 1015 1016 path = match.group(1) 1017 parts = path.split(".") 1018 1019 # Navigate path in workflow state 1020 if parts[0] == "workflow" and parts[1] == "input": 1021 # Reference to workflow input 1022 # Workflow input is stored in node_outputs["workflow_input"] 1023 if "workflow_input" not in workflow_state.node_outputs: 1024 raise ValueError("Workflow input has not been initialized") 1025 1026 # Navigate from workflow_input.output.field_path 1027 data = workflow_state.node_outputs["workflow_input"]["output"] 1028 for part in parts[2:]: # Skip "workflow" and "input" 1029 if isinstance(data, dict) and part in data: 1030 data = data[part] 1031 else: 1032 # Return None if input field is missing (allows coalesce to work) 1033 return None 1034 return data 1035 else: 1036 # Reference to node output 1037 node_id = parts[0] 1038 if node_id not in workflow_state.node_outputs: 1039 # Check if it's a map/loop variable 1040 if node_id in ["_map_item", "_map_index", "_loop_iteration"]: 1041 pass # Allow it 1042 else: 1043 # Return None for skipped/incomplete nodes to allow for safe navigation/coalescing 1044 return None 1045 1046 # Navigate remaining path 1047 # Special handling for map/loop variables: unwrap 'output' immediately 1048 if node_id in ["_map_item", "_map_index", "_loop_iteration"]: 1049 data = workflow_state.node_outputs[node_id].get("output") 1050 else: 1051 data = workflow_state.node_outputs[node_id] 1052 1053 for part in parts[1:]: 1054 if isinstance(data, dict) and part in data: 1055 data = data[part] 1056 else: 1057 raise ValueError( 1058 f"Output field '{part}' not found in node '{node_id}' for path: {path}" 1059 ) 1060 1061 return data 1062 1063 async def handle_node_completion( 1064 self, 1065 workflow_context: WorkflowExecutionContext, 1066 sub_task_id: str, 1067 result: StructuredInvocationResult, 1068 ): 1069 """Handle completion of a workflow node.""" 1070 log_id = f"{self.host.log_identifier}[Workflow:{workflow_context.workflow_task_id}]" 1071 1072 # Check for cancellation - don't process results if workflow is cancelled 1073 if workflow_context.is_cancelled(): 1074 log.info(f"{log_id} Workflow cancelled, ignoring node completion for sub-task {sub_task_id}") 1075 return 1076 1077 # Find which node this sub-task corresponds to 1078 node_id = workflow_context.get_node_id_for_sub_task(sub_task_id) 1079 1080 if not node_id: 1081 log.error(f"{log_id} Received result for unknown sub-task: {sub_task_id}") 1082 return 1083 1084 workflow_state = workflow_context.workflow_state 1085 1086 # Check result status 1087 if result.status == "error": 1088 log.error(f"{log_id} Node '{node_id}' failed: {result.error_message}") 1089 1090 # Set error state 1091 workflow_state.error_state = { 1092 "failed_node_id": node_id, 1093 "failure_reason": "node_execution_failed", 1094 "error_message": result.error_message, 1095 "timestamp": datetime.now(timezone.utc).isoformat(), 1096 } 1097 1098 # Fail entire workflow 1099 await self.host.finalize_workflow_failure( 1100 workflow_context, 1101 WorkflowNodeFailureError(node_id, result.error_message), 1102 ) 1103 return 1104 1105 # Node succeeded 1106 log.debug(f"{log_id} Node '{node_id}' completed successfully") 1107 1108 # Publish success event 1109 result_data = WorkflowNodeExecutionResultData( 1110 type="workflow_node_execution_result", 1111 node_id=node_id, 1112 status="success", 1113 output_artifact_ref=result.output_artifact_ref, 1114 ) 1115 await self.host.publish_workflow_event(workflow_context, result_data) 1116 1117 # Check if this was part of a Fork or Map 1118 # We need to find if this node_id is being tracked in active_branches 1119 # But wait, node_id is the ID of the node definition. 1120 # For Fork, the node_id IS the branch ID. 1121 # For Map, the node_id IS the map body node ID. 1122 1123 # Check if this node is part of an active fork/map 1124 parent_control_node_id = None 1125 for control_node_id, branches in workflow_state.active_branches.items(): 1126 for branch in branches: 1127 if branch.get("sub_task_id") == sub_task_id: 1128 parent_control_node_id = control_node_id 1129 break 1130 if parent_control_node_id: 1131 break 1132 1133 if parent_control_node_id: 1134 await self._handle_control_node_child_completion( 1135 parent_control_node_id, 1136 sub_task_id, 1137 result, 1138 workflow_state, 1139 workflow_context, 1140 ) 1141 else: 1142 # Standard node completion 1143 artifact_name = result.output_artifact_ref.name if result.output_artifact_ref else None 1144 workflow_state.completed_nodes[node_id] = artifact_name 1145 if node_id in workflow_state.pending_nodes: 1146 workflow_state.pending_nodes.remove(node_id) 1147 1148 # Cache node output for value references 1149 if result.output_artifact_ref and result.output_artifact_ref.name: 1150 artifact_data = await self.host._load_node_output( 1151 node_id, 1152 result.output_artifact_ref.name, 1153 result.output_artifact_ref.version, 1154 workflow_context, 1155 ) 1156 workflow_state.node_outputs[node_id] = {"output": artifact_data} 1157 1158 # Continue workflow execution 1159 await self.execute_workflow(workflow_state, workflow_context) 1160 1161 async def _handle_control_node_child_completion( 1162 self, 1163 control_node_id: str, 1164 sub_task_id: str, 1165 result: StructuredInvocationResult, 1166 workflow_state: WorkflowExecutionState, 1167 workflow_context: WorkflowExecutionContext, 1168 ): 1169 """Handle completion of a child task within a Fork or Map.""" 1170 log_id = f"{self.host.log_identifier}[ControlNode:{control_node_id}]" 1171 1172 # Check for cancellation - don't continue processing if workflow is cancelled 1173 if workflow_context.is_cancelled(): 1174 log.info(f"{log_id} Workflow cancelled, not processing child completion for sub-task {sub_task_id}") 1175 return 1176 1177 branches = workflow_state.active_branches.get(control_node_id, []) 1178 1179 # Find the specific branch/iteration 1180 completed_branch = None 1181 for branch in branches: 1182 if branch["sub_task_id"] == sub_task_id: 1183 completed_branch = branch 1184 break 1185 1186 if not completed_branch: 1187 log.error(f"{log_id} Could not find branch for sub-task {sub_task_id}") 1188 return 1189 1190 # Check for duplicate completion 1191 if "result" in completed_branch: 1192 log.warning( 1193 f"{log_id} Sub-task {sub_task_id} already completed. Ignoring duplicate response." 1194 ) 1195 return 1196 1197 # Update result 1198 completed_branch["result"] = { 1199 "artifact_name": result.output_artifact_ref.name if result.output_artifact_ref else None, 1200 "artifact_version": result.output_artifact_ref.version if result.output_artifact_ref else None, 1201 } 1202 1203 control_node = self.nodes[control_node_id] 1204 1205 if control_node.type == "loop": 1206 # Handle Loop iteration completion 1207 iteration = completed_branch.get("iteration") 1208 log.info(f"{log_id} Loop iteration {iteration} completed") 1209 1210 # Load result and store in node_outputs for condition evaluation 1211 if result.output_artifact_ref and result.output_artifact_ref.name: 1212 artifact_data = await self.host._load_node_output( 1213 node_id=control_node_id, 1214 artifact_name=result.output_artifact_ref.name, 1215 artifact_version=result.output_artifact_ref.version, 1216 workflow_context=workflow_context, 1217 sub_task_id=sub_task_id, 1218 ) 1219 # Store result under the inner node's original ID so conditions can reference it 1220 # e.g., {{check_task_status.output.ready}} will find the result 1221 inner_node_id = control_node.node # The original inner node ID from workflow definition 1222 workflow_state.node_outputs[inner_node_id] = { 1223 "output": artifact_data 1224 } 1225 log.debug(f"{log_id} Stored loop iteration result under '{inner_node_id}'") 1226 1227 # Clear active branches for this loop 1228 del workflow_state.active_branches[control_node_id] 1229 1230 # Re-execute loop node to check condition for next iteration 1231 await self._execute_loop_node( 1232 control_node, workflow_state, workflow_context 1233 ) 1234 elif control_node.type == "map": 1235 # Handle Map logic (concurrency, state update) 1236 map_state = workflow_state.metadata.get(f"map_state_{control_node_id}") 1237 if map_state: 1238 index = completed_branch["index"] 1239 # Safely remove index (idempotency check above should prevent this, but being safe) 1240 if index in map_state["active_indices"]: 1241 map_state["active_indices"].remove(index) 1242 else: 1243 log.warning(f"{log_id} Index {index} not found in active_indices during completion.") 1244 1245 map_state["completed_count"] += 1 1246 # Store result in map_state for final aggregation 1247 map_state["results"][index] = completed_branch 1248 1249 # Publish map progress 1250 progress_data = WorkflowMapProgressData( 1251 type="workflow_map_progress", 1252 node_id=control_node_id, 1253 total_items=len(map_state["items"]), 1254 completed_items=map_state["completed_count"], 1255 status="in-progress", 1256 ) 1257 await self.host.publish_workflow_event(workflow_context, progress_data) 1258 1259 # Launch next pending items 1260 await self._launch_map_iterations( 1261 control_node_id, workflow_state, workflow_context 1262 ) 1263 1264 # Check if ALL items are complete 1265 if map_state["completed_count"] == len(map_state["items"]): 1266 log.info(f"{log_id} All map items completed") 1267 await self._finalize_map_node( 1268 control_node_id, map_state, workflow_state, workflow_context 1269 ) 1270 1271 async def _finalize_map_node( 1272 self, 1273 map_node_id: str, 1274 map_state: Dict, 1275 workflow_state: WorkflowExecutionState, 1276 workflow_context: WorkflowExecutionContext, 1277 ): 1278 """Aggregate map results.""" 1279 log_id = f"{self.host.log_identifier}[Map:{map_node_id}]" 1280 1281 results_list = [] 1282 # map_state["results"] is already ordered by index 1283 for iter_info in map_state["results"]: 1284 if not iter_info or "result" not in iter_info: 1285 # Should not happen if completed_count check is correct 1286 log.error(f"{log_id} Missing result for item") 1287 results_list.append(None) 1288 continue 1289 1290 artifact_name = iter_info["result"]["artifact_name"] 1291 artifact_version = iter_info["result"]["artifact_version"] 1292 1293 artifact_data = await self.host._load_node_output( 1294 node_id=map_node_id, 1295 artifact_name=artifact_name, 1296 artifact_version=artifact_version, 1297 workflow_context=workflow_context, 1298 sub_task_id=iter_info["sub_task_id"], 1299 ) 1300 results_list.append(artifact_data) 1301 1302 # Create aggregated artifact 1303 merged_artifact_name = f"map_{map_node_id}_results.json" 1304 merged_bytes = json.dumps({"results": results_list}).encode("utf-8") 1305 1306 await save_artifact_with_metadata( 1307 artifact_service=self.host.artifact_service, 1308 app_name=self.host.workflow_name, 1309 user_id=workflow_context.a2a_context["user_id"], 1310 session_id=workflow_context.a2a_context["session_id"], 1311 filename=merged_artifact_name, 1312 content_bytes=merged_bytes, 1313 mime_type="application/json", 1314 metadata_dict={ 1315 "description": f"Aggregated results from map node '{map_node_id}'", 1316 "source": "workflow_map_aggregate", 1317 "node_id": map_node_id, 1318 }, 1319 timestamp=datetime.now(timezone.utc), 1320 tags=[ARTIFACT_TAG_WORKING], 1321 ) 1322 1323 # Publish result event 1324 result_data = WorkflowNodeExecutionResultData( 1325 type="workflow_node_execution_result", 1326 node_id=map_node_id, 1327 status="success", 1328 output_artifact_ref=ArtifactRef(name=merged_artifact_name), 1329 ) 1330 await self.host.publish_workflow_event(workflow_context, result_data) 1331 1332 workflow_state.completed_nodes[map_node_id] = merged_artifact_name 1333 if map_node_id in workflow_state.pending_nodes: 1334 workflow_state.pending_nodes.remove(map_node_id) 1335 workflow_state.node_outputs[map_node_id] = {"output": {"results": results_list}} 1336 1337 # Cleanup state 1338 del workflow_state.active_branches[map_node_id] 1339 del workflow_state.metadata[f"map_state_{map_node_id}"] 1340 1341 await self.execute_workflow(workflow_state, workflow_context)