rl_training_tool.py
1 #!/usr/bin/env python3 2 """ 3 RL Training Tools Module 4 5 This module provides tools for running RL training through Tinker-Atropos. 6 Directly manages training processes without requiring a separate API server. 7 8 Features: 9 - Environment discovery (AST-based scanning for BaseEnv subclasses) 10 - Configuration management with locked infrastructure settings 11 - Training run lifecycle via subprocess management 12 - WandB metrics monitoring 13 14 Required environment variables: 15 - TINKER_API_KEY: API key for Tinker service 16 - WANDB_API_KEY: API key for Weights & Biases metrics 17 18 Usage: 19 from tools.rl_training_tool import ( 20 rl_list_environments, 21 rl_select_environment, 22 rl_get_current_config, 23 rl_edit_config, 24 rl_start_training, 25 rl_check_status, 26 rl_stop_training, 27 rl_get_results, 28 ) 29 """ 30 31 import ast 32 import asyncio 33 import importlib.util 34 import json 35 import os 36 import subprocess 37 import sys 38 import time 39 import uuid 40 import logging 41 from datetime import datetime 42 import yaml 43 from dataclasses import dataclass 44 from pathlib import Path 45 from typing import Any, Dict, List, Optional 46 47 from hermes_constants import get_hermes_home 48 49 logger = logging.getLogger(__name__) 50 51 # ============================================================================ 52 # Path Configuration 53 # ============================================================================ 54 55 # Path to tinker-atropos submodule (relative to hermes-agent root) 56 HERMES_ROOT = Path(__file__).parent.parent 57 TINKER_ATROPOS_ROOT = HERMES_ROOT / "tinker-atropos" 58 ENVIRONMENTS_DIR = TINKER_ATROPOS_ROOT / "tinker_atropos" / "environments" 59 CONFIGS_DIR = TINKER_ATROPOS_ROOT / "configs" 60 LOGS_DIR = get_hermes_home() / "logs" / "rl_training" 61 62 def _ensure_logs_dir(): 63 """Lazily create logs directory on first use (avoid side effects at import time).""" 64 if TINKER_ATROPOS_ROOT.exists(): 65 LOGS_DIR.mkdir(exist_ok=True) 66 67 # ============================================================================ 68 # Locked Configuration (Infrastructure Settings) 69 # ============================================================================ 70 71 # These fields cannot be changed by the model - they're tuned for our infrastructure 72 LOCKED_FIELDS = { 73 "env": { 74 "tokenizer_name": "Qwen/Qwen3-8B", 75 "rollout_server_url": "http://localhost:8000", 76 "use_wandb": True, 77 "max_token_length": 8192, 78 "max_num_workers": 2048, 79 "worker_timeout": 3600, 80 "total_steps": 2500, 81 "steps_per_eval": 25, 82 "max_batches_offpolicy": 3, 83 "inference_weight": 1.0, 84 "eval_limit_ratio": 0.1, 85 }, 86 "openai": [ 87 { 88 "model_name": "Qwen/Qwen3-8B", 89 "base_url": "http://localhost:8001/v1", 90 "api_key": "x", 91 "weight": 1.0, 92 "num_requests_for_eval": 256, 93 "timeout": 3600, 94 "server_type": "sglang", # Tinker uses sglang for actual training 95 } 96 ], 97 "tinker": { 98 "lora_rank": 32, 99 "learning_rate": 0.00004, 100 "max_token_trainer_length": 9000, 101 "checkpoint_dir": "./temp/", 102 "save_checkpoint_interval": 25, 103 }, 104 "slurm": False, 105 "testing": False, 106 } 107 108 LOCKED_FIELD_NAMES = set(LOCKED_FIELDS.get("env", {}).keys()) 109 110 111 # ============================================================================ 112 # State Management 113 # ============================================================================ 114 115 @dataclass 116 class EnvironmentInfo: 117 """Information about a discovered environment.""" 118 name: str 119 class_name: str 120 file_path: str 121 description: str = "" 122 config_class: str = "BaseEnvConfig" 123 124 125 @dataclass 126 class RunState: 127 """State for a training run.""" 128 run_id: str 129 environment: str 130 config: Dict[str, Any] 131 status: str = "pending" # pending, starting, running, stopping, stopped, completed, failed 132 error_message: str = "" 133 wandb_project: str = "" 134 wandb_run_name: str = "" 135 start_time: float = 0.0 136 # Process handles 137 api_process: Optional[subprocess.Popen] = None 138 trainer_process: Optional[subprocess.Popen] = None 139 env_process: Optional[subprocess.Popen] = None 140 141 142 # Global state 143 _environments: List[EnvironmentInfo] = [] 144 _current_env: Optional[str] = None 145 _current_config: Dict[str, Any] = {} 146 _env_config_cache: Dict[str, Dict[str, Dict[str, Any]]] = {} 147 _active_runs: Dict[str, RunState] = {} 148 _last_status_check: Dict[str, float] = {} 149 150 # Rate limiting for status checks (30 minutes) 151 MIN_STATUS_CHECK_INTERVAL = 30 * 60 152 153 154 # ============================================================================ 155 # Environment Discovery 156 # ============================================================================ 157 158 def _scan_environments() -> List[EnvironmentInfo]: 159 """ 160 Scan the environments directory for BaseEnv subclasses using AST. 161 """ 162 environments = [] 163 164 if not ENVIRONMENTS_DIR.exists(): 165 return environments 166 167 for py_file in ENVIRONMENTS_DIR.glob("*.py"): 168 if py_file.name.startswith("_"): 169 continue 170 171 try: 172 with open(py_file, "r") as f: 173 tree = ast.parse(f.read()) 174 175 for node in ast.walk(tree): 176 if isinstance(node, ast.ClassDef): 177 # Check if class has BaseEnv as base 178 for base in node.bases: 179 base_name = "" 180 if isinstance(base, ast.Name): 181 base_name = base.id 182 elif isinstance(base, ast.Attribute): 183 base_name = base.attr 184 185 if base_name == "BaseEnv": 186 # Extract name from class attribute if present 187 env_name = py_file.stem 188 description = "" 189 config_class = "BaseEnvConfig" 190 191 for item in node.body: 192 if isinstance(item, ast.Assign): 193 for target in item.targets: 194 if isinstance(target, ast.Name): 195 if target.id == "name" and isinstance(item.value, ast.Constant): 196 env_name = item.value.value 197 elif target.id == "env_config_cls" and isinstance(item.value, ast.Name): 198 config_class = item.value.id 199 200 # Get docstring 201 if isinstance(item, ast.Expr) and isinstance(item.value, ast.Constant): 202 if isinstance(item.value.value, str) and not description: 203 description = item.value.value.split("\n")[0].strip() 204 205 environments.append(EnvironmentInfo( 206 name=env_name, 207 class_name=node.name, 208 file_path=str(py_file), 209 description=description or f"Environment from {py_file.name}", 210 config_class=config_class, 211 )) 212 break 213 except Exception as e: 214 logger.warning("Could not parse %s: %s", py_file, e) 215 216 return environments 217 218 219 def _get_env_config_fields(env_file_path: str) -> Dict[str, Dict[str, Any]]: 220 """ 221 Dynamically import an environment and extract its config fields. 222 223 Uses config_init() to get the actual config class, with fallback to 224 directly importing BaseEnvConfig if config_init fails. 225 """ 226 try: 227 # Load the environment module 228 spec = importlib.util.spec_from_file_location("env_module", env_file_path) 229 module = importlib.util.module_from_spec(spec) 230 sys.modules["env_module"] = module 231 spec.loader.exec_module(module) 232 233 # Find the BaseEnv subclass 234 env_class = None 235 for name, obj in vars(module).items(): 236 if isinstance(obj, type) and name != "BaseEnv": 237 if hasattr(obj, "config_init") and callable(getattr(obj, "config_init")): 238 env_class = obj 239 break 240 241 if not env_class: 242 return {} 243 244 # Try calling config_init to get the actual config class 245 config_class = None 246 try: 247 env_config, server_configs = env_class.config_init() 248 config_class = type(env_config) 249 except Exception as config_error: 250 # Fallback: try to import BaseEnvConfig directly from atroposlib 251 logger.info("config_init failed (%s), using BaseEnvConfig defaults", config_error) 252 try: 253 from atroposlib.envs.base import BaseEnvConfig 254 config_class = BaseEnvConfig 255 except ImportError: 256 return {} 257 258 if not config_class: 259 return {} 260 261 # Helper to make values JSON-serializable (handle enums, etc.) 262 def make_serializable(val): 263 if val is None: 264 return None 265 if hasattr(val, 'value'): # Enum 266 return val.value 267 if hasattr(val, 'name') and hasattr(val, '__class__') and 'Enum' in str(type(val)): 268 return val.name 269 return val 270 271 # Extract fields from the Pydantic model 272 fields = {} 273 for field_name, field_info in config_class.model_fields.items(): 274 field_type = field_info.annotation 275 default = make_serializable(field_info.default) 276 description = field_info.description or "" 277 278 is_locked = field_name in LOCKED_FIELD_NAMES 279 280 # Convert type to string 281 type_name = getattr(field_type, "__name__", str(field_type)) 282 if hasattr(field_type, "__origin__"): 283 type_name = str(field_type) 284 285 locked_value = LOCKED_FIELDS.get("env", {}).get(field_name, default) 286 current_value = make_serializable(locked_value) if is_locked else default 287 288 fields[field_name] = { 289 "type": type_name, 290 "default": default, 291 "description": description, 292 "locked": is_locked, 293 "current_value": current_value, 294 } 295 296 return fields 297 298 except Exception as e: 299 logger.warning("Could not introspect environment config: %s", e) 300 return {} 301 302 303 def _initialize_environments(): 304 """Initialize environment list on first use.""" 305 global _environments 306 if not _environments: 307 _environments = _scan_environments() 308 309 310 # ============================================================================ 311 # Subprocess Management 312 # ============================================================================ 313 314 async def _spawn_training_run(run_state: RunState, config_path: Path): 315 """ 316 Spawn the three processes needed for training: 317 1. run-api (Atropos API server) 318 2. launch_training.py (Tinker trainer + inference server) 319 3. environment.py serve (the Atropos environment) 320 """ 321 run_id = run_state.run_id 322 323 _ensure_logs_dir() 324 325 # Log file paths 326 api_log = LOGS_DIR / f"api_{run_id}.log" 327 trainer_log = LOGS_DIR / f"trainer_{run_id}.log" 328 env_log = LOGS_DIR / f"env_{run_id}.log" 329 330 try: 331 # Step 1: Start the Atropos API server (run-api) 332 logger.info("[%s] Starting Atropos API server (run-api)...", run_id) 333 334 # File must stay open while the subprocess runs; we store the handle 335 # on run_state so _stop_training_run() can close it when done. 336 api_log_file = open(api_log, "w") # closed by _stop_training_run 337 run_state.api_log_file = api_log_file 338 run_state.api_process = subprocess.Popen( 339 ["run-api"], 340 stdout=api_log_file, 341 stderr=subprocess.STDOUT, 342 cwd=str(TINKER_ATROPOS_ROOT), 343 ) 344 345 # Wait for API to start 346 await asyncio.sleep(5) 347 348 if run_state.api_process.poll() is not None: 349 run_state.status = "failed" 350 run_state.error_message = f"API server exited with code {run_state.api_process.returncode}. Check {api_log}" 351 _stop_training_run(run_state) 352 return 353 354 logger.info("[%s] Atropos API server started", run_id) 355 356 # Step 2: Start the Tinker trainer 357 logger.info("[%s] Starting Tinker trainer: launch_training.py --config %s", run_id, config_path) 358 359 trainer_log_file = open(trainer_log, "w") # closed by _stop_training_run 360 run_state.trainer_log_file = trainer_log_file 361 run_state.trainer_process = subprocess.Popen( 362 [sys.executable, "launch_training.py", "--config", str(config_path)], 363 stdout=trainer_log_file, 364 stderr=subprocess.STDOUT, 365 cwd=str(TINKER_ATROPOS_ROOT), 366 env={**os.environ, "TINKER_API_KEY": os.getenv("TINKER_API_KEY", "")}, 367 ) 368 369 # Wait for trainer to initialize (it starts FastAPI inference server on 8001) 370 logger.info("[%s] Waiting 30 seconds for trainer to initialize...", run_id) 371 await asyncio.sleep(30) 372 373 if run_state.trainer_process.poll() is not None: 374 run_state.status = "failed" 375 run_state.error_message = f"Trainer exited with code {run_state.trainer_process.returncode}. Check {trainer_log}" 376 _stop_training_run(run_state) 377 return 378 379 logger.info("[%s] Trainer started, inference server on port 8001", run_id) 380 381 # Step 3: Start the environment 382 logger.info("[%s] Waiting 90 more seconds before starting environment...", run_id) 383 await asyncio.sleep(90) 384 385 # Find the environment file 386 env_info = None 387 for env in _environments: 388 if env.name == run_state.environment: 389 env_info = env 390 break 391 392 if not env_info: 393 run_state.status = "failed" 394 run_state.error_message = f"Environment '{run_state.environment}' not found" 395 _stop_training_run(run_state) 396 return 397 398 logger.info("[%s] Starting environment: %s serve", run_id, env_info.file_path) 399 400 env_log_file = open(env_log, "w") # closed by _stop_training_run 401 run_state.env_log_file = env_log_file 402 run_state.env_process = subprocess.Popen( 403 [sys.executable, str(env_info.file_path), "serve", "--config", str(config_path)], 404 stdout=env_log_file, 405 stderr=subprocess.STDOUT, 406 cwd=str(TINKER_ATROPOS_ROOT), 407 ) 408 409 # Wait for environment to connect 410 await asyncio.sleep(10) 411 412 if run_state.env_process.poll() is not None: 413 run_state.status = "failed" 414 run_state.error_message = f"Environment exited with code {run_state.env_process.returncode}. Check {env_log}" 415 _stop_training_run(run_state) 416 return 417 418 run_state.status = "running" 419 run_state.start_time = time.time() 420 logger.info("[%s] Training run started successfully!", run_id) 421 422 # Start background monitoring 423 asyncio.create_task(_monitor_training_run(run_state)) 424 425 except Exception as e: 426 run_state.status = "failed" 427 run_state.error_message = str(e) 428 _stop_training_run(run_state) 429 430 431 async def _monitor_training_run(run_state: RunState): 432 """Background task to monitor a training run.""" 433 while run_state.status == "running": 434 await asyncio.sleep(30) # Check every 30 seconds 435 436 # Check if any process has died 437 if run_state.env_process and run_state.env_process.poll() is not None: 438 exit_code = run_state.env_process.returncode 439 if exit_code == 0: 440 run_state.status = "completed" 441 else: 442 run_state.status = "failed" 443 run_state.error_message = f"Environment process exited with code {exit_code}" 444 _stop_training_run(run_state) 445 break 446 447 if run_state.trainer_process and run_state.trainer_process.poll() is not None: 448 exit_code = run_state.trainer_process.returncode 449 if exit_code == 0: 450 run_state.status = "completed" 451 else: 452 run_state.status = "failed" 453 run_state.error_message = f"Trainer process exited with code {exit_code}" 454 _stop_training_run(run_state) 455 break 456 457 if run_state.api_process and run_state.api_process.poll() is not None: 458 run_state.status = "failed" 459 run_state.error_message = "API server exited unexpectedly" 460 _stop_training_run(run_state) 461 break 462 463 464 def _stop_training_run(run_state: RunState): 465 """Stop all processes for a training run.""" 466 # Stop in reverse order: env -> trainer -> api 467 if run_state.env_process and run_state.env_process.poll() is None: 468 logger.info("[%s] Stopping environment process...", run_state.run_id) 469 run_state.env_process.terminate() 470 try: 471 run_state.env_process.wait(timeout=10) 472 except subprocess.TimeoutExpired: 473 run_state.env_process.kill() 474 475 if run_state.trainer_process and run_state.trainer_process.poll() is None: 476 logger.info("[%s] Stopping trainer process...", run_state.run_id) 477 run_state.trainer_process.terminate() 478 try: 479 run_state.trainer_process.wait(timeout=10) 480 except subprocess.TimeoutExpired: 481 run_state.trainer_process.kill() 482 483 if run_state.api_process and run_state.api_process.poll() is None: 484 logger.info("[%s] Stopping API server...", run_state.run_id) 485 run_state.api_process.terminate() 486 try: 487 run_state.api_process.wait(timeout=10) 488 except subprocess.TimeoutExpired: 489 run_state.api_process.kill() 490 491 if run_state.status == "running": 492 run_state.status = "stopped" 493 494 # Close log file handles that were opened for subprocess stdout. 495 for attr in ("env_log_file", "trainer_log_file", "api_log_file"): 496 fh = getattr(run_state, attr, None) 497 if fh is not None: 498 try: 499 fh.close() 500 except Exception: 501 pass 502 setattr(run_state, attr, None) 503 504 505 # ============================================================================ 506 # Environment Discovery Tools 507 # ============================================================================ 508 509 async def rl_list_environments() -> str: 510 """ 511 List all available RL environments. 512 513 Scans tinker-atropos/tinker_atropos/environments/ for Python files 514 containing classes that inherit from BaseEnv. 515 516 Returns information about each environment including: 517 - name: Environment identifier 518 - class_name: Python class name 519 - file_path: Path to the environment file 520 - description: Brief description if available 521 522 TIP: To create or modify RL environments: 523 1. Use terminal/file tools to inspect existing environments 524 2. Study how they load datasets, define verifiers, and structure rewards 525 3. Inspect HuggingFace datasets to understand data formats 526 4. Copy an existing environment as a template 527 528 Returns: 529 JSON string with list of environments 530 """ 531 _initialize_environments() 532 533 response = { 534 "environments": [ 535 { 536 "name": env.name, 537 "class_name": env.class_name, 538 "file_path": env.file_path, 539 "description": env.description, 540 } 541 for env in _environments 542 ], 543 "count": len(_environments), 544 "tips": [ 545 "Use rl_select_environment(name) to select an environment", 546 "Read the file_path with file tools to understand how each environment works", 547 "Look for load_dataset(), score_answer(), get_next_item() methods", 548 ] 549 } 550 551 return json.dumps(response, indent=2) 552 553 554 async def rl_select_environment(name: str) -> str: 555 """ 556 Select an RL environment for training. 557 558 This loads the environment's configuration fields into memory. 559 After selecting, use rl_get_current_config() to see all configurable options 560 and rl_edit_config() to modify specific fields. 561 562 Args: 563 name: Name of the environment to select (from rl_list_environments) 564 565 Returns: 566 JSON string with selection result, file path, and configurable field count 567 568 TIP: Read the returned file_path to understand how the environment works. 569 """ 570 global _current_env, _current_config 571 572 _initialize_environments() 573 574 env_info = None 575 for env in _environments: 576 if env.name == name: 577 env_info = env 578 break 579 580 if not env_info: 581 return json.dumps({ 582 "error": f"Environment '{name}' not found", 583 "available": [e.name for e in _environments], 584 }, indent=2) 585 586 _current_env = name 587 588 # Dynamically discover config fields 589 config_fields = _get_env_config_fields(env_info.file_path) 590 _env_config_cache[name] = config_fields 591 592 # Initialize current config with defaults for non-locked fields 593 _current_config = {} 594 for field_name, field_info in config_fields.items(): 595 if not field_info.get("locked", False): 596 _current_config[field_name] = field_info.get("default") 597 598 # Auto-set wandb_name to "{env_name}-DATETIME" to avoid overlaps 599 timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") 600 _current_config["wandb_name"] = f"{name}-{timestamp}" 601 602 return json.dumps({ 603 "message": f"Selected environment: {name}", 604 "environment": name, 605 "file_path": env_info.file_path, 606 }, indent=2) 607 608 609 # ============================================================================ 610 # Configuration Tools 611 # ============================================================================ 612 613 async def rl_get_current_config() -> str: 614 """ 615 Get the current environment configuration. 616 617 Returns all configurable fields for the selected environment. 618 Each environment may have different configuration options. 619 620 Fields are divided into: 621 - configurable_fields: Can be changed with rl_edit_config() 622 - locked_fields: Infrastructure settings that cannot be changed 623 624 Returns: 625 JSON string with configurable and locked fields 626 """ 627 if not _current_env: 628 return json.dumps({ 629 "error": "No environment selected. Use rl_select_environment(name) first.", 630 }, indent=2) 631 632 config_fields = _env_config_cache.get(_current_env, {}) 633 634 configurable = [] 635 locked = [] 636 637 for field_name, field_info in config_fields.items(): 638 field_data = { 639 "name": field_name, 640 "type": field_info.get("type", "unknown"), 641 "default": field_info.get("default"), 642 "description": field_info.get("description", ""), 643 "current_value": _current_config.get(field_name, field_info.get("default")), 644 } 645 646 if field_info.get("locked", False): 647 field_data["locked_value"] = LOCKED_FIELDS.get("env", {}).get(field_name) 648 locked.append(field_data) 649 else: 650 configurable.append(field_data) 651 652 return json.dumps({ 653 "environment": _current_env, 654 "configurable_fields": configurable, 655 "locked_fields": locked, 656 "tip": "Use rl_edit_config(field, value) to change any configurable field.", 657 }, indent=2) 658 659 660 async def rl_edit_config(field: str, value: Any) -> str: 661 """ 662 Update a configuration field. 663 664 Use rl_get_current_config() first to see available fields for the 665 selected environment. Each environment has different options. 666 667 Locked fields (infrastructure settings) cannot be changed. 668 669 Args: 670 field: Name of the field to update (from rl_get_current_config) 671 value: New value for the field 672 673 Returns: 674 JSON string with updated config or error message 675 """ 676 if not _current_env: 677 return json.dumps({ 678 "error": "No environment selected. Use rl_select_environment(name) first.", 679 }, indent=2) 680 681 config_fields = _env_config_cache.get(_current_env, {}) 682 683 if field not in config_fields: 684 return json.dumps({ 685 "error": f"Unknown field '{field}'", 686 "available_fields": list(config_fields.keys()), 687 }, indent=2) 688 689 field_info = config_fields[field] 690 if field_info.get("locked", False): 691 return json.dumps({ 692 "error": f"Field '{field}' is locked and cannot be changed", 693 "locked_value": LOCKED_FIELDS.get("env", {}).get(field), 694 }, indent=2) 695 696 _current_config[field] = value 697 698 return json.dumps({ 699 "message": f"Updated {field} = {value}", 700 "field": field, 701 "value": value, 702 "config": _current_config, 703 }, indent=2) 704 705 706 # ============================================================================ 707 # Training Management Tools 708 # ============================================================================ 709 710 async def rl_start_training() -> str: 711 """ 712 Start a new RL training run with the current environment and config. 713 714 Requires an environment to be selected first using rl_select_environment(). 715 Use rl_edit_config() to adjust configuration before starting. 716 717 This spawns three processes: 718 1. run-api (Atropos trajectory API) 719 2. launch_training.py (Tinker trainer + inference server) 720 3. environment.py serve (the selected environment) 721 722 WARNING: Training runs take hours. Use rl_check_status() to monitor 723 progress (recommended: check every 30 minutes at most). 724 725 Returns: 726 JSON string with run_id and initial status 727 """ 728 if not _current_env: 729 return json.dumps({ 730 "error": "No environment selected. Use rl_select_environment(name) first.", 731 }, indent=2) 732 733 # Check API keys 734 if not os.getenv("TINKER_API_KEY"): 735 return json.dumps({ 736 "error": "TINKER_API_KEY not set. Add it to ~/.hermes/.env", 737 }, indent=2) 738 739 # Find environment file 740 env_info = None 741 for env in _environments: 742 if env.name == _current_env: 743 env_info = env 744 break 745 746 if not env_info or not Path(env_info.file_path).exists(): 747 return json.dumps({ 748 "error": f"Environment file not found for '{_current_env}'", 749 }, indent=2) 750 751 # Generate run ID 752 run_id = str(uuid.uuid4())[:8] 753 754 # Create config YAML 755 CONFIGS_DIR.mkdir(exist_ok=True) 756 config_path = CONFIGS_DIR / f"run_{run_id}.yaml" 757 758 # Start with locked config as base 759 import copy 760 run_config = copy.deepcopy(LOCKED_FIELDS) 761 762 if "env" not in run_config: 763 run_config["env"] = {} 764 765 # Apply configurable fields 766 for field_name, value in _current_config.items(): 767 if value is not None and value != "": 768 run_config["env"][field_name] = value 769 770 # Set WandB settings 771 wandb_project = _current_config.get("wandb_project", "atropos-tinker") 772 if "tinker" not in run_config: 773 run_config["tinker"] = {} 774 run_config["tinker"]["wandb_project"] = wandb_project 775 run_config["tinker"]["wandb_run_name"] = f"{_current_env}-{run_id}" 776 777 if "wandb_name" in _current_config and _current_config["wandb_name"]: 778 run_config["env"]["wandb_name"] = _current_config["wandb_name"] 779 780 with open(config_path, "w") as f: 781 yaml.dump(run_config, f, default_flow_style=False) 782 783 # Create run state 784 run_state = RunState( 785 run_id=run_id, 786 environment=_current_env, 787 config=_current_config.copy(), 788 status="starting", 789 wandb_project=wandb_project, 790 wandb_run_name=f"{_current_env}-{run_id}", 791 ) 792 793 _active_runs[run_id] = run_state 794 795 # Start training in background 796 asyncio.create_task(_spawn_training_run(run_state, config_path)) 797 798 return json.dumps({ 799 "run_id": run_id, 800 "status": "starting", 801 "environment": _current_env, 802 "config": _current_config, 803 "wandb_project": wandb_project, 804 "wandb_run_name": f"{_current_env}-{run_id}", 805 "config_path": str(config_path), 806 "logs": { 807 "api": str(LOGS_DIR / f"api_{run_id}.log"), 808 "trainer": str(LOGS_DIR / f"trainer_{run_id}.log"), 809 "env": str(LOGS_DIR / f"env_{run_id}.log"), 810 }, 811 "message": "Training starting. Use rl_check_status(run_id) to monitor (recommended: every 30 minutes).", 812 }, indent=2) 813 814 815 async def rl_check_status(run_id: str) -> str: 816 """ 817 Get status and metrics for a training run. 818 819 RATE LIMITED: For long-running training, this function enforces a 820 minimum 30-minute interval between checks for the same run_id. 821 822 Args: 823 run_id: The run ID returned by rl_start_training() 824 825 Returns: 826 JSON string with run status and metrics 827 """ 828 # Check rate limiting 829 now = time.time() 830 if run_id in _last_status_check: 831 elapsed = now - _last_status_check[run_id] 832 if elapsed < MIN_STATUS_CHECK_INTERVAL: 833 remaining = MIN_STATUS_CHECK_INTERVAL - elapsed 834 return json.dumps({ 835 "rate_limited": True, 836 "run_id": run_id, 837 "message": f"Rate limited. Next check available in {remaining/60:.0f} minutes.", 838 "next_check_in_seconds": remaining, 839 }, indent=2) 840 841 _last_status_check[run_id] = now 842 843 if run_id not in _active_runs: 844 return json.dumps({ 845 "error": f"Run '{run_id}' not found", 846 "active_runs": list(_active_runs.keys()), 847 }, indent=2) 848 849 run_state = _active_runs[run_id] 850 851 # Check process status 852 processes = { 853 "api": run_state.api_process.poll() if run_state.api_process else None, 854 "trainer": run_state.trainer_process.poll() if run_state.trainer_process else None, 855 "env": run_state.env_process.poll() if run_state.env_process else None, 856 } 857 858 running_time = time.time() - run_state.start_time if run_state.start_time else 0 859 860 result = { 861 "run_id": run_id, 862 "status": run_state.status, 863 "environment": run_state.environment, 864 "running_time_minutes": running_time / 60, 865 "processes": { 866 name: "running" if code is None else f"exited ({code})" 867 for name, code in processes.items() 868 }, 869 "wandb_project": run_state.wandb_project, 870 "wandb_run_name": run_state.wandb_run_name, 871 "logs": { 872 "api": str(LOGS_DIR / f"api_{run_id}.log"), 873 "trainer": str(LOGS_DIR / f"trainer_{run_id}.log"), 874 "env": str(LOGS_DIR / f"env_{run_id}.log"), 875 }, 876 } 877 878 if run_state.error_message: 879 result["error"] = run_state.error_message 880 881 # Try to get WandB metrics if available 882 try: 883 import wandb 884 api = wandb.Api() 885 runs = api.runs( 886 f"{os.getenv('WANDB_ENTITY', 'nousresearch')}/{run_state.wandb_project}", 887 filters={"display_name": run_state.wandb_run_name} 888 ) 889 if runs: 890 wandb_run = runs[0] 891 result["wandb_url"] = wandb_run.url 892 result["metrics"] = { 893 "step": wandb_run.summary.get("_step", 0), 894 "reward_mean": wandb_run.summary.get("train/reward_mean"), 895 "percent_correct": wandb_run.summary.get("train/percent_correct"), 896 "eval_percent_correct": wandb_run.summary.get("eval/percent_correct"), 897 } 898 except Exception as e: 899 result["wandb_error"] = str(e) 900 901 return json.dumps(result, indent=2) 902 903 904 async def rl_stop_training(run_id: str) -> str: 905 """ 906 Stop a running training job. 907 908 Args: 909 run_id: The run ID to stop 910 911 Returns: 912 JSON string with stop confirmation 913 """ 914 if run_id not in _active_runs: 915 return json.dumps({ 916 "error": f"Run '{run_id}' not found", 917 "active_runs": list(_active_runs.keys()), 918 }, indent=2) 919 920 run_state = _active_runs[run_id] 921 922 if run_state.status not in ("running", "starting"): 923 return json.dumps({ 924 "message": f"Run '{run_id}' is not running (status: {run_state.status})", 925 }, indent=2) 926 927 _stop_training_run(run_state) 928 929 return json.dumps({ 930 "message": f"Stopped training run '{run_id}'", 931 "run_id": run_id, 932 "status": run_state.status, 933 }, indent=2) 934 935 936 async def rl_get_results(run_id: str) -> str: 937 """ 938 Get final results and metrics for a training run. 939 940 Args: 941 run_id: The run ID to get results for 942 943 Returns: 944 JSON string with final results 945 """ 946 if run_id not in _active_runs: 947 return json.dumps({ 948 "error": f"Run '{run_id}' not found", 949 }, indent=2) 950 951 run_state = _active_runs[run_id] 952 953 result = { 954 "run_id": run_id, 955 "status": run_state.status, 956 "environment": run_state.environment, 957 "wandb_project": run_state.wandb_project, 958 "wandb_run_name": run_state.wandb_run_name, 959 } 960 961 # Get WandB metrics 962 try: 963 import wandb 964 api = wandb.Api() 965 runs = api.runs( 966 f"{os.getenv('WANDB_ENTITY', 'nousresearch')}/{run_state.wandb_project}", 967 filters={"display_name": run_state.wandb_run_name} 968 ) 969 if runs: 970 wandb_run = runs[0] 971 result["wandb_url"] = wandb_run.url 972 result["final_metrics"] = dict(wandb_run.summary) 973 result["history"] = [dict(row) for row in wandb_run.history(samples=10)] 974 except Exception as e: 975 result["wandb_error"] = str(e) 976 977 return json.dumps(result, indent=2) 978 979 980 async def rl_list_runs() -> str: 981 """ 982 List all training runs (active and completed). 983 984 Returns: 985 JSON string with list of runs and their status 986 """ 987 runs = [] 988 for run_id, run_state in _active_runs.items(): 989 runs.append({ 990 "run_id": run_id, 991 "environment": run_state.environment, 992 "status": run_state.status, 993 "wandb_run_name": run_state.wandb_run_name, 994 }) 995 996 return json.dumps({ 997 "runs": runs, 998 "count": len(runs), 999 }, indent=2) 1000 1001 1002 # ============================================================================ 1003 # Inference Testing (via Atropos `process` mode with OpenRouter) 1004 # ============================================================================ 1005 1006 # Test models at different scales for robustness testing 1007 # These are cheap, capable models on OpenRouter for testing parsing/scoring 1008 TEST_MODELS = [ 1009 {"id": "qwen/qwen3-8b", "name": "Qwen3 8B", "scale": "small"}, 1010 {"id": "z-ai/glm-4.7-flash", "name": "GLM-4.7 Flash", "scale": "medium"}, 1011 {"id": "minimax/minimax-m2.7", "name": "MiniMax M2.7", "scale": "large"}, 1012 ] 1013 1014 # Default test parameters - quick but representative 1015 DEFAULT_NUM_STEPS = 3 # Number of steps (items) to test 1016 DEFAULT_GROUP_SIZE = 16 # Completions per item (like training) 1017 1018 1019 async def rl_test_inference( 1020 num_steps: int = DEFAULT_NUM_STEPS, 1021 group_size: int = DEFAULT_GROUP_SIZE, 1022 models: Optional[List[str]] = None, 1023 ) -> str: 1024 """ 1025 Quick inference test for any environment using Atropos's `process` mode. 1026 1027 Runs a few steps of inference + scoring to validate: 1028 - Environment loads correctly 1029 - Prompt construction works 1030 - Inference parsing is robust (tested with multiple model scales) 1031 - Verifier/scoring logic works 1032 1033 Default: 3 steps × 16 completions = 48 total rollouts per model. 1034 Tests 3 models = 144 total rollouts. Quick sanity check. 1035 1036 Test models (varying intelligence levels for robustness): 1037 - qwen/qwen3-8b (small) 1038 - zhipu-ai/glm-4-flash (medium) 1039 - minimax/minimax-m1 (large) 1040 1041 Args: 1042 num_steps: Steps to run (default: 3, max recommended for testing) 1043 group_size: Completions per step (default: 16, like training) 1044 models: Optional model IDs to test. If None, uses all 3 test models. 1045 1046 Returns: 1047 JSON with results per model: steps_tested, accuracy, scores 1048 """ 1049 if not _current_env: 1050 return json.dumps({ 1051 "error": "No environment selected. Use rl_select_environment(name) first.", 1052 }, indent=2) 1053 1054 api_key = os.getenv("OPENROUTER_API_KEY") 1055 if not api_key: 1056 return json.dumps({ 1057 "error": "OPENROUTER_API_KEY not set. Required for inference testing.", 1058 }, indent=2) 1059 1060 # Find environment info 1061 env_info = None 1062 for env in _environments: 1063 if env.name == _current_env: 1064 env_info = env 1065 break 1066 1067 if not env_info: 1068 return json.dumps({ 1069 "error": f"Environment '{_current_env}' not found", 1070 }, indent=2) 1071 1072 # Determine which models to test 1073 if models: 1074 test_models = [m for m in TEST_MODELS if m["id"] in models] 1075 if not test_models: 1076 test_models = [{"id": m, "name": m, "scale": "custom"} for m in models] 1077 else: 1078 test_models = TEST_MODELS 1079 1080 # Calculate total rollouts for logging 1081 total_rollouts_per_model = num_steps * group_size 1082 total_rollouts = total_rollouts_per_model * len(test_models) 1083 1084 results = { 1085 "environment": _current_env, 1086 "environment_file": env_info.file_path, 1087 "test_config": { 1088 "num_steps": num_steps, 1089 "group_size": group_size, 1090 "rollouts_per_model": total_rollouts_per_model, 1091 "total_rollouts": total_rollouts, 1092 }, 1093 "models_tested": [], 1094 } 1095 1096 # Create output directory for test results 1097 _ensure_logs_dir() 1098 test_output_dir = LOGS_DIR / "inference_tests" 1099 test_output_dir.mkdir(exist_ok=True) 1100 1101 for model_info in test_models: 1102 model_id = model_info["id"] 1103 model_safe_name = model_id.replace("/", "_") 1104 1105 print(f"\n{'='*60}") 1106 print(f"Testing with {model_info['name']} ({model_id})") 1107 print(f"{'='*60}") 1108 1109 # Output file for this test run 1110 output_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.jsonl" 1111 1112 # Generate unique run ID for wandb 1113 test_run_id = str(uuid.uuid4())[:8] 1114 wandb_run_name = f"test_inference_RSIAgent_{_current_env}_{test_run_id}" 1115 1116 # Build the process command using Atropos's built-in CLI 1117 # This runs the environment's actual code with OpenRouter as the inference backend 1118 # We pass our locked settings + test-specific overrides via CLI args 1119 cmd = [ 1120 sys.executable, env_info.file_path, "process", 1121 # Test-specific overrides 1122 "--env.total_steps", str(num_steps), 1123 "--env.group_size", str(group_size), 1124 "--env.use_wandb", "true", # Enable wandb for test tracking 1125 "--env.wandb_name", wandb_run_name, 1126 "--env.data_path_to_save_groups", str(output_file), 1127 # Use locked settings from our config 1128 "--env.tokenizer_name", LOCKED_FIELDS["env"]["tokenizer_name"], 1129 "--env.max_token_length", str(LOCKED_FIELDS["env"]["max_token_length"]), 1130 "--env.max_num_workers", str(LOCKED_FIELDS["env"]["max_num_workers"]), 1131 "--env.max_batches_offpolicy", str(LOCKED_FIELDS["env"]["max_batches_offpolicy"]), 1132 # OpenRouter config for inference testing 1133 # IMPORTANT: Use server_type=openai for OpenRouter (not sglang) 1134 # sglang is only for actual training with Tinker's inference server 1135 "--openai.base_url", "https://openrouter.ai/api/v1", 1136 "--openai.api_key", api_key, 1137 "--openai.model_name", model_id, 1138 "--openai.server_type", "openai", # OpenRouter is OpenAI-compatible 1139 "--openai.health_check", "false", # OpenRouter doesn't have health endpoint 1140 ] 1141 1142 # Debug: Print the full command 1143 cmd_str = " ".join(str(c) for c in cmd) 1144 # Hide API key in printed output 1145 cmd_display = cmd_str.replace(api_key, "***API_KEY***") 1146 print(f"Command: {cmd_display}") 1147 print(f"Working dir: {TINKER_ATROPOS_ROOT}") 1148 print(f"WandB run: {wandb_run_name}") 1149 print(f" {num_steps} steps × {group_size} completions = {total_rollouts_per_model} rollouts") 1150 1151 model_results = { 1152 "model": model_id, 1153 "name": model_info["name"], 1154 "scale": model_info["scale"], 1155 "wandb_run": wandb_run_name, 1156 "output_file": str(output_file), 1157 "steps": [], 1158 "steps_tested": 0, 1159 "total_completions": 0, 1160 "correct_completions": 0, 1161 } 1162 1163 try: 1164 # Run the process command with real-time output streaming 1165 process = await asyncio.create_subprocess_exec( 1166 *cmd, 1167 stdout=asyncio.subprocess.PIPE, 1168 stderr=asyncio.subprocess.PIPE, 1169 cwd=str(TINKER_ATROPOS_ROOT), 1170 ) 1171 1172 # Stream output in real-time while collecting for logs 1173 stdout_lines = [] 1174 stderr_lines = [] 1175 log_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.log" 1176 1177 async def read_stream(stream, lines_list, prefix=""): 1178 """Read stream line by line and print in real-time.""" 1179 while True: 1180 line = await stream.readline() 1181 if not line: 1182 break 1183 decoded = line.decode().rstrip() 1184 lines_list.append(decoded) 1185 # Print progress-related lines in real-time 1186 if any(kw in decoded.lower() for kw in ['processing', 'group', 'step', 'progress', '%', 'completed']): 1187 print(f" {prefix}{decoded}") 1188 1189 # Read both streams concurrently with timeout 1190 try: 1191 await asyncio.wait_for( 1192 asyncio.gather( 1193 read_stream(process.stdout, stdout_lines, "📊 "), 1194 read_stream(process.stderr, stderr_lines, "⚠️ "), 1195 ), 1196 timeout=600, # 10 minute timeout per model 1197 ) 1198 except asyncio.TimeoutError: 1199 process.kill() 1200 raise 1201 1202 await process.wait() 1203 1204 # Combine output for logging 1205 stdout_text = "\n".join(stdout_lines) 1206 stderr_text = "\n".join(stderr_lines) 1207 1208 # Write logs to files for inspection outside CLI 1209 with open(log_file, "w") as f: 1210 f.write(f"Command: {cmd_display}\n") 1211 f.write(f"Working dir: {TINKER_ATROPOS_ROOT}\n") 1212 f.write(f"Return code: {process.returncode}\n") 1213 f.write(f"\n{'='*60}\n") 1214 f.write(f"STDOUT:\n{'='*60}\n") 1215 f.write(stdout_text or "(empty)\n") 1216 f.write(f"\n{'='*60}\n") 1217 f.write(f"STDERR:\n{'='*60}\n") 1218 f.write(stderr_text or "(empty)\n") 1219 1220 print(f" Log file: {log_file}") 1221 1222 if process.returncode != 0: 1223 model_results["error"] = f"Process exited with code {process.returncode}" 1224 model_results["stderr"] = stderr_text[-1000:] 1225 model_results["stdout"] = stdout_text[-1000:] 1226 model_results["log_file"] = str(log_file) 1227 print(f"\n ❌ Error: {model_results['error']}") 1228 # Print last few lines of stderr for debugging 1229 if stderr_lines: 1230 print(" Last errors:") 1231 for line in stderr_lines[-5:]: 1232 print(f" {line}") 1233 else: 1234 print("\n ✅ Process completed successfully") 1235 print(f" Output file: {output_file}") 1236 print(f" File exists: {output_file.exists()}") 1237 1238 # Parse the output JSONL file 1239 if output_file.exists(): 1240 # Read JSONL file (one JSON object per line = one step) 1241 with open(output_file, "r") as f: 1242 for line in f: 1243 line = line.strip() 1244 if not line: 1245 continue 1246 try: 1247 item = json.loads(line) 1248 scores = item.get("scores", []) 1249 model_results["steps_tested"] += 1 1250 model_results["total_completions"] += len(scores) 1251 correct = sum(1 for s in scores if s > 0) 1252 model_results["correct_completions"] += correct 1253 1254 model_results["steps"].append({ 1255 "step": model_results["steps_tested"], 1256 "completions": len(scores), 1257 "correct": correct, 1258 "scores": scores, 1259 }) 1260 except json.JSONDecodeError: 1261 continue 1262 1263 print(f" Completed {model_results['steps_tested']} steps") 1264 else: 1265 model_results["error"] = f"Output file not created: {output_file}" 1266 1267 except asyncio.TimeoutError: 1268 model_results["error"] = "Process timed out after 10 minutes" 1269 print(" Timeout!") 1270 except Exception as e: 1271 model_results["error"] = str(e) 1272 print(f" Error: {e}") 1273 1274 # Calculate stats 1275 if model_results["total_completions"] > 0: 1276 model_results["accuracy"] = round( 1277 model_results["correct_completions"] / model_results["total_completions"], 3 1278 ) 1279 else: 1280 model_results["accuracy"] = 0 1281 1282 if model_results["steps_tested"] > 0: 1283 steps_with_correct = sum(1 for s in model_results["steps"] if s.get("correct", 0) > 0) 1284 model_results["steps_with_correct"] = steps_with_correct 1285 model_results["step_success_rate"] = round( 1286 steps_with_correct / model_results["steps_tested"], 3 1287 ) 1288 else: 1289 model_results["steps_with_correct"] = 0 1290 model_results["step_success_rate"] = 0 1291 1292 print(f" Results: {model_results['correct_completions']}/{model_results['total_completions']} correct") 1293 print(f" Accuracy: {model_results['accuracy']:.1%}") 1294 1295 results["models_tested"].append(model_results) 1296 1297 # Overall summary 1298 working_models = [m for m in results["models_tested"] if m.get("steps_tested", 0) > 0] 1299 1300 results["summary"] = { 1301 "steps_requested": num_steps, 1302 "models_tested": len(test_models), 1303 "models_succeeded": len(working_models), 1304 "best_model": max(working_models, key=lambda x: x.get("accuracy", 0))["model"] if working_models else None, 1305 "avg_accuracy": round( 1306 sum(m.get("accuracy", 0) for m in working_models) / len(working_models), 3 1307 ) if working_models else 0, 1308 "environment_working": bool(working_models), 1309 "output_directory": str(test_output_dir), 1310 } 1311 1312 return json.dumps(results, indent=2) 1313 1314 1315 # ============================================================================ 1316 # Requirements Check 1317 # ============================================================================ 1318 1319 def check_rl_python_version() -> bool: 1320 """ 1321 Check if Python version meets the minimum for RL tools. 1322 1323 tinker-atropos depends on the 'tinker' package which requires Python >= 3.11. 1324 """ 1325 return sys.version_info >= (3, 11) 1326 1327 1328 def check_rl_api_keys() -> bool: 1329 """ 1330 Check if required API keys and Python version are available. 1331 1332 RL training requires: 1333 - Python >= 3.11 (tinker package requirement) 1334 - TINKER_API_KEY for the Tinker training API 1335 - WANDB_API_KEY for Weights & Biases metrics 1336 """ 1337 if not check_rl_python_version(): 1338 return False 1339 tinker_key = os.getenv("TINKER_API_KEY") 1340 wandb_key = os.getenv("WANDB_API_KEY") 1341 return bool(tinker_key) and bool(wandb_key) 1342 1343 1344 def get_missing_keys() -> List[str]: 1345 """ 1346 Get list of missing requirements for RL tools (API keys and Python version). 1347 """ 1348 missing = [] 1349 if not check_rl_python_version(): 1350 missing.append(f"Python >= 3.11 (current: {sys.version_info.major}.{sys.version_info.minor})") 1351 if not os.getenv("TINKER_API_KEY"): 1352 missing.append("TINKER_API_KEY") 1353 if not os.getenv("WANDB_API_KEY"): 1354 missing.append("WANDB_API_KEY") 1355 return missing 1356 1357 1358 # --------------------------------------------------------------------------- 1359 # Schemas + Registry 1360 # --------------------------------------------------------------------------- 1361 from tools.registry import registry 1362 1363 RL_LIST_ENVIRONMENTS_SCHEMA = {"name": "rl_list_environments", "description": "List all available RL environments. Returns environment names, paths, and descriptions. TIP: Read the file_path with file tools to understand how each environment works (verifiers, data loading, rewards).", "parameters": {"type": "object", "properties": {}, "required": []}} 1364 RL_SELECT_ENVIRONMENT_SCHEMA = {"name": "rl_select_environment", "description": "Select an RL environment for training. Loads the environment's default configuration. After selecting, use rl_get_current_config() to see settings and rl_edit_config() to modify them.", "parameters": {"type": "object", "properties": {"name": {"type": "string", "description": "Name of the environment to select (from rl_list_environments)"}}, "required": ["name"]}} 1365 RL_GET_CURRENT_CONFIG_SCHEMA = {"name": "rl_get_current_config", "description": "Get the current environment configuration. Returns only fields that can be modified: group_size, max_token_length, total_steps, steps_per_eval, use_wandb, wandb_name, max_num_workers.", "parameters": {"type": "object", "properties": {}, "required": []}} 1366 RL_EDIT_CONFIG_SCHEMA = {"name": "rl_edit_config", "description": "Update a configuration field. Use rl_get_current_config() first to see all available fields for the selected environment. Each environment has different configurable options. Infrastructure settings (tokenizer, URLs, lora_rank, learning_rate) are locked.", "parameters": {"type": "object", "properties": {"field": {"type": "string", "description": "Name of the field to update (get available fields from rl_get_current_config)"}, "value": {"description": "New value for the field"}}, "required": ["field", "value"]}} 1367 RL_START_TRAINING_SCHEMA = {"name": "rl_start_training", "description": "Start a new RL training run with the current environment and config. Most training parameters (lora_rank, learning_rate, etc.) are fixed. Use rl_edit_config() to set group_size, batch_size, wandb_project before starting. WARNING: Training takes hours.", "parameters": {"type": "object", "properties": {}, "required": []}} 1368 RL_CHECK_STATUS_SCHEMA = {"name": "rl_check_status", "description": "Get status and metrics for a training run. RATE LIMITED: enforces 30-minute minimum between checks for the same run. Returns WandB metrics: step, state, reward_mean, loss, percent_correct.", "parameters": {"type": "object", "properties": {"run_id": {"type": "string", "description": "The run ID from rl_start_training()"}}, "required": ["run_id"]}} 1369 RL_STOP_TRAINING_SCHEMA = {"name": "rl_stop_training", "description": "Stop a running training job. Use if metrics look bad, training is stagnant, or you want to try different settings.", "parameters": {"type": "object", "properties": {"run_id": {"type": "string", "description": "The run ID to stop"}}, "required": ["run_id"]}} 1370 RL_GET_RESULTS_SCHEMA = {"name": "rl_get_results", "description": "Get final results and metrics for a completed training run. Returns final metrics and path to trained weights.", "parameters": {"type": "object", "properties": {"run_id": {"type": "string", "description": "The run ID to get results for"}}, "required": ["run_id"]}} 1371 RL_LIST_RUNS_SCHEMA = {"name": "rl_list_runs", "description": "List all training runs (active and completed) with their status.", "parameters": {"type": "object", "properties": {}, "required": []}} 1372 RL_TEST_INFERENCE_SCHEMA = {"name": "rl_test_inference", "description": "Quick inference test for any environment. Runs a few steps of inference + scoring using OpenRouter. Default: 3 steps x 16 completions = 48 rollouts per model, testing 3 models = 144 total. Tests environment loading, prompt construction, inference parsing, and verifier logic. Use BEFORE training to catch issues.", "parameters": {"type": "object", "properties": {"num_steps": {"type": "integer", "description": "Number of steps to run (default: 3, recommended max for testing)", "default": 3}, "group_size": {"type": "integer", "description": "Completions per step (default: 16, like training)", "default": 16}, "models": {"type": "array", "items": {"type": "string"}, "description": "Optional list of OpenRouter model IDs. Default: qwen/qwen3-8b, z-ai/glm-4.7-flash, minimax/minimax-m2.7"}}, "required": []}} 1373 1374 _rl_env = ["TINKER_API_KEY", "WANDB_API_KEY"] 1375 1376 registry.register(name="rl_list_environments", emoji="🧪", toolset="rl", schema=RL_LIST_ENVIRONMENTS_SCHEMA, 1377 handler=lambda args, **kw: rl_list_environments(), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True) 1378 registry.register(name="rl_select_environment", emoji="🧪", toolset="rl", schema=RL_SELECT_ENVIRONMENT_SCHEMA, 1379 handler=lambda args, **kw: rl_select_environment(name=args.get("name", "")), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True) 1380 registry.register(name="rl_get_current_config", emoji="🧪", toolset="rl", schema=RL_GET_CURRENT_CONFIG_SCHEMA, 1381 handler=lambda args, **kw: rl_get_current_config(), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True) 1382 registry.register(name="rl_edit_config", emoji="🧪", toolset="rl", schema=RL_EDIT_CONFIG_SCHEMA, 1383 handler=lambda args, **kw: rl_edit_config(field=args.get("field", ""), value=args.get("value")), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True) 1384 registry.register(name="rl_start_training", emoji="🧪", toolset="rl", schema=RL_START_TRAINING_SCHEMA, 1385 handler=lambda args, **kw: rl_start_training(), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True) 1386 registry.register(name="rl_check_status", emoji="🧪", toolset="rl", schema=RL_CHECK_STATUS_SCHEMA, 1387 handler=lambda args, **kw: rl_check_status(run_id=args.get("run_id", "")), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True) 1388 registry.register(name="rl_stop_training", emoji="🧪", toolset="rl", schema=RL_STOP_TRAINING_SCHEMA, 1389 handler=lambda args, **kw: rl_stop_training(run_id=args.get("run_id", "")), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True) 1390 registry.register(name="rl_get_results", emoji="🧪", toolset="rl", schema=RL_GET_RESULTS_SCHEMA, 1391 handler=lambda args, **kw: rl_get_results(run_id=args.get("run_id", "")), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True) 1392 registry.register(name="rl_list_runs", emoji="🧪", toolset="rl", schema=RL_LIST_RUNS_SCHEMA, 1393 handler=lambda args, **kw: rl_list_runs(), check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True) 1394 registry.register(name="rl_test_inference", emoji="🧪", toolset="rl", schema=RL_TEST_INFERENCE_SCHEMA, 1395 handler=lambda args, **kw: rl_test_inference(num_steps=args.get("num_steps", 3), group_size=args.get("group_size", 16), models=args.get("models")), 1396 check_fn=check_rl_api_keys, requires_env=_rl_env, is_async=True)