/ src / solace_agent_mesh / agent / adk / intelligent_mcp_callbacks.py
intelligent_mcp_callbacks.py
  1  """
  2  Intelligent MCP Callback Functions
  3  
  4  This module contains the refactored MCP callback functions that use intelligent
  5  content processing to save MCP tool responses as appropriately typed artifacts.
  6  """
  7  
  8  import logging
  9  import json
 10  import uuid
 11  from datetime import datetime, timezone
 12  from typing import Any, Dict, TYPE_CHECKING, List, Optional
 13  from enum import Enum
 14  from pydantic import BaseModel
 15  
 16  from google.adk.tools import ToolContext, BaseTool
 17  
 18  from .mcp_content_processor import MCPContentProcessor, MCPContentProcessorConfig
 19  from ...agent.utils.artifact_helpers import (
 20      save_artifact_with_metadata,
 21      DEFAULT_SCHEMA_MAX_KEYS,
 22      DEFAULT_SCHEMA_INFERENCE_DEPTH,
 23  )
 24  from ...agent.utils.context_helpers import get_original_session_id
 25  
 26  log = logging.getLogger(__name__)
 27  
 28  if TYPE_CHECKING:
 29      from ...agent.sac.component import SamAgentComponent
 30  
 31  
 32  def _get_schema_config_from_tool_or_agent(
 33      tool: BaseTool,
 34      host_component: "SamAgentComponent",
 35      config_key: str,
 36      default_value: Any,
 37  ) -> Any:
 38      """
 39      Get schema configuration from tool_config if available, otherwise fall back to agent config.
 40  
 41      This allows per-tool override of schema settings like schema_inference_depth.
 42  
 43      Args:
 44          tool: The MCP tool instance (may have _tool_config attribute)
 45          host_component: The agent component for fallback config
 46          config_key: The configuration key to look up
 47          default_value: Default value if not found in either location
 48  
 49      Returns:
 50          The configuration value from tool config, agent config, or default
 51      """
 52      # Check if tool has tool_config with this setting
 53      tool_config = getattr(tool, "_tool_config", None)
 54      if tool_config and isinstance(tool_config, dict):
 55          if config_key in tool_config:
 56              log.debug(
 57                  "Using per-tool config for %s: %s",
 58                  config_key,
 59                  tool_config[config_key],
 60              )
 61              return tool_config[config_key]
 62  
 63      # Fall back to agent-level config
 64      return host_component.get_config(config_key, default_value)
 65  
 66  
 67  class McpSaveStatus(str, Enum):
 68      """Enumeration for the status of an MCP save operation."""
 69  
 70      SUCCESS = "success"
 71      PARTIAL_SUCCESS = "partial_success"
 72      ERROR = "error"
 73  
 74  
 75  class SavedArtifactInfo(BaseModel):
 76      """
 77      A Pydantic model to hold the details of a successfully saved artifact.
 78      This mirrors the dictionary structure returned by save_artifact_with_metadata.
 79      """
 80  
 81      status: str
 82      data_filename: str
 83      data_version: int
 84      metadata_filename: str
 85      metadata_version: int
 86      message: str
 87  
 88  
 89  class McpSaveResult(BaseModel):
 90      """
 91      The definitive, type-safe result of an MCP response save operation.
 92  
 93      Attributes:
 94          status: The overall status of the save operation.
 95          message: A human-readable summary of the outcome.
 96          artifacts_saved: A list of successfully created "intelligent" artifacts.
 97          fallback_artifact: An optional artifact representing the raw JSON response,
 98                             created only if intelligent processing failed.
 99      """
100  
101      status: McpSaveStatus
102      message: str
103      artifacts_saved: List[SavedArtifactInfo] = []
104      fallback_artifact: Optional[SavedArtifactInfo] = None
105  
106  
107  async def save_mcp_response_as_artifact_intelligent(
108      tool: BaseTool,
109      tool_context: ToolContext,
110      host_component: "SamAgentComponent",
111      mcp_response_dict: Dict[str, Any],
112      original_tool_args: Dict[str, Any],
113  ) -> McpSaveResult:
114      """
115      Intelligently processes and saves MCP tool response content as typed artifacts.
116  
117      This function uses intelligent content processing to:
118      - Detect and parse different content types (text, image, audio, resource)
119      - Create appropriately typed artifacts with proper MIME types
120      - Generate enhanced metadata based on content analysis
121      - Fall back to raw JSON saving if intelligent processing fails
122  
123      Args:
124          tool: The MCPTool instance that generated the response.
125          tool_context: The ADK ToolContext.
126          host_component: The A2A_ADK_HostComponent instance for accessing config and services.
127          mcp_response_dict: The raw MCP tool response dictionary.
128          original_tool_args: The original arguments passed to the MCP tool.
129  
130      Returns:
131          An McpSaveResult object containing the structured result of the operation,
132          including status, a list of successfully saved artifacts, and any
133          fallback artifact.
134      """
135      log_identifier = f"[IntelligentMCPCallback:{tool.name}]"
136      log.debug("%s Starting intelligent MCP response artifact saving...", log_identifier)
137  
138      processor_config_dict = host_component.get_config("mcp_intelligent_processing", {})
139      processor_config = MCPContentProcessorConfig.from_dict(processor_config_dict)
140  
141      saved_artifacts: List[SavedArtifactInfo] = []
142      failed_artifacts: List[Dict[str, Any]] = []
143      fallback_artifact: Optional[SavedArtifactInfo] = None
144      overall_status = McpSaveStatus.SUCCESS
145  
146      try:
147          if not processor_config.enable_intelligent_processing:
148              log.info(
149                  "%s Intelligent processing disabled, using raw JSON fallback.",
150                  log_identifier,
151              )
152              fallback_dict = await _save_raw_mcp_response_fallback(
153                  tool,
154                  tool_context,
155                  host_component,
156                  mcp_response_dict,
157                  original_tool_args,
158              )
159              if fallback_dict.get("status") in ["success", "partial_success"]:
160                  fallback_artifact = SavedArtifactInfo(**fallback_dict)
161                  status = McpSaveStatus.SUCCESS
162              else:
163                  status = McpSaveStatus.ERROR
164              return McpSaveResult(
165                  status=status,
166                  message="Intelligent processing disabled; saved raw JSON as fallback.",
167                  fallback_artifact=fallback_artifact,
168              )
169  
170          processor = MCPContentProcessor(tool.name, original_tool_args)
171          content_items = processor.process_mcp_response(mcp_response_dict)
172  
173          if not content_items:
174              log.warning(
175                  "%s No content items found, falling back to raw JSON.", log_identifier
176              )
177              fallback_dict = await _save_raw_mcp_response_fallback(
178                  tool,
179                  tool_context,
180                  host_component,
181                  mcp_response_dict,
182                  original_tool_args,
183              )
184              if fallback_dict.get("status") in ["success", "partial_success"]:
185                  fallback_artifact = SavedArtifactInfo(**fallback_dict)
186              return McpSaveResult(
187                  status=McpSaveStatus.PARTIAL_SUCCESS,
188                  message="No content items found in MCP response; saved raw JSON as fallback.",
189                  fallback_artifact=fallback_artifact,
190              )
191  
192          log.info(
193              "%s Processing %d content items intelligently.",
194              log_identifier,
195              len(content_items),
196          )
197  
198          for item in content_items:
199              try:
200                  if hasattr(item, "uri"):
201                      item.uri = str(item.uri)
202                  result_dict = await _save_content_item_as_artifact(
203                      item, tool, tool_context, host_component
204                  )
205                  if result_dict.get("status") in ["success", "partial_success"]:
206                      saved_artifacts.append(SavedArtifactInfo(**result_dict))
207                  else:
208                      log.warning(
209                          "%s Failed to save content item: %s",
210                          log_identifier,
211                          result_dict.get("message", "Unknown error"),
212                      )
213                      overall_status = McpSaveStatus.PARTIAL_SUCCESS
214                      failed_artifacts.append(result_dict)
215              except Exception as e:
216                  if not processor_config.fallback_to_raw_on_error:
217                      raise
218                  log.exception("%s Error saving content item: %s", log_identifier, e)
219                  overall_status = McpSaveStatus.PARTIAL_SUCCESS
220                  failed_artifacts.append({"status": "error", "message": str(e)})
221  
222          if not saved_artifacts:
223              if failed_artifacts:
224                  first_error_msg = failed_artifacts[0].get("message", "Unknown error")
225                  log.warning(
226                      "%s No items saved successfully. First error: %s",
227                      log_identifier,
228                      first_error_msg,
229                  )
230                  return McpSaveResult(
231                      status=McpSaveStatus.ERROR,
232                      message=f"Content processing failed. First error: {first_error_msg}",
233                  )
234  
235              fallback_dict = await _save_raw_mcp_response_fallback(
236                  tool,
237                  tool_context,
238                  host_component,
239                  mcp_response_dict,
240                  original_tool_args,
241              )
242              if fallback_dict.get("status") in ["success", "partial_success"]:
243                  fallback_artifact = SavedArtifactInfo(**fallback_dict)
244              return McpSaveResult(
245                  status=McpSaveStatus.PARTIAL_SUCCESS,
246                  message="Content processing failed for all items; saved raw JSON as fallback.",
247                  fallback_artifact=fallback_artifact,
248              )
249  
250          if processor_config_dict.get("save_raw_alongside_intelligent", False):
251              try:
252                  fallback_dict = await _save_raw_mcp_response_fallback(
253                      tool,
254                      tool_context,
255                      host_component,
256                      mcp_response_dict,
257                      original_tool_args,
258                  )
259                  if fallback_dict.get("status") in ["success", "partial_success"]:
260                      fallback_artifact = SavedArtifactInfo(**fallback_dict)
261              except Exception as e:
262                  log.warning(
263                      "%s Failed to save raw JSON alongside: %s", log_identifier, e
264                  )
265  
266          log.info(
267              "%s Intelligent processing complete: %d artifacts saved, status: %s",
268              log_identifier,
269              len(saved_artifacts),
270              overall_status.value,
271          )
272          return McpSaveResult(
273              status=overall_status,
274              artifacts_saved=saved_artifacts,
275              fallback_artifact=fallback_artifact,
276              message=f"Successfully processed {len(saved_artifacts)} content items.",
277          )
278  
279      except Exception as e:
280          log.exception(
281              "%s Error in intelligent MCP response processing: %s", log_identifier, e
282          )
283          if processor_config.fallback_to_raw_on_error:
284              log.info(
285                  "%s Falling back to raw JSON due to processing error.", log_identifier
286              )
287              try:
288                  fallback_dict = await _save_raw_mcp_response_fallback(
289                      tool,
290                      tool_context,
291                      host_component,
292                      mcp_response_dict,
293                      original_tool_args,
294                  )
295                  if fallback_dict.get("status") in ["success", "partial_success"]:
296                      fallback_artifact = SavedArtifactInfo(**fallback_dict)
297                  return McpSaveResult(
298                      status=McpSaveStatus.PARTIAL_SUCCESS,
299                      artifacts_saved=saved_artifacts,
300                      fallback_artifact=fallback_artifact,
301                      message=f"Intelligent processing failed, saved raw JSON as fallback: {e}",
302                  )
303              except Exception as fallback_error:
304                  log.exception(
305                      "%s Fallback also failed: %s", log_identifier, fallback_error
306                  )
307  
308          return McpSaveResult(
309              status=McpSaveStatus.ERROR,
310              artifacts_saved=saved_artifacts,
311              fallback_artifact=None,
312              message=f"Failed to save MCP response as artifact: {e}",
313          )
314  
315  
316  async def _save_content_item_as_artifact(
317      content_item,
318      tool: BaseTool,
319      tool_context: ToolContext,
320      host_component: "SamAgentComponent",
321  ) -> Dict[str, Any]:
322      """Save a single processed content item as an artifact."""
323  
324      log_identifier = f"[IntelligentMCPCallback:SaveContentItem:{content_item.filename}]"
325  
326      try:
327          artifact_service = host_component.artifact_service
328          if not artifact_service:
329              raise ValueError("ArtifactService is not available on host_component.")
330  
331          app_name = host_component.agent_name
332          user_id = tool_context._invocation_context.user_id
333          session_id = get_original_session_id(tool_context._invocation_context)
334          # Get schema config from tool_config (per-tool) or agent config (fallback)
335          schema_max_keys = _get_schema_config_from_tool_or_agent(
336              tool, host_component, "schema_max_keys", DEFAULT_SCHEMA_MAX_KEYS
337          )
338          schema_inference_depth = _get_schema_config_from_tool_or_agent(
339              tool, host_component, "schema_inference_depth", DEFAULT_SCHEMA_INFERENCE_DEPTH
340          )
341          artifact_timestamp = datetime.now(timezone.utc)
342  
343          log.debug(
344              "%s Saving content item: type=%s, mime_type=%s, size=%d bytes",
345              log_identifier,
346              content_item.content_type,
347              content_item.mime_type,
348              len(content_item.content_bytes),
349          )
350  
351          save_result = await save_artifact_with_metadata(
352              artifact_service=artifact_service,
353              app_name=app_name,
354              user_id=user_id,
355              session_id=session_id,
356              filename=content_item.filename,
357              content_bytes=content_item.content_bytes,
358              mime_type=content_item.mime_type,
359              metadata_dict=content_item.metadata,
360              timestamp=artifact_timestamp,
361              schema_max_keys=schema_max_keys,
362              schema_inference_depth=schema_inference_depth,
363              tool_context=tool_context,
364          )
365  
366          log.info(
367              "%s Content item saved as artifact '%s' (version %s). Status: %s",
368              log_identifier,
369              save_result.get("data_filename", content_item.filename),
370              save_result.get("data_version", "N/A"),
371              save_result.get("status"),
372          )
373  
374          return save_result
375  
376      except Exception as e:
377          log.exception("%s Error saving content item as artifact: %s", log_identifier, e)
378          return {
379              "status": "error",
380              "data_filename": content_item.filename,
381              "message": f"Failed to save content item as artifact: {e}",
382          }
383  
384  
385  async def _save_raw_mcp_response_fallback(
386      tool: BaseTool,
387      tool_context: ToolContext,
388      host_component: "SamAgentComponent",
389      mcp_response_dict: Dict[str, Any],
390      original_tool_args: Dict[str, Any],
391  ) -> Dict[str, Any]:
392      """
393      Fallback function to save the raw MCP response as a JSON artifact.
394      This is the original behavior, used when intelligent processing is disabled or fails.
395      """
396      log_identifier = f"[IntelligentMCPCallback:{tool.name}:RawFallback]"
397      log.debug("%s Saving raw MCP response as JSON artifact...", log_identifier)
398  
399      try:
400          a2a_context = tool_context.state.get("a2a_context", {})
401          logical_task_id = a2a_context.get("logical_task_id", "unknownTask")
402          task_id_suffix = logical_task_id[-6:]
403          random_suffix = uuid.uuid4().hex[:6]
404          filename = f"{task_id_suffix}_{tool.name}_raw_{random_suffix}.json"
405  
406          content_bytes = json.dumps(mcp_response_dict, indent=2).encode("utf-8")
407          mime_type = "application/json"
408          artifact_timestamp = datetime.now(timezone.utc)
409  
410          metadata_for_saving = {
411              "description": f"Raw JSON response from MCP tool {tool.name}",
412              "source_tool_name": tool.name,
413              "source_tool_args": original_tool_args,
414              "processing_type": "raw_fallback",
415          }
416  
417          artifact_service = host_component.artifact_service
418          app_name = host_component.agent_name
419          user_id = tool_context._invocation_context.user_id
420          session_id = get_original_session_id(tool_context._invocation_context)
421          # Get schema config from tool_config (per-tool) or agent config (fallback)
422          schema_max_keys = _get_schema_config_from_tool_or_agent(
423              tool, host_component, "schema_max_keys", DEFAULT_SCHEMA_MAX_KEYS
424          )
425          schema_inference_depth = _get_schema_config_from_tool_or_agent(
426              tool, host_component, "schema_inference_depth", DEFAULT_SCHEMA_INFERENCE_DEPTH
427          )
428  
429          save_result = await save_artifact_with_metadata(
430              artifact_service=artifact_service,
431              app_name=app_name,
432              user_id=user_id,
433              session_id=session_id,
434              filename=filename,
435              content_bytes=content_bytes,
436              mime_type=mime_type,
437              metadata_dict=metadata_for_saving,
438              timestamp=artifact_timestamp,
439              schema_max_keys=schema_max_keys,
440              schema_inference_depth=schema_inference_depth,
441              tool_context=tool_context,
442          )
443  
444          log.info(
445              "%s Raw MCP response saved as artifact '%s' (version %s). Status: %s",
446              log_identifier,
447              save_result.get("data_filename", filename),
448              save_result.get("data_version", "N/A"),
449              save_result.get("status"),
450          )
451  
452          return save_result
453  
454      except Exception as e:
455          log.exception(
456              "%s Error saving raw MCP response as artifact: %s", log_identifier, e
457          )
458          return {
459              "status": "error",
460              "data_filename": filename if "filename" in locals() else "unknown_filename",
461              "message": f"Failed to save raw MCP response as artifact: {e}",
462          }