/ src / solace_agent_mesh / agent / tools / workflow_tool.py
workflow_tool.py
  1  """
  2  ADK Tool implementation for invoking Workflow agents via A2A.
  3  """
  4  
  5  import logging
  6  import json
  7  import uuid
  8  from datetime import datetime, timezone
  9  from typing import Any, Dict, Optional, Tuple
 10  
 11  import jsonschema
 12  from google.adk.tools import BaseTool, ToolContext
 13  from google.genai import types as adk_types
 14  
 15  from ...common import a2a
 16  from ...common.constants import ARTIFACT_TAG_WORKING, DEFAULT_COMMUNICATION_TIMEOUT
 17  from ...common.exceptions import MessageSizeExceededError
 18  from ...common.data_parts import StructuredInvocationRequest
 19  from ...agent.utils.artifact_helpers import (
 20      save_artifact_with_metadata,
 21      format_artifact_uri,
 22  )
 23  
 24  log = logging.getLogger(__name__)
 25  
 26  WORKFLOW_TOOL_PREFIX = "workflow_"
 27  CORRELATION_DATA_PREFIX = "a2a_subtask_"
 28  
 29  
 30  class WorkflowAgentTool(BaseTool):
 31      """
 32      An ADK Tool that represents a discovered Workflow agent.
 33      Supports dual-mode invocation:
 34      1. Parameter Mode: Pass arguments directly (validated against schema).
 35      2. Artifact Mode: Pass an 'input_artifact' reference.
 36      """
 37  
 38      is_long_running = True
 39  
 40      def __init__(
 41          self,
 42          target_agent_name: str,
 43          input_schema: Dict[str, Any],
 44          host_component,
 45      ):
 46          """
 47          Initializes the WorkflowAgentTool.
 48  
 49          Args:
 50              target_agent_name: The name of the workflow agent.
 51              input_schema: The JSON schema defining the workflow's input parameters.
 52              host_component: A reference to the SamAgentComponent instance.
 53          """
 54          tool_name = f"{WORKFLOW_TOOL_PREFIX}{target_agent_name}"
 55          # Sanitize tool name if necessary (replace hyphens with underscores)
 56          tool_name = tool_name.replace("-", "_")
 57  
 58          super().__init__(
 59              name=tool_name,
 60              description=f"Invoke the '{target_agent_name}' workflow.",
 61              is_long_running=True,
 62          )
 63          self.target_agent_name = target_agent_name
 64          self.input_schema = input_schema
 65          self.host_component = host_component
 66          self.log_identifier = (
 67              f"{host_component.log_identifier}[WorkflowTool:{target_agent_name}]"
 68          )
 69  
 70      def _json_schema_to_adk_schema(
 71          self, json_schema: Dict[str, Any], nullable: bool = False
 72      ) -> adk_types.Schema:
 73          """
 74          Recursively converts a JSON schema to an ADK Schema, preserving nested structure.
 75  
 76          Args:
 77              json_schema: The JSON schema definition to convert.
 78              nullable: Whether the schema should be nullable.
 79  
 80          Returns:
 81              An ADK Schema object representing the JSON schema.
 82          """
 83          json_type = json_schema.get("type", "string")
 84          description = json_schema.get("description", "")
 85  
 86          # Map JSON schema type to ADK type
 87          type_mapping = {
 88              "string": adk_types.Type.STRING,
 89              "integer": adk_types.Type.INTEGER,
 90              "number": adk_types.Type.NUMBER,
 91              "boolean": adk_types.Type.BOOLEAN,
 92              "array": adk_types.Type.ARRAY,
 93              "object": adk_types.Type.OBJECT,
 94          }
 95          adk_type = type_mapping.get(json_type, adk_types.Type.STRING)
 96  
 97          schema_kwargs = {
 98              "type": adk_type,
 99              "description": description,
100              "nullable": nullable,
101          }
102  
103          # Handle array items
104          if json_type == "array" and "items" in json_schema:
105              schema_kwargs["items"] = self._json_schema_to_adk_schema(
106                  json_schema["items"]
107              )
108  
109          # Handle object properties
110          if json_type == "object" and "properties" in json_schema:
111              nested_properties = {}
112              for prop_name, prop_def in json_schema["properties"].items():
113                  nested_properties[prop_name] = self._json_schema_to_adk_schema(prop_def)
114              schema_kwargs["properties"] = nested_properties
115  
116          return adk_types.Schema(**schema_kwargs)
117  
118      def _get_declaration(self) -> adk_types.FunctionDeclaration:
119          """
120          Dynamically generates the FunctionDeclaration based on the workflow's input schema.
121          Adds 'input_artifact' as an optional parameter and marks all parameters as optional
122          to support dual-mode invocation.
123          """
124          properties = self.input_schema.get("properties", {})
125          adk_properties = {}
126  
127          # Add input_artifact parameter
128          adk_properties["input_artifact"] = adk_types.Schema(
129              type=adk_types.Type.STRING,
130              description="Filename of an existing artifact containing the input JSON data. Use this OR individual parameters.",
131              nullable=True,
132          )
133  
134          for prop_name, prop_def in properties.items():
135              adk_properties[prop_name] = self._json_schema_to_adk_schema(
136                  prop_def, nullable=True
137              )
138  
139          parameters_schema = adk_types.Schema(
140              type=adk_types.Type.OBJECT,
141              properties=adk_properties,
142              required=[],  # All optional
143          )
144  
145          return adk_types.FunctionDeclaration(
146              name=self.name,
147              description=f"Invoke the '{self.target_agent_name}' workflow. Dual-mode: provide parameters directly OR 'input_artifact'.",
148              parameters=parameters_schema,
149          )
150  
151      async def run_async(
152          self, *, args: Dict[str, Any], tool_context: ToolContext
153      ) -> Any:
154          """
155          Handles the workflow invocation.
156          """
157          sub_task_id = f"{CORRELATION_DATA_PREFIX}{uuid.uuid4().hex}"
158          log_identifier = f"{self.log_identifier}[SubTask:{sub_task_id}]"
159  
160          try:
161              # 1. Prepare Input Artifact
162              try:
163                  (
164                      payload_artifact_name,
165                      payload_artifact_version,
166                  ) = await self._prepare_input_artifact(
167                      args, tool_context, log_identifier
168                  )
169              except jsonschema.ValidationError as e:
170                  log.warning(
171                      "%s Input validation failed | message=%s",
172                      log_identifier,
173                      e.message,
174                  )
175                  error_response = {
176                      "status": "error",
177                      "message": f"Input validation failed: {e.message}. Please provide required parameters or use input_artifact.",
178                  }
179                  return error_response
180  
181              # 2. Prepare Context
182              original_task_context = tool_context.state.get("a2a_context", {})
183              main_logical_task_id = original_task_context.get(
184                  "logical_task_id", "unknown_task"
185              )
186              invocation_id = tool_context._invocation_context.invocation_id
187              user_id = tool_context._invocation_context.user_id
188              user_config = original_task_context.get("a2a_user_config", {})
189  
190              # 3. Prepare Message
191              session_id = tool_context._invocation_context.session.id
192              a2a_message = self._prepare_a2a_message(
193                  payload_artifact_name,
194                  payload_artifact_version,
195                  user_id,
196                  session_id,
197                  main_logical_task_id,
198                  original_task_context,
199              )
200  
201              # 4. Submit Task
202              try:
203                  self._submit_workflow_task(
204                      sub_task_id,
205                      main_logical_task_id,
206                      invocation_id,
207                      tool_context,
208                      original_task_context,
209                      a2a_message,
210                      user_id,
211                      user_config,
212                      log_identifier,
213                  )
214              except MessageSizeExceededError as e:
215                  log.error("%s Message size exceeded: %s", log_identifier, e)
216                  return {
217                      "status": "error",
218                      "message": f"Error: {str(e)}. Message size exceeded.",
219                  }
220  
221              return None  # Fire-and-forget
222  
223          except Exception as e:
224              log.exception("%s Error in WorkflowAgentTool: %s", log_identifier, e)
225              return {
226                  "status": "error",
227                  "message": f"Failed to invoke workflow '{self.target_agent_name}': {e}",
228              }
229  
230      async def _prepare_input_artifact(
231          self, args: Dict[str, Any], tool_context: ToolContext, log_identifier: str
232      ) -> Tuple[str, Optional[int]]:
233          """
234          Determines input mode, validates parameters, and creates implicit artifact if needed.
235          Returns (artifact_name, artifact_version).
236          """
237          input_artifact_name = args.get("input_artifact")
238  
239          if input_artifact_name:
240              log.info(
241                  "%s Invoking in Artifact Mode with '%s'",
242                  log_identifier,
243                  input_artifact_name,
244              )
245              return input_artifact_name, None
246  
247          # Parameter Mode - Validate against strict schema
248          try:
249              jsonschema.validate(instance=args, schema=self.input_schema)
250          except jsonschema.ValidationError as ve:
251              log.warning(
252                  "%s Schema validation failed | error=%s | path=%s",
253                  log_identifier,
254                  ve.message,
255                  list(ve.absolute_path),
256              )
257              raise
258  
259          # Create implicit artifact
260          payload_data = args
261          payload_bytes = json.dumps(payload_data).encode("utf-8")
262  
263          # Generate unique filename using UUID to avoid collisions in parallel invocations
264          sanitized_wf_name = "".join(
265              c for c in self.target_agent_name if c.isalnum() or c in "_-"
266          )
267          unique_suffix = uuid.uuid4().hex[:8]
268          payload_artifact_name = f"wi_{sanitized_wf_name}_{unique_suffix}.json"
269  
270          # Save artifact
271          user_id = tool_context._invocation_context.user_id
272          session_id = tool_context._invocation_context.session.id
273  
274          save_result = await save_artifact_with_metadata(
275              artifact_service=self.host_component.artifact_service,
276              app_name=self.host_component.agent_name,
277              user_id=user_id,
278              session_id=session_id,
279              filename=payload_artifact_name,
280              content_bytes=payload_bytes,
281              mime_type="application/json",
282              metadata_dict={
283                  "description": f"Auto-generated input for workflow '{self.target_agent_name}'",
284                  "source": "workflow_tool_implicit_creation",
285              },
286              timestamp=datetime.now(timezone.utc),
287              tags=[ARTIFACT_TAG_WORKING],
288          )
289  
290          if save_result["status"] != "success":
291              raise RuntimeError(
292                  f"Failed to save implicit input artifact: {save_result.get('message')}"
293              )
294  
295          payload_artifact_version = save_result.get("data_version")
296  
297          log.info(
298              "%s Created implicit input artifact: %s v%s",
299              log_identifier,
300              payload_artifact_name,
301              payload_artifact_version,
302          )
303  
304          return payload_artifact_name, payload_artifact_version
305  
306      def _prepare_a2a_message(
307          self,
308          payload_artifact_name: str,
309          payload_artifact_version: Optional[int],
310          user_id: str,
311          session_id: str,
312          main_logical_task_id: str,
313          original_task_context: Dict[str, Any],
314      ) -> Any:
315          """Constructs the A2A message with StructuredInvocationRequest and FilePart."""
316          parts = []
317  
318          # 1. Add StructuredInvocationRequest DataPart (triggers structured invocation)
319          invocation_request = StructuredInvocationRequest(
320              type="structured_invocation_request",
321              workflow_name=self.host_component.agent_name,
322              node_id=f"workflow_tool_{self.target_agent_name}",
323              input_schema=self.input_schema,
324              output_schema=None,  # Workflow defines its own output schema
325              suggested_output_filename=f"wf_{self.target_agent_name}_result.json",
326          )
327          parts.append(a2a.create_data_part(data=invocation_request.model_dump()))
328  
329          # 2. Add FilePart with artifact URI (contains the input data)
330          if payload_artifact_name and payload_artifact_version is not None:
331              uri = format_artifact_uri(
332                  app_name=self.host_component.agent_name,
333                  user_id=user_id,
334                  session_id=session_id,
335                  filename=payload_artifact_name,
336                  version=payload_artifact_version,
337              )
338              parts.append(
339                  a2a.create_file_part_from_uri(
340                      uri=uri,
341                      name=payload_artifact_name,
342                      mime_type="application/json",
343                  )
344              )
345  
346          a2a_metadata = {
347              "sessionBehavior": "RUN_BASED",
348              "parentTaskId": main_logical_task_id,
349              "agent_name": self.target_agent_name,
350          }
351  
352          return a2a.create_user_message(
353              parts=parts,
354              metadata=a2a_metadata,
355              context_id=original_task_context.get("contextId"),
356          )
357  
358      def _submit_workflow_task(
359          self,
360          sub_task_id: str,
361          main_logical_task_id: str,
362          invocation_id: str,
363          tool_context: ToolContext,
364          original_task_context: Dict[str, Any],
365          a2a_message: Any,
366          user_id: str,
367          user_config: Dict[str, Any],
368          log_identifier: str,
369      ):
370          """Handles task registration, correlation data, and submission."""
371          # Register parallel call
372          task_context_obj = None
373          with self.host_component.active_tasks_lock:
374              task_context_obj = self.host_component.active_tasks.get(
375                  main_logical_task_id
376              )
377  
378          if not task_context_obj:
379              log.error(
380                  "%s TaskExecutionContext NOT FOUND | main_task_id=%s | active_tasks_keys=%s",
381                  log_identifier,
382                  main_logical_task_id,
383                  list(self.host_component.active_tasks.keys()),
384              )
385              raise ValueError(
386                  f"TaskExecutionContext not found for task '{main_logical_task_id}'"
387              )
388  
389          # NOTE: register_parallel_call_sent is now called in
390          # preregister_long_running_tools_callback (after_model_callback)
391          # BEFORE tool execution begins. This prevents race conditions where
392          # one tool completes before another registers.
393  
394          # Submit Task
395          correlation_data = {
396              "adk_function_call_id": tool_context.function_call_id,
397              "original_task_context": original_task_context,
398              "peer_tool_name": self.name,
399              "peer_agent_name": self.target_agent_name,
400              "logical_task_id": main_logical_task_id,
401              "invocation_id": invocation_id,
402          }
403  
404          task_context_obj.register_peer_sub_task(sub_task_id, correlation_data)
405  
406          timeout_sec = self.host_component.get_config(
407              "inter_agent_communication", {}
408          ).get("request_timeout_seconds", DEFAULT_COMMUNICATION_TIMEOUT)
409  
410          self.host_component.cache_service.add_data(
411              key=sub_task_id,
412              value=main_logical_task_id,
413              expiry=timeout_sec,
414              component=self.host_component,
415          )
416  
417          self.host_component.submit_a2a_task(
418              target_agent_name=self.target_agent_name,
419              a2a_message=a2a_message,
420              user_id=user_id,
421              user_config=user_config,
422              sub_task_id=sub_task_id,
423          )
424  
425          log.info(
426              "%s Workflow task submitted for agent: %s", log_identifier, self.target_agent_name
427          )