validator.py
  1  """
  2  A2A Message Validator for integration tests.
  3  Patches message publishing methods to intercept and validate A2A messages.
  4  """
  5  
  6  import functools
  7  import json
  8  import importlib.resources
  9  from typing import Any, Dict, List
 10  from unittest.mock import patch
 11  
 12  import pytest
 13  from jsonschema import Draft7Validator, RefResolver, ValidationError
 14  
 15  
 16  
 17  METHOD_TO_SCHEMA_MAP = {
 18      "message/send": "SendMessageRequest",
 19      "message/stream": "SendStreamingMessageRequest",
 20      "tasks/get": "GetTaskRequest",
 21      "tasks/cancel": "CancelTaskRequest",
 22      "tasks/pushNotificationConfig/set": "SetTaskPushNotificationConfigRequest",
 23      "tasks/pushNotificationConfig/get": "GetTaskPushNotificationConfigRequest",
 24      "tasks/pushNotificationConfig/list": "ListTaskPushNotificationConfigRequest",
 25      "tasks/pushNotificationConfig/delete": "DeleteTaskPushNotificationConfigRequest",
 26      "tasks/resubscribe": "TaskResubscriptionRequest",
 27      "agent/getAuthenticatedExtendedCard": "GetAuthenticatedExtendedCardRequest",
 28  }
 29  
 30  
 31  class A2AMessageValidator:
 32      """
 33      Intercepts and validates A2A messages published by SAM components against the
 34      official a2a.json schema.
 35      """
 36  
 37      def __init__(self):
 38          self._patched_targets: List[Dict[str, Any]] = []
 39          self.active = False
 40          self.schema = self._load_schema()
 41          self.validator = self._create_validator(self.schema)
 42  
 43      def _load_schema(self) -> Dict[str, Any]:
 44          """Loads the A2A JSON schema from the installed package."""
 45          try:
 46              # Use importlib.resources to find the schema file within the package.
 47              # This works whether the package is installed or in editable mode.
 48              with importlib.resources.path(
 49                  "solace_agent_mesh.common.a2a_spec", "a2a.json"
 50              ) as schema_path:
 51                  with open(schema_path, "r", encoding="utf-8") as f:
 52                      return json.load(f)
 53          except (ModuleNotFoundError, FileNotFoundError):
 54              pytest.fail(
 55                  "A2A Validator: Schema file 'a2a.json' not found in package "
 56                  "'solace_agent_mesh.common.a2a_spec'. "
 57                  "Ensure the package is installed correctly or run 'scripts/sync_a2a_schema.py'."
 58              )
 59          except json.JSONDecodeError as e:
 60              pytest.fail(f"A2A Validator: Failed to parse schema file: {e}")
 61  
 62      def _create_validator(self, schema: Dict[str, Any]) -> Draft7Validator:
 63          """Creates a jsonschema validator with a resolver for local $refs."""
 64          resolver = RefResolver.from_schema(schema)
 65          return Draft7Validator(schema, resolver=resolver)
 66  
 67      def activate(self, components_to_patch: List[Any]):
 68          """
 69          Activates the validator by patching message publishing methods on components.
 70  
 71          Args:
 72              components_to_patch: A list of component instances.
 73                                   It will patch 'publish_a2a_message' on TestGatewayComponent instances
 74                                   and '_publish_a2a_message' on SamAgentComponent instances.
 75          """
 76          if self.active:
 77              self.deactivate()
 78          from solace_agent_mesh.agent.sac.component import SamAgentComponent
 79          from sam_test_infrastructure.gateway_interface.component import (
 80              TestGatewayComponent,
 81          )
 82          from solace_agent_mesh.agent.proxies.base.component import BaseProxyComponent
 83  
 84          for component_instance in components_to_patch:
 85              method_name_to_patch = None
 86              is_sam_agent_component = isinstance(component_instance, SamAgentComponent)
 87              is_test_gateway_component = isinstance(
 88                  component_instance, TestGatewayComponent
 89              )
 90              is_base_proxy_component = isinstance(component_instance, BaseProxyComponent)
 91  
 92              if is_sam_agent_component or is_base_proxy_component:
 93                  method_name_to_patch = "_publish_a2a_message"
 94              elif is_test_gateway_component:
 95                  method_name_to_patch = "publish_a2a_message"
 96              else:
 97                  print(
 98                      f"A2AMessageValidator: Warning - Component {type(component_instance)} is not a recognized type for patching."
 99                  )
100                  continue
101  
102              if not hasattr(component_instance, method_name_to_patch):
103                  print(
104                      f"A2AMessageValidator: Warning - Component {type(component_instance)} has no method {method_name_to_patch}"
105                  )
106                  continue
107  
108              original_method = getattr(component_instance, method_name_to_patch)
109  
110              def side_effect_with_validation(
111                  original_method_ref,
112                  component_instance_at_patch_time,
113                  current_method_name,
114                  *args,
115                  **kwargs,
116              ):
117                  return_value = original_method_ref(*args, **kwargs)
118  
119                  payload_to_validate = None
120                  topic_to_validate = None
121                  source_info = f"Patched {component_instance_at_patch_time.__class__.__name__}.{current_method_name}"
122  
123                  if current_method_name == "_publish_a2a_message":
124                      payload_to_validate = kwargs.get("payload")
125                      topic_to_validate = kwargs.get("topic")
126                      if payload_to_validate is None or topic_to_validate is None:
127                          if len(args) >= 2:
128                              payload_to_validate = args[0]
129                              topic_to_validate = args[1]
130                          else:
131                              pytest.fail(
132                                  f"A2A Validator: Incorrect args/kwargs for {source_info}. Expected payload, topic. Got args: {args}, kwargs: {kwargs}"
133                              )
134                  elif current_method_name == "publish_a2a_message":
135                      topic_to_validate = kwargs.get("topic")
136                      payload_to_validate = kwargs.get("payload")
137                      if payload_to_validate is None or topic_to_validate is None:
138                          if len(args) >= 2:
139                              topic_to_validate = args[0]
140                              payload_to_validate = args[1]
141                          else:
142                              pytest.fail(
143                                  f"A2A Validator: Incorrect args/kwargs for {source_info}. Expected topic, payload. Got args: {args}, kwargs: {kwargs}"
144                              )
145  
146                  if payload_to_validate is not None and topic_to_validate is not None:
147                      self.validate_message(
148                          payload_to_validate, topic_to_validate, source_info
149                      )
150                  else:
151                      print(
152                          f"A2AMessageValidator: Warning - Could not extract payload/topic from {source_info} call. Args: {args}, Kwargs: {kwargs}"
153                      )
154  
155                  return return_value
156  
157              try:
158                  patcher = patch.object(
159                      component_instance, method_name_to_patch, autospec=True
160                  )
161                  mock_method = patcher.start()
162                  bound_side_effect = functools.partial(
163                      side_effect_with_validation,
164                      original_method,
165                      component_instance,
166                      method_name_to_patch,
167                  )
168                  mock_method.side_effect = bound_side_effect
169  
170                  self._patched_targets.append(
171                      {
172                          "patcher": patcher,
173                          "component": component_instance,
174                          "method_name": method_name_to_patch,
175                      }
176                  )
177              except Exception as e:
178                  print(
179                      f"A2AMessageValidator: Failed to patch {method_name_to_patch} on {component_instance}: {e}"
180                  )
181                  self.deactivate()
182                  raise
183  
184          if self._patched_targets:
185              self.active = True
186              print(
187                  f"A2AMessageValidator: Activated. Monitoring {len(self._patched_targets)} methods."
188              )
189  
190      def deactivate(self):
191          """Deactivates the validator by stopping all active patches."""
192          for patch_info in self._patched_targets:
193              try:
194                  patch_info["patcher"].stop()
195              except RuntimeError:
196                  pass
197          self._patched_targets = []
198          self.active = False
199          print("A2AMessageValidator: Deactivated.")
200  
201      def validate_message(
202          self, payload: Dict, topic: str, source_info: str = "Unknown source"
203      ):
204          """
205          Validates a single A2A message payload against the official a2a.json schema.
206          Fails the test immediately using pytest.fail() if validation errors occur.
207          """
208          if "/discovery/agentcards" in topic:
209              return
210  
211          schema_to_use = None
212          is_request = "method" in payload
213  
214          try:
215              if is_request:
216                  method = payload.get("method")
217                  schema_name = METHOD_TO_SCHEMA_MAP.get(method)
218                  if schema_name and schema_name in self.schema["definitions"]:
219                      schema_to_use = self.schema["definitions"][schema_name]
220                  else:
221                      # Fallback to generic request if specific one not found
222                      schema_to_use = self.schema["definitions"]["JSONRPCRequest"]
223              else:
224                  # For responses, try to find a specific schema based on the result 'kind'.
225                  schema_to_use = self.schema["definitions"]["JSONRPCResponse"]  # Default
226                  result = payload.get("result")
227                  if isinstance(result, dict):
228                      kind = result.get("kind")
229                      if kind == "task":
230                          schema_to_use = self.schema["definitions"][
231                              "GetTaskSuccessResponse"
232                          ]
233                      elif kind == "message":
234                          schema_to_use = self.schema["definitions"][
235                              "SendMessageSuccessResponse"
236                          ]
237                      elif kind in ["status-update", "artifact-update"]:
238                          schema_to_use = self.schema["definitions"][
239                              "SendStreamingMessageSuccessResponse"
240                          ]
241  
242              self.validator.check_schema(schema_to_use)
243              self.validator.validate(payload, schema_to_use)
244  
245              # The JSON-RPC spec states that 'result' and 'error' MUST NOT coexist.
246              # The generated schema might use 'anyOf' which doesn't enforce this.
247              # We add an explicit check here to ensure compliance.
248              if not is_request and "result" in payload and "error" in payload:
249                  raise ValidationError(
250                      "'result' and 'error' are mutually exclusive and cannot be present in the same response.",
251                      validator="dependencies",
252                      validator_value={"result": ["error"], "error": ["result"]},
253                      instance=payload,
254                      schema=schema_to_use,
255                  )
256  
257          except ValidationError as e:
258              pytest.fail(
259                  f"A2A Schema Validation Error from {source_info} on topic '{topic}':\n"
260                  f"Message: {e.message}\n"
261                  f"Path: {list(e.path)}\n"
262                  f"Validator: {e.validator} = {e.validator_value}\n"
263                  f"Payload: {json.dumps(payload, indent=2)}"
264              )
265          except Exception as e:
266              pytest.fail(
267                  f"A2A Validation Error (Structure) from {source_info} on topic '{topic}': {e}\n"
268                  f"Payload: {json.dumps(payload, indent=2)}"
269              )
270  
271          print(
272              f"A2AMessageValidator: Successfully validated message from {source_info} on topic '{topic}' (ID: {payload.get('id')})"
273          )