/ src / solace_agent_mesh / workflow / agent_caller.py
agent_caller.py
  1  """
  2  AgentCaller component for invoking agents via A2A.
  3  """
  4  
  5  import logging
  6  import uuid
  7  import re
  8  import json
  9  from datetime import datetime, timezone
 10  from typing import Any, Dict, Optional, TYPE_CHECKING
 11  
 12  from a2a.types import MessageSendParams, SendMessageRequest, Message as A2AMessage
 13  
 14  from ..common import a2a
 15  from ..common.constants import ARTIFACT_TAG_WORKING
 16  from ..common.data_parts import StructuredInvocationRequest
 17  from ..common.agent_card_utils import get_schemas_from_agent_card
 18  from ..agent.utils.artifact_helpers import (
 19      save_artifact_with_metadata,
 20      format_artifact_uri,
 21  )
 22  from .app import WorkflowNode, WorkflowInvokeNode
 23  from .workflow_execution_context import WorkflowExecutionContext, WorkflowExecutionState
 24  
 25  if TYPE_CHECKING:
 26      from .component import WorkflowExecutorComponent
 27  
 28  log = logging.getLogger(__name__)
 29  
 30  
 31  class AgentCaller:
 32      """Manages A2A calls to agents from workflow."""
 33  
 34      def __init__(self, host_component: "WorkflowExecutorComponent"):
 35          self.host = host_component
 36  
 37      def _resolve_string_with_templates(
 38          self, template_string: str, workflow_state: WorkflowExecutionState
 39      ) -> Optional[str]:
 40          """
 41          Resolve a string that may contain embedded template expressions.
 42  
 43          Unlike dag_executor.resolve_value which only handles strings that ARE templates,
 44          this method handles strings that CONTAIN templates (e.g., "Hello {{name}}!").
 45  
 46          Args:
 47              template_string: A string that may contain {{...}} template expressions
 48              workflow_state: Current workflow state for resolving variables
 49  
 50          Returns:
 51              The string with all template expressions resolved, or None if resolution fails
 52          """
 53          if not template_string:
 54              return None
 55  
 56          # Pattern to match {{...}} template expressions
 57          template_pattern = re.compile(r"\{\{\s*(.+?)\s*\}\}")
 58  
 59          def replace_template(match: re.Match) -> str:
 60              """Replace a single template match with its resolved value."""
 61              full_match = match.group(0)  # The full {{...}} string
 62              try:
 63                  # Use dag_executor to resolve the full template
 64                  resolved = self.host.dag_executor.resolve_value(
 65                      full_match, workflow_state
 66                  )
 67                  if resolved is None:
 68                      # Keep the original template if resolution fails
 69                      return full_match
 70                  return str(resolved)
 71              except Exception as e:
 72                  log.warning(
 73                      f"{self.host.log_identifier} Failed to resolve template "
 74                      f"'{full_match}': {e}"
 75                  )
 76                  return full_match
 77  
 78          # Replace all template expressions in the string
 79          resolved = template_pattern.sub(replace_template, template_string)
 80          return resolved
 81  
 82      async def call_agent(
 83          self,
 84          node: WorkflowNode,
 85          workflow_state: WorkflowExecutionState,
 86          workflow_context: WorkflowExecutionContext,
 87          sub_task_id: Optional[str] = None,
 88      ) -> str:
 89          """
 90          Invoke an agent for a workflow node.
 91          Returns sub-task ID for correlation.
 92          """
 93          log_id = f"{self.host.log_identifier}[CallAgent:{node.agent_name}]"
 94  
 95          # Generate sub-task ID if not provided
 96          if not sub_task_id:
 97              sub_task_id = (
 98                  f"wf_{workflow_state.execution_id}_{node.id}_{uuid.uuid4().hex[:8]}"
 99              )
100          # Resolve input data
101          input_data = await self._resolve_node_input(node, workflow_state)
102  
103          # Resolve instruction template if present
104          # Handles both full templates ({{...}}) and embedded templates within strings
105          resolved_instruction = None
106          if hasattr(node, "instruction") and node.instruction:
107              resolved_instruction = self._resolve_string_with_templates(
108                  node.instruction, workflow_state
109              )
110  
111          # Get agent card - required for proper structured invocation
112          agent_card = self.host.agent_registry.get_agent(node.agent_name)
113          if not agent_card:
114              raise ValueError(
115                  f"Agent '{node.agent_name}' not found in registry. "
116                  f"Ensure the agent is running and has published its agent card before "
117                  f"starting the workflow."
118              )
119  
120          # Get schemas from agent card extensions
121          card_input_schema, card_output_schema = get_schemas_from_agent_card(agent_card)
122  
123          # Use override schemas if provided, otherwise use schemas from agent card
124          input_schema = node.input_schema_override or card_input_schema
125          output_schema = node.output_schema_override or card_output_schema
126  
127          # Construct A2A message
128          message = await self._construct_agent_message(
129              node,
130              input_data,
131              input_schema,
132              output_schema,
133              workflow_state,
134              sub_task_id,
135              workflow_context,
136              resolved_instruction,
137          )
138  
139          # Publish request
140          await self._publish_agent_request(
141              node.agent_name, message, sub_task_id, workflow_context
142          )
143  
144          # Track in workflow context
145          workflow_context.track_agent_call(node.id, sub_task_id)
146  
147          return sub_task_id
148  
149      async def call_workflow(
150          self,
151          node: WorkflowInvokeNode,
152          workflow_state: WorkflowExecutionState,
153          workflow_context: WorkflowExecutionContext,
154          sub_task_id: Optional[str] = None,
155      ) -> str:
156          """
157          Invoke a sub-workflow.
158  
159          Workflows register as agents, so this method adapts the workflow node
160          to use the agent calling mechanism.
161  
162          Returns sub-task ID for correlation.
163          """
164          log_id = f"{self.host.log_identifier}[CallWorkflow:{node.workflow_name}]"
165  
166          # Generate sub-task ID if not provided
167          if not sub_task_id:
168              sub_task_id = (
169                  f"wf_{workflow_state.execution_id}_{node.id}_{uuid.uuid4().hex[:8]}"
170              )
171  
172          # Create an adapter object that makes WorkflowInvokeNode compatible
173          # with the existing _resolve_node_input and _construct_agent_message methods
174          class WorkflowNodeAdapter:
175              """Adapter to make WorkflowInvokeNode work with agent calling infrastructure."""
176  
177              def __init__(self, wf_node: WorkflowInvokeNode):
178                  self.id = wf_node.id
179                  self.type = "workflow"
180                  self.agent_name = wf_node.workflow_name  # Map workflow_name to agent_name
181                  self.input = wf_node.input
182                  self.instruction = wf_node.instruction
183                  self.input_schema_override = wf_node.input_schema_override
184                  self.output_schema_override = wf_node.output_schema_override
185                  self.depends_on = wf_node.depends_on
186  
187          adapted_node = WorkflowNodeAdapter(node)
188  
189          # Resolve input data
190          input_data = await self._resolve_node_input(adapted_node, workflow_state)
191  
192          # Resolve instruction template if present
193          resolved_instruction = None
194          if node.instruction:
195              resolved_instruction = self._resolve_string_with_templates(
196                  node.instruction, workflow_state
197              )
198  
199          # Get agent card - required for proper structured invocation
200          # Workflows publish their schemas in their agent cards
201          agent_card = self.host.agent_registry.get_agent(node.workflow_name)
202          if not agent_card:
203              raise ValueError(
204                  f"Workflow '{node.workflow_name}' not found in registry. "
205                  f"Ensure the sub-workflow is running and has published its agent card before "
206                  f"starting the parent workflow."
207              )
208  
209          # Get schemas from agent card extensions
210          card_input_schema, card_output_schema = get_schemas_from_agent_card(agent_card)
211  
212          # Use override schemas if provided, otherwise use schemas from agent card
213          input_schema = node.input_schema_override or card_input_schema
214          output_schema = node.output_schema_override or card_output_schema
215  
216          # Construct A2A message
217          message = await self._construct_agent_message(
218              adapted_node,
219              input_data,
220              input_schema,
221              output_schema,
222              workflow_state,
223              sub_task_id,
224              workflow_context,
225              resolved_instruction,
226          )
227  
228          # Publish request to the sub-workflow
229          await self._publish_agent_request(
230              node.workflow_name, message, sub_task_id, workflow_context
231          )
232  
233          # Track in workflow context
234          workflow_context.track_agent_call(node.id, sub_task_id)
235  
236          log.info(
237              f"{log_id} Invoked sub-workflow '{node.workflow_name}' (sub_task_id: {sub_task_id})"
238          )
239  
240          return sub_task_id
241  
242      async def _resolve_node_input(
243          self, node: WorkflowNode, workflow_state: WorkflowExecutionState
244      ) -> Dict[str, Any]:
245          """
246          Resolve input mapping for a node.
247          If input is not provided, infer it from dependencies.
248          """
249          # Case 1: Explicit Input Mapping
250          if node.input is not None:
251              resolved_input = {}
252              for key, value in node.input.items():
253                  # Use DAGExecutor's resolve_value to handle templates and operators
254                  resolved_value = self.host.dag_executor.resolve_value(
255                      value, workflow_state
256                  )
257                  resolved_input[key] = resolved_value
258              return resolved_input
259  
260          # Case 2: Implicit Input Inference
261          log.debug(
262              f"{self.host.log_identifier} Node '{node.id}' has no explicit input. Inferring from dependencies."
263          )
264  
265          # Case 2a: No dependencies (Initial Node) -> Use Workflow Input
266          if not node.depends_on:
267              if "workflow_input" not in workflow_state.node_outputs:
268                  raise ValueError("Workflow input has not been initialized")
269              return workflow_state.node_outputs["workflow_input"]["output"]
270  
271          # Case 2b: Single Dependency -> Use Dependency Output
272          if len(node.depends_on) == 1:
273              dep_id = node.depends_on[0]
274  
275              # Check if dependency is a switch node - use workflow input instead of switch metadata
276              dep_node = self.host.dag_executor.nodes.get(dep_id)
277              if dep_node and dep_node.type == "switch":
278                  log.debug(
279                      f"{self.host.log_identifier} Node '{node.id}' depends on switch '{dep_id}'. Using workflow input."
280                  )
281                  if "workflow_input" not in workflow_state.node_outputs:
282                      raise ValueError("Workflow input has not been initialized")
283                  return workflow_state.node_outputs["workflow_input"]["output"]
284  
285              if dep_id not in workflow_state.node_outputs:
286                  raise ValueError(f"Dependency '{dep_id}' has not completed")
287              return workflow_state.node_outputs[dep_id]["output"]
288  
289          # Case 2c: Multiple Dependencies -> Ambiguous
290          raise ValueError(
291              f"Node '{node.id}' has multiple dependencies {node.depends_on} but no explicit 'input' mapping. "
292              "Implicit input inference is only supported for nodes with 0 or 1 dependency. "
293              "Please provide an explicit 'input' mapping."
294          )
295  
296      def _generate_result_embed_reminder(
297          self, output_schema: Optional[Dict[str, Any]]
298      ) -> str:
299          """Generate user-facing reminder about result embed requirement."""
300          if output_schema:
301              return """
302  REMINDER: When you complete this task, you MUST end your response with:
303  «result:artifact=<your_artifact_name>:<version> status=success»
304  
305  For example: «result:artifact=analysis_results.json:0 status=success»
306  
307  This is required for the workflow to continue. Without this result embed, the workflow will fail.
308  """
309          else:
310              return """
311  REMINDER: When you complete this task, you MUST end your response with:
312  «result:artifact=<your_artifact_name>:<version> status=success»
313  
314  This is MANDATORY for the workflow to continue.
315  """
316  
317      async def _construct_agent_message(
318          self,
319          node: WorkflowNode,
320          input_data: Dict[str, Any],
321          input_schema: Optional[Dict[str, Any]],
322          output_schema: Optional[Dict[str, Any]],
323          workflow_state: WorkflowExecutionState,
324          sub_task_id: str,
325          workflow_context: WorkflowExecutionContext,
326          resolved_instruction: Optional[str] = None,
327      ) -> A2AMessage:
328          """Construct A2A message for agent."""
329  
330          # Build message parts
331          parts = []
332  
333          # Generate unique output filename for this workflow node
334          # Use last 8 chars of sub_task_id for uniqueness (contains UUID)
335          unique_suffix = sub_task_id[-8:] if len(sub_task_id) >= 8 else sub_task_id
336          # Sanitize workflow name (replace spaces/special chars with underscore)
337          safe_workflow_name = re.sub(
338              r"[^a-zA-Z0-9_-]", "_", workflow_state.workflow_name
339          )
340          # node.id already includes iteration index for map nodes (e.g., "generate_data_0")
341          suggested_output_filename = f"{safe_workflow_name}_{node.id}_{unique_suffix}.json"
342  
343          # 1. Structured invocation request (must be first)
344          invocation_request = StructuredInvocationRequest(
345              type="structured_invocation_request",
346              workflow_name=workflow_state.workflow_name,
347              node_id=node.id,
348              input_schema=input_schema,
349              output_schema=output_schema,
350              suggested_output_filename=suggested_output_filename,
351          )
352          parts.append(a2a.create_data_part(data=invocation_request.model_dump()))
353  
354          # 2. Add instruction text part if provided
355          if resolved_instruction and resolved_instruction.strip():
356              parts.append(a2a.create_text_part(text=resolved_instruction))
357  
358          # Determine if we should send as structured artifact or text
359          # For structured invocations (workflow calls), we ALWAYS send input as FilePart
360          # unless it's explicitly a single text schema. This ensures the receiver can
361          # properly handle the structured input even if we don't have the agent's schema
362          # yet (e.g., due to timing issues with agent card discovery).
363          should_send_artifact = True
364          if input_schema:
365              # Only use text mode if schema is explicitly a single text field
366              is_single_text = (
367                  input_schema.get("type") == "object"
368                  and len(input_schema.get("properties", {})) == 1
369                  and "text" in input_schema.get("properties", {})
370                  and input_schema["properties"]["text"].get("type") == "string"
371              )
372              if is_single_text:
373                  should_send_artifact = False
374  
375          if should_send_artifact:
376              # Create and save input artifact, then add FilePart with URI
377              filename = f"input_{node.id}_{sub_task_id}.json"
378              content_bytes = json.dumps(input_data).encode("utf-8")
379              user_id = workflow_context.a2a_context["user_id"]
380              session_id = workflow_context.a2a_context["session_id"]
381  
382              try:
383                  save_result = await save_artifact_with_metadata(
384                      artifact_service=self.host.artifact_service,
385                      app_name=self.host.workflow_name,
386                      user_id=user_id,
387                      session_id=session_id,
388                      filename=filename,
389                      content_bytes=content_bytes,
390                      mime_type="application/json",
391                      metadata_dict={
392                          "description": f"Input for node {node.id}",
393                          "source": "workflow_execution",
394                      },
395                      timestamp=datetime.now(timezone.utc),
396                      tags=[ARTIFACT_TAG_WORKING],
397                  )
398  
399                  if save_result["status"] == "success":
400                      version = save_result["data_version"]
401                      uri = format_artifact_uri(
402                          app_name=self.host.workflow_name,
403                          user_id=user_id,
404                          session_id=session_id,
405                          filename=filename,
406                          version=version,
407                      )
408                      parts.append(
409                          a2a.create_file_part_from_uri(
410                              uri=uri, name=filename, mime_type="application/json"
411                          )
412                      )
413                      log.info(
414                          f"{self.host.log_identifier} Created input artifact for node "
415                          f"{node.id}: {filename}"
416                      )
417                  else:
418                      raise RuntimeError(
419                          f"Failed to save input artifact: {save_result.get('message')}"
420                      )
421  
422              except Exception as e:
423                  log.error(
424                      f"{self.host.log_identifier} Error saving input artifact for node "
425                      f"{node.id}: {e}"
426                  )
427                  raise e
428  
429          else:
430              # Send as text/data parts (Chat Mode)
431              if "query" in input_data:
432                  parts.append(a2a.create_text_part(text=input_data["query"]))
433              elif "text" in input_data:
434                  parts.append(a2a.create_text_part(text=input_data["text"]))
435              else:
436                  # Fallback for unstructured data without 'query'/'text' keys
437                  text_parts = []
438                  for key, value in input_data.items():
439                      text_parts.append(f"{key}: {value}")
440                  if text_parts:
441                      parts.append(a2a.create_text_part(text="\n".join(text_parts)))
442  
443          # Add reminder about result embed requirement
444          reminder_text = self._generate_result_embed_reminder(output_schema)
445          parts.append(a2a.create_text_part(text=reminder_text))
446  
447          # Construct message using helper function
448          # Use the original workflow session ID as context_id so that RUN_BASED sessions
449          # will be created as {workflow_session_id}:{sub_task_id}:run, allowing the workflow
450          # to find artifacts saved by the node using get_original_session_id()
451          message = a2a.create_user_message(
452              parts=parts,
453              task_id=sub_task_id,
454              context_id=workflow_context.a2a_context["session_id"],
455              metadata={
456                  "workflow_name": workflow_state.workflow_name,
457                  "node_id": node.id,
458                  "sub_task_id": sub_task_id,
459                  "parentTaskId": workflow_context.workflow_task_id,
460              },
461          )
462  
463          return message
464  
465      async def _publish_agent_request(
466          self,
467          agent_name: str,
468          message: A2AMessage,
469          sub_task_id: str,
470          workflow_context: WorkflowExecutionContext,
471      ):
472          """Publish A2A request to agent."""
473          log_id = f"{self.host.log_identifier}[PublishAgentRequest:{agent_name}]"
474  
475          # Get agent request topic
476          request_topic = a2a.get_agent_request_topic(self.host.namespace, agent_name)
477  
478          # Create SendMessageRequest
479          send_params = MessageSendParams(message=message)
480          a2a_request = SendMessageRequest(id=sub_task_id, params=send_params)
481  
482          # Construct reply-to and status topics
483          reply_to_topic = a2a.get_agent_response_topic(
484              self.host.namespace, self.host.workflow_name, sub_task_id
485          )
486          status_topic = a2a.get_peer_agent_status_topic(
487              self.host.namespace, self.host.workflow_name, sub_task_id
488          )
489  
490          # Get current call depth and increment for outgoing request
491          current_depth = workflow_context.a2a_context.get("call_depth", 0)
492  
493          # User properties
494          user_properties = {
495              "replyTo": reply_to_topic,
496              "a2aStatusTopic": status_topic,
497              "userId": workflow_context.a2a_context["user_id"],
498              "a2aUserConfig": workflow_context.a2a_context.get("a2a_user_config", {}),
499              "callDepth": current_depth + 1,
500          }
501  
502          # Publish request
503          self.host.publish_a2a_message(
504              payload=a2a_request.model_dump(by_alias=True, exclude_none=True),
505              topic=request_topic,
506              user_properties=user_properties,
507          )
508  
509          log.debug(
510              f"{log_id} Published agent request to {request_topic} (sub_task_id: {sub_task_id})"
511          )
512  
513          # Set timeout tracking
514          timeout_seconds = self.host.get_config("default_node_timeout_seconds", 300)
515          self.host.cache_service.add_data(
516              key=sub_task_id,
517              value=workflow_context.workflow_task_id,
518              expiry=timeout_seconds,
519              component=self.host,
520          )