/ tools / rl_training_tool.py
rl_training_tool.py
   1  #!/usr/bin/env python3
   2  """
   3  RL Training Tools Module
   4  
   5  This module provides tools for running RL training through Tinker-Atropos.
   6  Directly manages training processes without requiring a separate API server.
   7  
   8  Features:
   9  - Environment discovery (AST-based scanning for BaseEnv subclasses)
  10  - Configuration management with locked infrastructure settings
  11  - Training run lifecycle via subprocess management
  12  - WandB metrics monitoring
  13  
  14  Required environment variables:
  15  - TINKER_API_KEY: API key for Tinker service
  16  - WANDB_API_KEY: API key for Weights & Biases metrics
  17  
  18  Usage:
  19      from tools.rl_training_tool import (
  20          rl_list_environments,
  21          rl_select_environment,
  22          rl_get_current_config,
  23          rl_edit_config,
  24          rl_start_training,
  25          rl_check_status,
  26          rl_stop_training,
  27          rl_get_results,
  28      )
  29  """
  30  
  31  import ast
  32  import asyncio
  33  import importlib.util
  34  import json
  35  import os
  36  import subprocess
  37  import sys
  38  import time
  39  import uuid
  40  import logging
  41  from datetime import datetime
  42  import yaml
  43  from dataclasses import dataclass
  44  from pathlib import Path
  45  from typing import Any, Dict, List, Optional
  46  
  47  from hermes_constants import get_hermes_home
  48  
  49  logger = logging.getLogger(__name__)
  50  
  51  # ============================================================================
  52  # Path Configuration
  53  # ============================================================================
  54  
  55  # Path to tinker-atropos submodule (relative to hermes-agent root)
  56  HERMES_ROOT = Path(__file__).parent.parent
  57  TINKER_ATROPOS_ROOT = HERMES_ROOT / "tinker-atropos"
  58  ENVIRONMENTS_DIR = TINKER_ATROPOS_ROOT / "tinker_atropos" / "environments"
  59  CONFIGS_DIR = TINKER_ATROPOS_ROOT / "configs"
  60  LOGS_DIR = get_hermes_home() / "logs" / "rl_training"
  61  
  62  def _ensure_logs_dir():
  63      """Lazily create logs directory on first use (avoid side effects at import time)."""
  64      if TINKER_ATROPOS_ROOT.exists():
  65          LOGS_DIR.mkdir(exist_ok=True)
  66  
  67  # ============================================================================
  68  # Locked Configuration (Infrastructure Settings)
  69  # ============================================================================
  70  
  71  # These fields cannot be changed by the model - they're tuned for our infrastructure
  72  LOCKED_FIELDS = {
  73      "env": {
  74          "tokenizer_name": "Qwen/Qwen3-8B",
  75          "rollout_server_url": "http://localhost:8000",
  76          "use_wandb": True,
  77          "max_token_length": 8192,
  78          "max_num_workers": 2048,
  79          "worker_timeout": 3600,
  80          "total_steps": 2500,
  81          "steps_per_eval": 25,
  82          "max_batches_offpolicy": 3,
  83          "inference_weight": 1.0,
  84          "eval_limit_ratio": 0.1,
  85      },
  86      "openai": [
  87          {
  88              "model_name": "Qwen/Qwen3-8B",
  89              "base_url": "http://localhost:8001/v1",
  90              "api_key": "x",
  91              "weight": 1.0,
  92              "num_requests_for_eval": 256,
  93              "timeout": 3600,
  94              "server_type": "sglang",  # Tinker uses sglang for actual training
  95          }
  96      ],
  97      "tinker": {
  98          "lora_rank": 32,
  99          "learning_rate": 0.00004,
 100          "max_token_trainer_length": 9000,
 101          "checkpoint_dir": "./temp/",
 102          "save_checkpoint_interval": 25,
 103      },
 104      "slurm": False,
 105      "testing": False,
 106  }
 107  
 108  LOCKED_FIELD_NAMES = set(LOCKED_FIELDS.get("env", {}).keys())
 109  
 110  
 111  # ============================================================================
 112  # State Management
 113  # ============================================================================
 114  
 115  @dataclass
 116  class EnvironmentInfo:
 117      """Information about a discovered environment."""
 118      name: str
 119      class_name: str
 120      file_path: str
 121      description: str = ""
 122      config_class: str = "BaseEnvConfig"
 123  
 124  
 125  @dataclass
 126  class RunState:
 127      """State for a training run."""
 128      run_id: str
 129      environment: str
 130      config: Dict[str, Any]
 131      status: str = "pending"  # pending, starting, running, stopping, stopped, completed, failed
 132      error_message: str = ""
 133      wandb_project: str = ""
 134      wandb_run_name: str = ""
 135      start_time: float = 0.0
 136      # Process handles
 137      api_process: Optional[subprocess.Popen] = None
 138      trainer_process: Optional[subprocess.Popen] = None
 139      env_process: Optional[subprocess.Popen] = None
 140  
 141  
 142  # Global state
 143  _environments: List[EnvironmentInfo] = []
 144  _current_env: Optional[str] = None
 145  _current_config: Dict[str, Any] = {}
 146  _env_config_cache: Dict[str, Dict[str, Dict[str, Any]]] = {}
 147  _active_runs: Dict[str, RunState] = {}
 148  _last_status_check: Dict[str, float] = {}
 149  
 150  # Rate limiting for status checks (30 minutes)
 151  MIN_STATUS_CHECK_INTERVAL = 30 * 60
 152  
 153  
 154  # ============================================================================
 155  # Environment Discovery
 156  # ============================================================================
 157  
 158  def _scan_environments() -> List[EnvironmentInfo]:
 159      """
 160      Scan the environments directory for BaseEnv subclasses using AST.
 161      """
 162      environments = []
 163      
 164      if not ENVIRONMENTS_DIR.exists():
 165          return environments
 166      
 167      for py_file in ENVIRONMENTS_DIR.glob("*.py"):
 168          if py_file.name.startswith("_"):
 169              continue
 170          
 171          try:
 172              with open(py_file, "r") as f:
 173                  tree = ast.parse(f.read())
 174              
 175              for node in ast.walk(tree):
 176                  if isinstance(node, ast.ClassDef):
 177                      # Check if class has BaseEnv as base
 178                      for base in node.bases:
 179                          base_name = ""
 180                          if isinstance(base, ast.Name):
 181                              base_name = base.id
 182                          elif isinstance(base, ast.Attribute):
 183                              base_name = base.attr
 184                          
 185                          if base_name == "BaseEnv":
 186                              # Extract name from class attribute if present
 187                              env_name = py_file.stem
 188                              description = ""
 189                              config_class = "BaseEnvConfig"
 190                              
 191                              for item in node.body:
 192                                  if isinstance(item, ast.Assign):
 193                                      for target in item.targets:
 194                                          if isinstance(target, ast.Name):
 195                                              if target.id == "name" and isinstance(item.value, ast.Constant):
 196                                                  env_name = item.value.value
 197                                              elif target.id == "env_config_cls" and isinstance(item.value, ast.Name):
 198                                                  config_class = item.value.id
 199                                  
 200                                  # Get docstring
 201                                  if isinstance(item, ast.Expr) and isinstance(item.value, ast.Constant):
 202                                      if isinstance(item.value.value, str) and not description:
 203                                          description = item.value.value.split("\n")[0].strip()
 204                              
 205                              environments.append(EnvironmentInfo(
 206                                  name=env_name,
 207                                  class_name=node.name,
 208                                  file_path=str(py_file),
 209                                  description=description or f"Environment from {py_file.name}",
 210                                  config_class=config_class,
 211                              ))
 212                              break
 213          except Exception as e:
 214              logger.warning("Could not parse %s: %s", py_file, e)
 215      
 216      return environments
 217  
 218  
 219  def _get_env_config_fields(env_file_path: str) -> Dict[str, Dict[str, Any]]:
 220      """
 221      Dynamically import an environment and extract its config fields.
 222      
 223      Uses config_init() to get the actual config class, with fallback to
 224      directly importing BaseEnvConfig if config_init fails.
 225      """
 226      try:
 227          # Load the environment module
 228          spec = importlib.util.spec_from_file_location("env_module", env_file_path)
 229          module = importlib.util.module_from_spec(spec)
 230          sys.modules["env_module"] = module
 231          spec.loader.exec_module(module)
 232          
 233          # Find the BaseEnv subclass
 234          env_class = None
 235          for name, obj in vars(module).items():
 236              if isinstance(obj, type) and name != "BaseEnv":
 237                  if hasattr(obj, "config_init") and callable(getattr(obj, "config_init")):
 238                      env_class = obj
 239                      break
 240          
 241          if not env_class:
 242              return {}
 243          
 244          # Try calling config_init to get the actual config class
 245          config_class = None
 246          try:
 247              env_config, server_configs = env_class.config_init()
 248              config_class = type(env_config)
 249          except Exception as config_error:
 250              # Fallback: try to import BaseEnvConfig directly from atroposlib
 251              logger.info("config_init failed (%s), using BaseEnvConfig defaults", config_error)
 252              try:
 253                  from atroposlib.envs.base import BaseEnvConfig
 254                  config_class = BaseEnvConfig
 255              except ImportError:
 256                  return {}
 257          
 258          if not config_class:
 259              return {}
 260          
 261          # Helper to make values JSON-serializable (handle enums, etc.)
 262          def make_serializable(val):
 263              if val is None:
 264                  return None
 265              if hasattr(val, 'value'):  # Enum
 266                  return val.value
 267              if hasattr(val, 'name') and hasattr(val, '__class__') and 'Enum' in str(type(val)):
 268                  return val.name
 269              return val
 270          
 271          # Extract fields from the Pydantic model
 272          fields = {}
 273          for field_name, field_info in config_class.model_fields.items():
 274              field_type = field_info.annotation
 275              default = make_serializable(field_info.default)
 276              description = field_info.description or ""
 277              
 278              is_locked = field_name in LOCKED_FIELD_NAMES
 279              
 280              # Convert type to string
 281              type_name = getattr(field_type, "__name__", str(field_type))
 282              if hasattr(field_type, "__origin__"):
 283                  type_name = str(field_type)
 284              
 285              locked_value = LOCKED_FIELDS.get("env", {}).get(field_name, default)
 286              current_value = make_serializable(locked_value) if is_locked else default
 287              
 288              fields[field_name] = {
 289                  "type": type_name,
 290                  "default": default,
 291                  "description": description,
 292                  "locked": is_locked,
 293                  "current_value": current_value,
 294              }
 295          
 296          return fields
 297          
 298      except Exception as e:
 299          logger.warning("Could not introspect environment config: %s", e)
 300          return {}
 301  
 302  
 303  def _initialize_environments():
 304      """Initialize environment list on first use."""
 305      global _environments
 306      if not _environments:
 307          _environments = _scan_environments()
 308  
 309  
 310  # ============================================================================
 311  # Subprocess Management
 312  # ============================================================================
 313  
 314  async def _spawn_training_run(run_state: RunState, config_path: Path):
 315      """
 316      Spawn the three processes needed for training:
 317      1. run-api (Atropos API server)
 318      2. launch_training.py (Tinker trainer + inference server)
 319      3. environment.py serve (the Atropos environment)
 320      """
 321      run_id = run_state.run_id
 322      
 323      _ensure_logs_dir()
 324  
 325      # Log file paths
 326      api_log = LOGS_DIR / f"api_{run_id}.log"
 327      trainer_log = LOGS_DIR / f"trainer_{run_id}.log"
 328      env_log = LOGS_DIR / f"env_{run_id}.log"
 329      
 330      try:
 331          # Step 1: Start the Atropos API server (run-api)
 332          logger.info("[%s] Starting Atropos API server (run-api)...", run_id)
 333          
 334          # File must stay open while the subprocess runs; we store the handle
 335          # on run_state so _stop_training_run() can close it when done.
 336          api_log_file = open(api_log, "w")  # closed by _stop_training_run
 337          run_state.api_log_file = api_log_file
 338          run_state.api_process = subprocess.Popen(
 339              ["run-api"],
 340              stdout=api_log_file,
 341              stderr=subprocess.STDOUT,
 342              cwd=str(TINKER_ATROPOS_ROOT),
 343          )
 344          
 345          # Wait for API to start
 346          await asyncio.sleep(5)
 347          
 348          if run_state.api_process.poll() is not None:
 349              run_state.status = "failed"
 350              run_state.error_message = f"API server exited with code {run_state.api_process.returncode}. Check {api_log}"
 351              _stop_training_run(run_state)
 352              return
 353          
 354          logger.info("[%s] Atropos API server started", run_id)
 355          
 356          # Step 2: Start the Tinker trainer
 357          logger.info("[%s] Starting Tinker trainer: launch_training.py --config %s", run_id, config_path)
 358          
 359          trainer_log_file = open(trainer_log, "w")  # closed by _stop_training_run
 360          run_state.trainer_log_file = trainer_log_file
 361          run_state.trainer_process = subprocess.Popen(
 362              [sys.executable, "launch_training.py", "--config", str(config_path)],
 363              stdout=trainer_log_file,
 364              stderr=subprocess.STDOUT,
 365              cwd=str(TINKER_ATROPOS_ROOT),
 366              env={**os.environ, "TINKER_API_KEY": os.getenv("TINKER_API_KEY", "")},
 367          )
 368          
 369          # Wait for trainer to initialize (it starts FastAPI inference server on 8001)
 370          logger.info("[%s] Waiting 30 seconds for trainer to initialize...", run_id)
 371          await asyncio.sleep(30)
 372          
 373          if run_state.trainer_process.poll() is not None:
 374              run_state.status = "failed"
 375              run_state.error_message = f"Trainer exited with code {run_state.trainer_process.returncode}. Check {trainer_log}"
 376              _stop_training_run(run_state)
 377              return
 378          
 379          logger.info("[%s] Trainer started, inference server on port 8001", run_id)
 380          
 381          # Step 3: Start the environment
 382          logger.info("[%s] Waiting 90 more seconds before starting environment...", run_id)
 383          await asyncio.sleep(90)
 384          
 385          # Find the environment file
 386          env_info = None
 387          for env in _environments:
 388              if env.name == run_state.environment:
 389                  env_info = env
 390                  break
 391          
 392          if not env_info:
 393              run_state.status = "failed"
 394              run_state.error_message = f"Environment '{run_state.environment}' not found"
 395              _stop_training_run(run_state)
 396              return
 397          
 398          logger.info("[%s] Starting environment: %s serve", run_id, env_info.file_path)
 399          
 400          env_log_file = open(env_log, "w")  # closed by _stop_training_run
 401          run_state.env_log_file = env_log_file
 402          run_state.env_process = subprocess.Popen(
 403              [sys.executable, str(env_info.file_path), "serve", "--config", str(config_path)],
 404              stdout=env_log_file,
 405              stderr=subprocess.STDOUT,
 406              cwd=str(TINKER_ATROPOS_ROOT),
 407          )
 408          
 409          # Wait for environment to connect
 410          await asyncio.sleep(10)
 411          
 412          if run_state.env_process.poll() is not None:
 413              run_state.status = "failed"
 414              run_state.error_message = f"Environment exited with code {run_state.env_process.returncode}. Check {env_log}"
 415              _stop_training_run(run_state)
 416              return
 417          
 418          run_state.status = "running"
 419          run_state.start_time = time.time()
 420          logger.info("[%s] Training run started successfully!", run_id)
 421          
 422          # Start background monitoring
 423          asyncio.create_task(_monitor_training_run(run_state))
 424          
 425      except Exception as e:
 426          run_state.status = "failed"
 427          run_state.error_message = str(e)
 428          _stop_training_run(run_state)
 429  
 430  
 431  async def _monitor_training_run(run_state: RunState):
 432      """Background task to monitor a training run."""
 433      while run_state.status == "running":
 434          await asyncio.sleep(30)  # Check every 30 seconds
 435          
 436          # Check if any process has died
 437          if run_state.env_process and run_state.env_process.poll() is not None:
 438              exit_code = run_state.env_process.returncode
 439              if exit_code == 0:
 440                  run_state.status = "completed"
 441              else:
 442                  run_state.status = "failed"
 443                  run_state.error_message = f"Environment process exited with code {exit_code}"
 444              _stop_training_run(run_state)
 445              break
 446          
 447          if run_state.trainer_process and run_state.trainer_process.poll() is not None:
 448              exit_code = run_state.trainer_process.returncode
 449              if exit_code == 0:
 450                  run_state.status = "completed"
 451              else:
 452                  run_state.status = "failed"
 453                  run_state.error_message = f"Trainer process exited with code {exit_code}"
 454              _stop_training_run(run_state)
 455              break
 456          
 457          if run_state.api_process and run_state.api_process.poll() is not None:
 458              run_state.status = "failed"
 459              run_state.error_message = "API server exited unexpectedly"
 460              _stop_training_run(run_state)
 461              break
 462  
 463  
 464  def _stop_training_run(run_state: RunState):
 465      """Stop all processes for a training run."""
 466      # Stop in reverse order: env -> trainer -> api
 467      if run_state.env_process and run_state.env_process.poll() is None:
 468          logger.info("[%s] Stopping environment process...", run_state.run_id)
 469          run_state.env_process.terminate()
 470          try:
 471              run_state.env_process.wait(timeout=10)
 472          except subprocess.TimeoutExpired:
 473              run_state.env_process.kill()
 474      
 475      if run_state.trainer_process and run_state.trainer_process.poll() is None:
 476          logger.info("[%s] Stopping trainer process...", run_state.run_id)
 477          run_state.trainer_process.terminate()
 478          try:
 479              run_state.trainer_process.wait(timeout=10)
 480          except subprocess.TimeoutExpired:
 481              run_state.trainer_process.kill()
 482      
 483      if run_state.api_process and run_state.api_process.poll() is None:
 484          logger.info("[%s] Stopping API server...", run_state.run_id)
 485          run_state.api_process.terminate()
 486          try:
 487              run_state.api_process.wait(timeout=10)
 488          except subprocess.TimeoutExpired:
 489              run_state.api_process.kill()
 490      
 491      if run_state.status == "running":
 492          run_state.status = "stopped"
 493  
 494      # Close log file handles that were opened for subprocess stdout.
 495      for attr in ("env_log_file", "trainer_log_file", "api_log_file"):
 496          fh = getattr(run_state, attr, None)
 497          if fh is not None:
 498              try:
 499                  fh.close()
 500              except Exception:
 501                  pass
 502              setattr(run_state, attr, None)
 503  
 504  
 505  # ============================================================================
 506  # Environment Discovery Tools
 507  # ============================================================================
 508  
 509  async def rl_list_environments() -> str:
 510      """
 511      List all available RL environments.
 512      
 513      Scans tinker-atropos/tinker_atropos/environments/ for Python files
 514      containing classes that inherit from BaseEnv.
 515      
 516      Returns information about each environment including:
 517      - name: Environment identifier
 518      - class_name: Python class name
 519      - file_path: Path to the environment file
 520      - description: Brief description if available
 521      
 522      TIP: To create or modify RL environments:
 523      1. Use terminal/file tools to inspect existing environments
 524      2. Study how they load datasets, define verifiers, and structure rewards
 525      3. Inspect HuggingFace datasets to understand data formats
 526      4. Copy an existing environment as a template
 527      
 528      Returns:
 529          JSON string with list of environments
 530      """
 531      _initialize_environments()
 532      
 533      response = {
 534          "environments": [
 535              {
 536                  "name": env.name,
 537                  "class_name": env.class_name,
 538                  "file_path": env.file_path,
 539                  "description": env.description,
 540              }
 541              for env in _environments
 542          ],
 543          "count": len(_environments),
 544          "tips": [
 545              "Use rl_select_environment(name) to select an environment",
 546              "Read the file_path with file tools to understand how each environment works",
 547              "Look for load_dataset(), score_answer(), get_next_item() methods",
 548          ]
 549      }
 550      
 551      return json.dumps(response, indent=2)
 552  
 553  
 554  async def rl_select_environment(name: str) -> str:
 555      """
 556      Select an RL environment for training.
 557      
 558      This loads the environment's configuration fields into memory.
 559      After selecting, use rl_get_current_config() to see all configurable options
 560      and rl_edit_config() to modify specific fields.
 561      
 562      Args:
 563          name: Name of the environment to select (from rl_list_environments)
 564      
 565      Returns:
 566          JSON string with selection result, file path, and configurable field count
 567      
 568      TIP: Read the returned file_path to understand how the environment works.
 569      """
 570      global _current_env, _current_config
 571      
 572      _initialize_environments()
 573      
 574      env_info = None
 575      for env in _environments:
 576          if env.name == name:
 577              env_info = env
 578              break
 579      
 580      if not env_info:
 581          return json.dumps({
 582              "error": f"Environment '{name}' not found",
 583              "available": [e.name for e in _environments],
 584          }, indent=2)
 585      
 586      _current_env = name
 587      
 588      # Dynamically discover config fields
 589      config_fields = _get_env_config_fields(env_info.file_path)
 590      _env_config_cache[name] = config_fields
 591      
 592      # Initialize current config with defaults for non-locked fields
 593      _current_config = {}
 594      for field_name, field_info in config_fields.items():
 595          if not field_info.get("locked", False):
 596              _current_config[field_name] = field_info.get("default")
 597      
 598      # Auto-set wandb_name to "{env_name}-DATETIME" to avoid overlaps
 599      timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
 600      _current_config["wandb_name"] = f"{name}-{timestamp}"
 601      
 602      return json.dumps({
 603          "message": f"Selected environment: {name}",
 604          "environment": name,
 605          "file_path": env_info.file_path,
 606      }, indent=2)
 607  
 608  
 609  # ============================================================================
 610  # Configuration Tools
 611  # ============================================================================
 612  
 613  async def rl_get_current_config() -> str:
 614      """
 615      Get the current environment configuration.
 616      
 617      Returns all configurable fields for the selected environment.
 618      Each environment may have different configuration options.
 619      
 620      Fields are divided into:
 621      - configurable_fields: Can be changed with rl_edit_config()
 622      - locked_fields: Infrastructure settings that cannot be changed
 623      
 624      Returns:
 625          JSON string with configurable and locked fields
 626      """
 627      if not _current_env:
 628          return json.dumps({
 629              "error": "No environment selected. Use rl_select_environment(name) first.",
 630          }, indent=2)
 631      
 632      config_fields = _env_config_cache.get(_current_env, {})
 633      
 634      configurable = []
 635      locked = []
 636      
 637      for field_name, field_info in config_fields.items():
 638          field_data = {
 639              "name": field_name,
 640              "type": field_info.get("type", "unknown"),
 641              "default": field_info.get("default"),
 642              "description": field_info.get("description", ""),
 643              "current_value": _current_config.get(field_name, field_info.get("default")),
 644          }
 645          
 646          if field_info.get("locked", False):
 647              field_data["locked_value"] = LOCKED_FIELDS.get("env", {}).get(field_name)
 648              locked.append(field_data)
 649          else:
 650              configurable.append(field_data)
 651      
 652      return json.dumps({
 653          "environment": _current_env,
 654          "configurable_fields": configurable,
 655          "locked_fields": locked,
 656          "tip": "Use rl_edit_config(field, value) to change any configurable field.",
 657      }, indent=2)
 658  
 659  
 660  async def rl_edit_config(field: str, value: Any) -> str:
 661      """
 662      Update a configuration field.
 663      
 664      Use rl_get_current_config() first to see available fields for the
 665      selected environment. Each environment has different options.
 666      
 667      Locked fields (infrastructure settings) cannot be changed.
 668      
 669      Args:
 670          field: Name of the field to update (from rl_get_current_config)
 671          value: New value for the field
 672      
 673      Returns:
 674          JSON string with updated config or error message
 675      """
 676      if not _current_env:
 677          return json.dumps({
 678              "error": "No environment selected. Use rl_select_environment(name) first.",
 679          }, indent=2)
 680      
 681      config_fields = _env_config_cache.get(_current_env, {})
 682      
 683      if field not in config_fields:
 684          return json.dumps({
 685              "error": f"Unknown field '{field}'",
 686              "available_fields": list(config_fields.keys()),
 687          }, indent=2)
 688      
 689      field_info = config_fields[field]
 690      if field_info.get("locked", False):
 691          return json.dumps({
 692              "error": f"Field '{field}' is locked and cannot be changed",
 693              "locked_value": LOCKED_FIELDS.get("env", {}).get(field),
 694          }, indent=2)
 695      
 696      _current_config[field] = value
 697      
 698      return json.dumps({
 699          "message": f"Updated {field} = {value}",
 700          "field": field,
 701          "value": value,
 702          "config": _current_config,
 703      }, indent=2)
 704  
 705  
 706  # ============================================================================
 707  # Training Management Tools
 708  # ============================================================================
 709  
 710  async def rl_start_training() -> str:
 711      """
 712      Start a new RL training run with the current environment and config.
 713      
 714      Requires an environment to be selected first using rl_select_environment().
 715      Use rl_edit_config() to adjust configuration before starting.
 716      
 717      This spawns three processes:
 718      1. run-api (Atropos trajectory API)
 719      2. launch_training.py (Tinker trainer + inference server)
 720      3. environment.py serve (the selected environment)
 721      
 722      WARNING: Training runs take hours. Use rl_check_status() to monitor
 723      progress (recommended: check every 30 minutes at most).
 724      
 725      Returns:
 726          JSON string with run_id and initial status
 727      """
 728      if not _current_env:
 729          return json.dumps({
 730              "error": "No environment selected. Use rl_select_environment(name) first.",
 731          }, indent=2)
 732      
 733      # Check API keys
 734      if not os.getenv("TINKER_API_KEY"):
 735          return json.dumps({
 736              "error": "TINKER_API_KEY not set. Add it to ~/.hermes/.env",
 737          }, indent=2)
 738      
 739      # Find environment file
 740      env_info = None
 741      for env in _environments:
 742          if env.name == _current_env:
 743              env_info = env
 744              break
 745      
 746      if not env_info or not Path(env_info.file_path).exists():
 747          return json.dumps({
 748              "error": f"Environment file not found for '{_current_env}'",
 749          }, indent=2)
 750      
 751      # Generate run ID
 752      run_id = str(uuid.uuid4())[:8]
 753      
 754      # Create config YAML
 755      CONFIGS_DIR.mkdir(exist_ok=True)
 756      config_path = CONFIGS_DIR / f"run_{run_id}.yaml"
 757      
 758      # Start with locked config as base
 759      import copy
 760      run_config = copy.deepcopy(LOCKED_FIELDS)
 761      
 762      if "env" not in run_config:
 763          run_config["env"] = {}
 764      
 765      # Apply configurable fields
 766      for field_name, value in _current_config.items():
 767          if value is not None and value != "":
 768              run_config["env"][field_name] = value
 769      
 770      # Set WandB settings
 771      wandb_project = _current_config.get("wandb_project", "atropos-tinker")
 772      if "tinker" not in run_config:
 773          run_config["tinker"] = {}
 774      run_config["tinker"]["wandb_project"] = wandb_project
 775      run_config["tinker"]["wandb_run_name"] = f"{_current_env}-{run_id}"
 776      
 777      if "wandb_name" in _current_config and _current_config["wandb_name"]:
 778          run_config["env"]["wandb_name"] = _current_config["wandb_name"]
 779      
 780      with open(config_path, "w") as f:
 781          yaml.dump(run_config, f, default_flow_style=False)
 782      
 783      # Create run state
 784      run_state = RunState(
 785          run_id=run_id,
 786          environment=_current_env,
 787          config=_current_config.copy(),
 788          status="starting",
 789          wandb_project=wandb_project,
 790          wandb_run_name=f"{_current_env}-{run_id}",
 791      )
 792      
 793      _active_runs[run_id] = run_state
 794      
 795      # Start training in background
 796      asyncio.create_task(_spawn_training_run(run_state, config_path))
 797      
 798      return json.dumps({
 799          "run_id": run_id,
 800          "status": "starting",
 801          "environment": _current_env,
 802          "config": _current_config,
 803          "wandb_project": wandb_project,
 804          "wandb_run_name": f"{_current_env}-{run_id}",
 805          "config_path": str(config_path),
 806          "logs": {
 807              "api": str(LOGS_DIR / f"api_{run_id}.log"),
 808              "trainer": str(LOGS_DIR / f"trainer_{run_id}.log"),
 809              "env": str(LOGS_DIR / f"env_{run_id}.log"),
 810          },
 811          "message": "Training starting. Use rl_check_status(run_id) to monitor (recommended: every 30 minutes).",
 812      }, indent=2)
 813  
 814  
 815  async def rl_check_status(run_id: str) -> str:
 816      """
 817      Get status and metrics for a training run.
 818      
 819      RATE LIMITED: For long-running training, this function enforces a
 820      minimum 30-minute interval between checks for the same run_id.
 821      
 822      Args:
 823          run_id: The run ID returned by rl_start_training()
 824      
 825      Returns:
 826          JSON string with run status and metrics
 827      """
 828      # Check rate limiting
 829      now = time.time()
 830      if run_id in _last_status_check:
 831          elapsed = now - _last_status_check[run_id]
 832          if elapsed < MIN_STATUS_CHECK_INTERVAL:
 833              remaining = MIN_STATUS_CHECK_INTERVAL - elapsed
 834              return json.dumps({
 835                  "rate_limited": True,
 836                  "run_id": run_id,
 837                  "message": f"Rate limited. Next check available in {remaining/60:.0f} minutes.",
 838                  "next_check_in_seconds": remaining,
 839              }, indent=2)
 840      
 841      _last_status_check[run_id] = now
 842      
 843      if run_id not in _active_runs:
 844          return json.dumps({
 845              "error": f"Run '{run_id}' not found",
 846              "active_runs": list(_active_runs.keys()),
 847          }, indent=2)
 848      
 849      run_state = _active_runs[run_id]
 850      
 851      # Check process status
 852      processes = {
 853          "api": run_state.api_process.poll() if run_state.api_process else None,
 854          "trainer": run_state.trainer_process.poll() if run_state.trainer_process else None,
 855          "env": run_state.env_process.poll() if run_state.env_process else None,
 856      }
 857      
 858      running_time = time.time() - run_state.start_time if run_state.start_time else 0
 859      
 860      result = {
 861          "run_id": run_id,
 862          "status": run_state.status,
 863          "environment": run_state.environment,
 864          "running_time_minutes": running_time / 60,
 865          "processes": {
 866              name: "running" if code is None else f"exited ({code})"
 867              for name, code in processes.items()
 868          },
 869          "wandb_project": run_state.wandb_project,
 870          "wandb_run_name": run_state.wandb_run_name,
 871          "logs": {
 872              "api": str(LOGS_DIR / f"api_{run_id}.log"),
 873              "trainer": str(LOGS_DIR / f"trainer_{run_id}.log"),
 874              "env": str(LOGS_DIR / f"env_{run_id}.log"),
 875          },
 876      }
 877      
 878      if run_state.error_message:
 879          result["error"] = run_state.error_message
 880      
 881      # Try to get WandB metrics if available
 882      try:
 883          import wandb
 884          api = wandb.Api()
 885          runs = api.runs(
 886              f"{os.getenv('WANDB_ENTITY', 'nousresearch')}/{run_state.wandb_project}",
 887              filters={"display_name": run_state.wandb_run_name}
 888          )
 889          if runs:
 890              wandb_run = runs[0]
 891              result["wandb_url"] = wandb_run.url
 892              result["metrics"] = {
 893                  "step": wandb_run.summary.get("_step", 0),
 894                  "reward_mean": wandb_run.summary.get("train/reward_mean"),
 895                  "percent_correct": wandb_run.summary.get("train/percent_correct"),
 896                  "eval_percent_correct": wandb_run.summary.get("eval/percent_correct"),
 897              }
 898      except Exception as e:
 899          result["wandb_error"] = str(e)
 900      
 901      return json.dumps(result, indent=2)
 902  
 903  
 904  async def rl_stop_training(run_id: str) -> str:
 905      """
 906      Stop a running training job.
 907      
 908      Args:
 909          run_id: The run ID to stop
 910      
 911      Returns:
 912          JSON string with stop confirmation
 913      """
 914      if run_id not in _active_runs:
 915          return json.dumps({
 916              "error": f"Run '{run_id}' not found",
 917              "active_runs": list(_active_runs.keys()),
 918          }, indent=2)
 919      
 920      run_state = _active_runs[run_id]
 921      
 922      if run_state.status not in ("running", "starting"):
 923          return json.dumps({
 924              "message": f"Run '{run_id}' is not running (status: {run_state.status})",
 925          }, indent=2)
 926      
 927      _stop_training_run(run_state)
 928      
 929      return json.dumps({
 930          "message": f"Stopped training run '{run_id}'",
 931          "run_id": run_id,
 932          "status": run_state.status,
 933      }, indent=2)
 934  
 935  
 936  async def rl_get_results(run_id: str) -> str:
 937      """
 938      Get final results and metrics for a training run.
 939      
 940      Args:
 941          run_id: The run ID to get results for
 942      
 943      Returns:
 944          JSON string with final results
 945      """
 946      if run_id not in _active_runs:
 947          return json.dumps({
 948              "error": f"Run '{run_id}' not found",
 949          }, indent=2)
 950      
 951      run_state = _active_runs[run_id]
 952      
 953      result = {
 954          "run_id": run_id,
 955          "status": run_state.status,
 956          "environment": run_state.environment,
 957          "wandb_project": run_state.wandb_project,
 958          "wandb_run_name": run_state.wandb_run_name,
 959      }
 960      
 961      # Get WandB metrics
 962      try:
 963          import wandb
 964          api = wandb.Api()
 965          runs = api.runs(
 966              f"{os.getenv('WANDB_ENTITY', 'nousresearch')}/{run_state.wandb_project}",
 967              filters={"display_name": run_state.wandb_run_name}
 968          )
 969          if runs:
 970              wandb_run = runs[0]
 971              result["wandb_url"] = wandb_run.url
 972              result["final_metrics"] = dict(wandb_run.summary)
 973              result["history"] = [dict(row) for row in wandb_run.history(samples=10)]
 974      except Exception as e:
 975          result["wandb_error"] = str(e)
 976      
 977      return json.dumps(result, indent=2)
 978  
 979  
 980  async def rl_list_runs() -> str:
 981      """
 982      List all training runs (active and completed).
 983      
 984      Returns:
 985          JSON string with list of runs and their status
 986      """
 987      runs = []
 988      for run_id, run_state in _active_runs.items():
 989          runs.append({
 990              "run_id": run_id,
 991              "environment": run_state.environment,
 992              "status": run_state.status,
 993              "wandb_run_name": run_state.wandb_run_name,
 994          })
 995      
 996      return json.dumps({
 997          "runs": runs,
 998          "count": len(runs),
 999      }, indent=2)
1000  
1001  
1002  # ============================================================================
1003  # Inference Testing (via Atropos `process` mode with OpenRouter)
1004  # ============================================================================
1005  
1006  # Test models at different scales for robustness testing
1007  # These are cheap, capable models on OpenRouter for testing parsing/scoring
1008  TEST_MODELS = [
1009      {"id": "qwen/qwen3-8b", "name": "Qwen3 8B", "scale": "small"},
1010      {"id": "z-ai/glm-4.7-flash", "name": "GLM-4.7 Flash", "scale": "medium"},
1011      {"id": "minimax/minimax-m2.7", "name": "MiniMax M2.7", "scale": "large"},
1012  ]
1013  
1014  # Default test parameters - quick but representative
1015  DEFAULT_NUM_STEPS = 3       # Number of steps (items) to test
1016  DEFAULT_GROUP_SIZE = 16     # Completions per item (like training)
1017  
1018  
1019  async def rl_test_inference(
1020      num_steps: int = DEFAULT_NUM_STEPS,
1021      group_size: int = DEFAULT_GROUP_SIZE,
1022      models: Optional[List[str]] = None,
1023  ) -> str:
1024      """
1025      Quick inference test for any environment using Atropos's `process` mode.
1026      
1027      Runs a few steps of inference + scoring to validate:
1028      - Environment loads correctly
1029      - Prompt construction works
1030      - Inference parsing is robust (tested with multiple model scales)
1031      - Verifier/scoring logic works
1032      
1033      Default: 3 steps × 16 completions = 48 total rollouts per model.
1034      Tests 3 models = 144 total rollouts. Quick sanity check.
1035      
1036      Test models (varying intelligence levels for robustness):
1037      - qwen/qwen3-8b (small)
1038      - zhipu-ai/glm-4-flash (medium)
1039      - minimax/minimax-m1 (large)
1040      
1041      Args:
1042          num_steps: Steps to run (default: 3, max recommended for testing)
1043          group_size: Completions per step (default: 16, like training)
1044          models: Optional model IDs to test. If None, uses all 3 test models.
1045      
1046      Returns:
1047          JSON with results per model: steps_tested, accuracy, scores
1048      """
1049      if not _current_env:
1050          return json.dumps({
1051              "error": "No environment selected. Use rl_select_environment(name) first.",
1052          }, indent=2)
1053      
1054      api_key = os.getenv("OPENROUTER_API_KEY")
1055      if not api_key:
1056          return json.dumps({
1057              "error": "OPENROUTER_API_KEY not set. Required for inference testing.",
1058          }, indent=2)
1059      
1060      # Find environment info
1061      env_info = None
1062      for env in _environments:
1063          if env.name == _current_env:
1064              env_info = env
1065              break
1066      
1067      if not env_info:
1068          return json.dumps({
1069              "error": f"Environment '{_current_env}' not found",
1070          }, indent=2)
1071      
1072      # Determine which models to test
1073      if models:
1074          test_models = [m for m in TEST_MODELS if m["id"] in models]
1075          if not test_models:
1076              test_models = [{"id": m, "name": m, "scale": "custom"} for m in models]
1077      else:
1078          test_models = TEST_MODELS
1079      
1080      # Calculate total rollouts for logging
1081      total_rollouts_per_model = num_steps * group_size
1082      total_rollouts = total_rollouts_per_model * len(test_models)
1083      
1084      results = {
1085          "environment": _current_env,
1086          "environment_file": env_info.file_path,
1087          "test_config": {
1088              "num_steps": num_steps,
1089              "group_size": group_size,
1090              "rollouts_per_model": total_rollouts_per_model,
1091              "total_rollouts": total_rollouts,
1092          },
1093          "models_tested": [],
1094      }
1095      
1096      # Create output directory for test results
1097      _ensure_logs_dir()
1098      test_output_dir = LOGS_DIR / "inference_tests"
1099      test_output_dir.mkdir(exist_ok=True)
1100      
1101      for model_info in test_models:
1102          model_id = model_info["id"]
1103          model_safe_name = model_id.replace("/", "_")
1104          
1105          print(f"\n{'='*60}")
1106          print(f"Testing with {model_info['name']} ({model_id})")
1107          print(f"{'='*60}")
1108          
1109          # Output file for this test run
1110          output_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.jsonl"
1111          
1112          # Generate unique run ID for wandb
1113          test_run_id = str(uuid.uuid4())[:8]
1114          wandb_run_name = f"test_inference_RSIAgent_{_current_env}_{test_run_id}"
1115          
1116          # Build the process command using Atropos's built-in CLI
1117          # This runs the environment's actual code with OpenRouter as the inference backend
1118          # We pass our locked settings + test-specific overrides via CLI args
1119          cmd = [
1120              sys.executable, env_info.file_path, "process",
1121              # Test-specific overrides
1122              "--env.total_steps", str(num_steps),
1123              "--env.group_size", str(group_size),
1124              "--env.use_wandb", "true",  # Enable wandb for test tracking
1125              "--env.wandb_name", wandb_run_name,
1126              "--env.data_path_to_save_groups", str(output_file),
1127              # Use locked settings from our config
1128              "--env.tokenizer_name", LOCKED_FIELDS["env"]["tokenizer_name"],
1129              "--env.max_token_length", str(LOCKED_FIELDS["env"]["max_token_length"]),
1130              "--env.max_num_workers", str(LOCKED_FIELDS["env"]["max_num_workers"]),
1131              "--env.max_batches_offpolicy", str(LOCKED_FIELDS["env"]["max_batches_offpolicy"]),
1132              # OpenRouter config for inference testing
1133              # IMPORTANT: Use server_type=openai for OpenRouter (not sglang)
1134              # sglang is only for actual training with Tinker's inference server
1135              "--openai.base_url", "https://openrouter.ai/api/v1",
1136              "--openai.api_key", api_key,
1137              "--openai.model_name", model_id,
1138              "--openai.server_type", "openai",  # OpenRouter is OpenAI-compatible
1139              "--openai.health_check", "false",  # OpenRouter doesn't have health endpoint
1140          ]
1141          
1142          # Debug: Print the full command
1143          cmd_str = " ".join(str(c) for c in cmd)
1144          # Hide API key in printed output
1145          cmd_display = cmd_str.replace(api_key, "***API_KEY***")
1146          print(f"Command: {cmd_display}")
1147          print(f"Working dir: {TINKER_ATROPOS_ROOT}")
1148          print(f"WandB run: {wandb_run_name}")
1149          print(f"  {num_steps} steps × {group_size} completions = {total_rollouts_per_model} rollouts")
1150          
1151          model_results = {
1152              "model": model_id,
1153              "name": model_info["name"],
1154              "scale": model_info["scale"],
1155              "wandb_run": wandb_run_name,
1156              "output_file": str(output_file),
1157              "steps": [],
1158              "steps_tested": 0,
1159              "total_completions": 0,
1160              "correct_completions": 0,
1161          }
1162          
1163          try:
1164              # Run the process command with real-time output streaming
1165              process = await asyncio.create_subprocess_exec(
1166                  *cmd,
1167                  stdout=asyncio.subprocess.PIPE,
1168                  stderr=asyncio.subprocess.PIPE,
1169                  cwd=str(TINKER_ATROPOS_ROOT),
1170              )
1171              
1172              # Stream output in real-time while collecting for logs
1173              stdout_lines = []
1174              stderr_lines = []
1175              log_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.log"
1176              
1177              async def read_stream(stream, lines_list, prefix=""):
1178                  """Read stream line by line and print in real-time."""
1179                  while True:
1180                      line = await stream.readline()
1181                      if not line:
1182                          break
1183                      decoded = line.decode().rstrip()
1184                      lines_list.append(decoded)
1185                      # Print progress-related lines in real-time
1186                      if any(kw in decoded.lower() for kw in ['processing', 'group', 'step', 'progress', '%', 'completed']):
1187                          print(f"  {prefix}{decoded}")
1188              
1189              # Read both streams concurrently with timeout
1190              try:
1191                  await asyncio.wait_for(
1192                      asyncio.gather(
1193                          read_stream(process.stdout, stdout_lines, "📊 "),
1194                          read_stream(process.stderr, stderr_lines, "⚠️ "),
1195                      ),
1196                      timeout=600,  # 10 minute timeout per model
1197                  )
1198              except asyncio.TimeoutError:
1199                  process.kill()
1200                  raise
1201              
1202              await process.wait()
1203              
1204              # Combine output for logging
1205              stdout_text = "\n".join(stdout_lines)
1206              stderr_text = "\n".join(stderr_lines)
1207              
1208              # Write logs to files for inspection outside CLI
1209              with open(log_file, "w") as f:
1210                  f.write(f"Command: {cmd_display}\n")
1211                  f.write(f"Working dir: {TINKER_ATROPOS_ROOT}\n")
1212                  f.write(f"Return code: {process.returncode}\n")
1213                  f.write(f"\n{'='*60}\n")
1214                  f.write(f"STDOUT:\n{'='*60}\n")
1215                  f.write(stdout_text or "(empty)\n")
1216                  f.write(f"\n{'='*60}\n")
1217                  f.write(f"STDERR:\n{'='*60}\n")
1218                  f.write(stderr_text or "(empty)\n")
1219              
1220              print(f"  Log file: {log_file}")
1221              
1222              if process.returncode != 0:
1223                  model_results["error"] = f"Process exited with code {process.returncode}"
1224                  model_results["stderr"] = stderr_text[-1000:]
1225                  model_results["stdout"] = stdout_text[-1000:]
1226                  model_results["log_file"] = str(log_file)
1227                  print(f"\n  ❌ Error: {model_results['error']}")
1228                  # Print last few lines of stderr for debugging
1229                  if stderr_lines:
1230                      print("  Last errors:")
1231                      for line in stderr_lines[-5:]:
1232                          print(f"    {line}")
1233              else:
1234                  print("\n  ✅ Process completed successfully")
1235                  print(f"  Output file: {output_file}")
1236                  print(f"  File exists: {output_file.exists()}")
1237                  
1238                  # Parse the output JSONL file
1239                  if output_file.exists():
1240                      # Read JSONL file (one JSON object per line = one step)
1241                      with open(output_file, "r") as f:
1242                          for line in f:
1243                              line = line.strip()
1244                              if not line:
1245                                  continue
1246                              try:
1247                                  item = json.loads(line)
1248                                  scores = item.get("scores", [])
1249                                  model_results["steps_tested"] += 1
1250                                  model_results["total_completions"] += len(scores)
1251                                  correct = sum(1 for s in scores if s > 0)
1252                                  model_results["correct_completions"] += correct
1253                                  
1254                                  model_results["steps"].append({
1255                                      "step": model_results["steps_tested"],
1256                                      "completions": len(scores),
1257                                      "correct": correct,
1258                                      "scores": scores,
1259                                  })
1260                              except json.JSONDecodeError:
1261                                  continue
1262                      
1263                      print(f"  Completed {model_results['steps_tested']} steps")
1264                  else:
1265                      model_results["error"] = f"Output file not created: {output_file}"
1266                      
1267          except asyncio.TimeoutError:
1268              model_results["error"] = "Process timed out after 10 minutes"
1269              print("  Timeout!")
1270          except Exception as e:
1271              model_results["error"] = str(e)
1272              print(f"  Error: {e}")
1273          
1274          # Calculate stats
1275          if model_results["total_completions"] > 0:
1276              model_results["accuracy"] = round(
1277                  model_results["correct_completions"] / model_results["total_completions"], 3
1278              )
1279          else:
1280              model_results["accuracy"] = 0
1281              
1282          if model_results["steps_tested"] > 0:
1283              steps_with_correct = sum(1 for s in model_results["steps"] if s.get("correct", 0) > 0)
1284              model_results["steps_with_correct"] = steps_with_correct
1285              model_results["step_success_rate"] = round(
1286                  steps_with_correct / model_results["steps_tested"], 3
1287              )
1288          else:
1289              model_results["steps_with_correct"] = 0
1290              model_results["step_success_rate"] = 0
1291          
1292          print(f"  Results: {model_results['correct_completions']}/{model_results['total_completions']} correct")
1293          print(f"  Accuracy: {model_results['accuracy']:.1%}")
1294          
1295          results["models_tested"].append(model_results)
1296      
1297      # Overall summary
1298      working_models = [m for m in results["models_tested"] if m.get("steps_tested", 0) > 0]
1299      
1300      results["summary"] = {
1301          "steps_requested": num_steps,
1302          "models_tested": len(test_models),
1303          "models_succeeded": len(working_models),
1304          "best_model": max(working_models, key=lambda x: x.get("accuracy", 0))["model"] if working_models else None,
1305          "avg_accuracy": round(
1306              sum(m.get("accuracy", 0) for m in working_models) / len(working_models), 3
1307          ) if working_models else 0,
1308          "environment_working": bool(working_models),
1309          "output_directory": str(test_output_dir),
1310      }
1311      
1312      return json.dumps(results, indent=2)
1313  
1314  
1315  # ============================================================================
1316  # Requirements Check
1317  # ============================================================================
1318  
1319  def check_rl_python_version() -> bool:
1320      """
1321      Check if Python version meets the minimum for RL tools.
1322      
1323      tinker-atropos depends on the 'tinker' package which requires Python >= 3.11.
1324      """
1325      return sys.version_info >= (3, 11)
1326  
1327  
1328  def check_rl_api_keys() -> bool:
1329      """
1330      Check if required API keys and Python version are available.
1331      
1332      RL training requires:
1333      - Python >= 3.11 (tinker package requirement)
1334      - TINKER_API_KEY for the Tinker training API
1335      - WANDB_API_KEY for Weights & Biases metrics
1336      """
1337      if not check_rl_python_version():
1338          return False
1339      tinker_key = os.getenv("TINKER_API_KEY")
1340      wandb_key = os.getenv("WANDB_API_KEY")
1341      return bool(tinker_key) and bool(wandb_key)
1342  
1343  
1344  def get_missing_keys() -> List[str]:
1345      """
1346      Get list of missing requirements for RL tools (API keys and Python version).
1347      """
1348      missing = []
1349      if not check_rl_python_version():
1350          missing.append(f"Python >= 3.11 (current: {sys.version_info.major}.{sys.version_info.minor})")
1351      if not os.getenv("TINKER_API_KEY"):
1352          missing.append("TINKER_API_KEY")
1353      if not os.getenv("WANDB_API_KEY"):
1354          missing.append("WANDB_API_KEY")
1355      return missing
1356  
1357  
1358  # ---------------------------------------------------------------------------
1359  # Schemas + Registry
1360  # ---------------------------------------------------------------------------
1361  from tools.registry import registry
1362  
1363  RL_LIST_ENVIRONMENTS_SCHEMA = {"name": "rl_list_environments", "description": "List all available RL environments. Returns environment names, paths, and descriptions. TIP: Read the file_path with file tools to understand how each environment works (verifiers, data loading, rewards).", "parameters": {"type": "object", "properties": {}, "required": []}}
1364  RL_SELECT_ENVIRONMENT_SCHEMA = {"name": "rl_select_environment", "description": "Select an RL environment for training. Loads the environment's default configuration. After selecting, use rl_get_current_config() to see settings and rl_edit_config() to modify them.", "parameters": {"type": "object", "properties": {"name": {"type": "string", "description": "Name of the environment to select (from rl_list_environments)"}}, "required": ["name"]}}
1365  RL_GET_CURRENT_CONFIG_SCHEMA = {"name": "rl_get_current_config", "description": "Get the current environment configuration. Returns only fields that can be modified: group_size, max_token_length, total_steps, steps_per_eval, use_wandb, wandb_name, max_num_workers.", "parameters": {"type": "object", "properties": {}, "required": []}}
1366  RL_EDIT_CONFIG_SCHEMA = {"name": "rl_edit_config", "description": "Update a configuration field. Use rl_get_current_config() first to see all available fields for the selected environment. Each environment has different configurable options. Infrastructure settings (tokenizer, URLs, lora_rank, learning_rate) are locked.", "parameters": {"type": "object", "properties": {"field": {"type": "string", "description": "Name of the field to update (get available fields from rl_get_current_config)"}, "value": {"description": "New value for the field"}}, "required": ["field", "value"]}}
1367  RL_START_TRAINING_SCHEMA = {"name": "rl_start_training", "description": "Start a new RL training run with the current environment and config. Most training parameters (lora_rank, learning_rate, etc.) are fixed. Use rl_edit_config() to set group_size, batch_size, wandb_project before starting. WARNING: Training takes hours.", "parameters": {"type": "object", "properties": {}, "required": []}}
1368  RL_CHECK_STATUS_SCHEMA = {"name": "rl_check_status", "description": "Get status and metrics for a training run. RATE LIMITED: enforces 30-minute minimum between checks for the same run. Returns WandB metrics: step, state, reward_mean, loss, percent_correct.", "parameters": {"type": "object", "properties": {"run_id": {"type": "string", "description": "The run ID from rl_start_training()"}}, "required": ["run_id"]}}
1369  RL_STOP_TRAINING_SCHEMA = {"name": "rl_stop_training", "description": "Stop a running training job. Use if metrics look bad, training is stagnant, or you want to try different settings.", "parameters": {"type": "object", "properties": {"run_id": {"type": "string", "description": "The run ID to stop"}}, "required": ["run_id"]}}
1370  RL_GET_RESULTS_SCHEMA = {"name": "rl_get_results", "description": "Get final results and metrics for a completed training run. Returns final metrics and path to trained weights.", "parameters": {"type": "object", "properties": {"run_id": {"type": "string", "description": "The run ID to get results for"}}, "required": ["run_id"]}}
1371  RL_LIST_RUNS_SCHEMA = {"name": "rl_list_runs", "description": "List all training runs (active and completed) with their status.", "parameters": {"type": "object", "properties": {}, "required": []}}
1372  RL_TEST_INFERENCE_SCHEMA = {"name": "rl_test_inference", "description": "Quick inference test for any environment. Runs a few steps of inference + scoring using OpenRouter. Default: 3 steps x 16 completions = 48 rollouts per model, testing 3 models = 144 total. Tests environment loading, prompt construction, inference parsing, and verifier logic. Use BEFORE training to catch issues.", "parameters": {"type": "object", "properties": {"num_steps": {"type": "integer", "description": "Number of steps to run (default: 3, recommended max for testing)", "default": 3}, "group_size": {"type": "integer", "description": "Completions per step (default: 16, like training)", "default": 16}, "models": {"type": "array", "items": {"type": "string"}, "description": "Optional list of OpenRouter model IDs. Default: qwen/qwen3-8b, z-ai/glm-4.7-flash, minimax/minimax-m2.7"}}, "required": []}}
1373  
1374  _rl_env = ["TINKER_API_KEY", "WANDB_API_KEY"]
1375  
1376  registry.register(name="rl_list_environments", emoji="🧪", toolset="rl", schema=RL_LIST_ENVIRONMENTS_SCHEMA,
1377      handler=lambda args, **kw: rl_list_environments(), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True)
1378  registry.register(name="rl_select_environment", emoji="🧪", toolset="rl", schema=RL_SELECT_ENVIRONMENT_SCHEMA,
1379      handler=lambda args, **kw: rl_select_environment(name=args.get("name", "")), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True)
1380  registry.register(name="rl_get_current_config", emoji="🧪", toolset="rl", schema=RL_GET_CURRENT_CONFIG_SCHEMA,
1381      handler=lambda args, **kw: rl_get_current_config(), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True)
1382  registry.register(name="rl_edit_config", emoji="🧪", toolset="rl", schema=RL_EDIT_CONFIG_SCHEMA,
1383      handler=lambda args, **kw: rl_edit_config(field=args.get("field", ""), value=args.get("value")), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True)
1384  registry.register(name="rl_start_training", emoji="🧪", toolset="rl", schema=RL_START_TRAINING_SCHEMA,
1385      handler=lambda args, **kw: rl_start_training(), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True)
1386  registry.register(name="rl_check_status", emoji="🧪", toolset="rl", schema=RL_CHECK_STATUS_SCHEMA,
1387      handler=lambda args, **kw: rl_check_status(run_id=args.get("run_id", "")), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True)
1388  registry.register(name="rl_stop_training", emoji="🧪", toolset="rl", schema=RL_STOP_TRAINING_SCHEMA,
1389      handler=lambda args, **kw: rl_stop_training(run_id=args.get("run_id", "")), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True)
1390  registry.register(name="rl_get_results", emoji="🧪", toolset="rl", schema=RL_GET_RESULTS_SCHEMA,
1391      handler=lambda args, **kw: rl_get_results(run_id=args.get("run_id", "")), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True)
1392  registry.register(name="rl_list_runs", emoji="🧪", toolset="rl", schema=RL_LIST_RUNS_SCHEMA,
1393      handler=lambda args, **kw: rl_list_runs(), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True)
1394  registry.register(name="rl_test_inference", emoji="🧪", toolset="rl", schema=RL_TEST_INFERENCE_SCHEMA,
1395      handler=lambda args, **kw: rl_test_inference(num_steps=args.get("num_steps", 3), group_size=args.get("group_size", 16), models=args.get("models")),
1396      check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True)