/ scripts / migrate_declarative_tests.py
migrate_declarative_tests.py
  1  """
  2  This script automates the migration of declarative test YAML files to conform
  3  to the new A2A specification format. It identifies files in the old format
  4  and transforms their `expected_gateway_output` sections in-place, while
  5  preserving comments and formatting.
  6  """
  7  
  8  import argparse
  9  import sys
 10  from pathlib import Path
 11  from ruamel.yaml import YAML
 12  
 13  
 14  def is_migrated(data: dict) -> bool:
 15      """Checks if the YAML data appears to be in the new format."""
 16      if "expected_gateway_output" not in data:
 17          return True  # No section to migrate, so it's "done"
 18  
 19      events = data.get("expected_gateway_output")
 20      if not events:  # Handles empty list case: expected_gateway_output: []
 21          return True  # Nothing to migrate, so consider it done.
 22  
 23      for event in events:
 24          if isinstance(event, dict) and event.get("kind") == "task":
 25              return True
 26      return False
 27  
 28  
 29  def transform_data(data: dict) -> tuple[dict, bool]:
 30      """Transforms the YAML data from the old format to the new format."""
 31      if "expected_gateway_output" not in data:
 32          return data, False
 33  
 34      original_events = data.get("expected_gateway_output")
 35      if not original_events:
 36          return data, False
 37  
 38      transformed_events = []
 39      was_transformed = False
 40  
 41      for event in original_events:
 42          # Transform if it's a dictionary and doesn't have the new format's 'kind' key.
 43          if isinstance(event, dict) and event.get("kind") != "task":
 44              was_transformed = True
 45  
 46              # 1. Create the new root structure
 47              new_event = {
 48                  "type": event.get("type"),
 49                  "kind": "task",
 50                  "id": "*",
 51              }
 52  
 53              # 2. Derive contextId
 54              context_id = (
 55                  data.get("gateway_input", {})
 56                  .get("external_context", {})
 57                  .get("a2a_session_id")
 58              )
 59              if not context_id:
 60                  # Fallback for older tests that might not have this structure
 61                  test_case_id = data.get("test_case_id", "unknown_test")
 62                  context_id = f"session_{test_case_id}"  # Best guess
 63                  print(
 64                      f"    [WARNING] Could not find a2a_session_id for '{test_case_id}'. Falling back to generated contextId: {context_id}",
 65                      file=sys.stderr,
 66                  )
 67              new_event["contextId"] = context_id
 68  
 69              # 3. Build the status object, inferring task_state if missing.
 70              task_state = event.get("task_state")
 71              if task_state is None:
 72                  if event.get("type") == "final_response":
 73                      task_state = "completed"
 74                  else:
 75                      task_state = "unknown"
 76                      print(
 77                          f"    [WARNING] 'task_state' missing in non-final_response event. Defaulting to 'unknown'. Event: {event}",
 78                          file=sys.stderr,
 79                      )
 80              status = {"state": task_state}
 81  
 82              # 4. Build the status.message object
 83              message = {
 84                  "kind": "message",
 85                  "messageId": "*",
 86                  "role": "agent",
 87              }
 88  
 89              # 5. Relocate content_parts
 90              content_parts = event.get("content_parts", [])
 91              if content_parts:
 92                  # Handle text_contains string-to-list conversion
 93                  for part in content_parts:
 94                      if "text_contains" in part and isinstance(
 95                          part["text_contains"], str
 96                      ):
 97                          part["text_contains"] = [part["text_contains"]]
 98                  message["parts"] = content_parts
 99  
100              status["message"] = message
101              new_event["status"] = status
102  
103              # 6. Handle other assertions
104              for key, value in event.items():
105                  if key not in ["type", "task_state", "content_parts"]:
106                      new_event[key] = value
107  
108              transformed_events.append(new_event)
109          else:
110              # If it's already migrated or not a dict, keep it as is
111              transformed_events.append(event)
112  
113      if was_transformed:
114          data["expected_gateway_output"] = transformed_events
115  
116      return data, was_transformed
117  
118  
119  def main():
120      """Main script execution."""
121      parser = argparse.ArgumentParser(
122          description="Migrate declarative test YAML files to the new A2A spec format.",
123          formatter_class=argparse.RawTextHelpFormatter,
124      )
125      parser.add_argument(
126          "path",
127          type=str,
128          help="The root directory to scan for YAML files (e.g., 'tests/integration/scenarios_declarative/test_data').",
129      )
130      parser.add_argument(
131          "--dry-run",
132          action="store_true",
133          help="Perform a dry run without modifying any files. Reports which files would be changed.",
134      )
135      parser.add_argument(
136          "-v",
137          "--verbose",
138          action="store_true",
139          help="Enable verbose output.",
140      )
141      args = parser.parse_args()
142  
143      root_path = Path(args.path)
144      if not root_path.is_dir():
145          print(f"Error: Path '{root_path}' is not a valid directory.", file=sys.stderr)
146          sys.exit(1)
147  
148      yaml = YAML()
149      yaml.preserve_quotes = True
150      yaml.indent(mapping=2, sequence=4, offset=2)
151      # Set a very large width to prevent line wrapping, which can break long JSON strings.
152      yaml.width = 4096
153  
154      files_to_migrate = []
155      files_skipped = []
156  
157      print(f"Scanning for YAML files in '{root_path}'...")
158      yaml_files = sorted(list(root_path.glob("**/*.yaml")))
159  
160      for file_path in yaml_files:
161          if args.verbose:
162              print(f"\n--- Processing: {file_path.relative_to(root_path)} ---")
163          try:
164              with open(file_path, "r", encoding="utf-8") as f:
165                  content = yaml.load(f)
166  
167              if is_migrated(content):
168                  if args.verbose:
169                      print("Status: Already migrated. Skipping.")
170                  files_skipped.append(file_path)
171              else:
172                  if args.verbose:
173                      print("Status: Needs migration.")
174                  files_to_migrate.append(file_path)
175  
176                  if not args.dry_run:
177                      transformed_content, was_transformed = transform_data(content)
178                      if was_transformed:
179                          with open(file_path, "w", encoding="utf-8") as f:
180                              yaml.dump(transformed_content, f)
181                          if args.verbose:
182                              print("Action: Transformed and saved.")
183                      else:
184                          if args.verbose:
185                              print(
186                                  "Action: No transformation was applied despite detection."
187                              )
188  
189          except Exception as e:
190              print(f"Error processing file {file_path}: {e}", file=sys.stderr)
191  
192      print("\n--- Migration Summary ---")
193      print(f"Total YAML files found: {len(yaml_files)}")
194      print(f"Files already migrated (skipped): {len(files_skipped)}")
195      print(f"Files to be migrated: {len(files_to_migrate)}")
196  
197      if files_to_migrate:
198          print("\nFiles to be migrated:")
199          for f in files_to_migrate:
200              print(f"  - {f.relative_to(root_path.parent)}")
201  
202      if args.dry_run:
203          print("\n** DRY RUN COMPLETE. No files were modified. **")
204      else:
205          if files_to_migrate:
206              print(
207                  f"\n** MIGRATION COMPLETE. {len(files_to_migrate)} files were modified. **"
208              )
209          else:
210              print("\n** No files needed migration. **")
211  
212  
213  if __name__ == "__main__":
214      main()