/ environments / hermes_base_env.py
hermes_base_env.py
1 """ 2 HermesAgentBaseEnv -- Abstract Base Environment for Hermes-Agent + Atropos 3 4 Provides the Atropos integration plumbing that all hermes-agent environments share: 5 - Two-mode operation (OpenAI server for Phase 1, VLLM ManagedServer for Phase 2) 6 - Per-group toolset/distribution resolution 7 - Agent loop orchestration via HermesAgentLoop 8 - ToolContext creation for reward functions 9 - ScoredDataGroup construction from ManagedServer state 10 11 Subclasses only need to implement: 12 setup() -- Load dataset, initialize state 13 get_next_item() -- Return the next item from the dataset 14 format_prompt() -- Convert a dataset item into the user message 15 compute_reward() -- Score the rollout (has full ToolContext access) 16 evaluate() -- Periodic evaluation 17 """ 18 19 import asyncio 20 import json 21 import logging 22 import os 23 import sys 24 import uuid 25 from abc import abstractmethod 26 from pathlib import Path 27 from typing import Any, Dict, List, Optional, Set, Tuple, Union 28 29 # Ensure the hermes-agent repo root is on sys.path so that imports like 30 # `from model_tools import ...` and `from environments.X import ...` work 31 # regardless of where the script is invoked from. 32 _repo_root = Path(__file__).resolve().parent.parent 33 if str(_repo_root) not in sys.path: 34 sys.path.insert(0, str(_repo_root)) 35 36 from dotenv import load_dotenv 37 from pydantic import Field 38 39 # Load API keys from hermes-agent/.env so all environments can access them 40 _env_path = _repo_root / ".env" 41 if _env_path.exists(): 42 load_dotenv(dotenv_path=_env_path) 43 44 # Apply monkey patches for async-safe tool operation inside Atropos's event loop. 45 # This patches SwerexModalEnvironment to use a background thread instead of 46 # asyncio.run(), which would deadlock inside Atropos. Safe for normal CLI too. 47 from environments.patches import apply_patches 48 apply_patches() 49 50 from atroposlib.envs.base import ( 51 BaseEnv, 52 BaseEnvConfig, 53 ScoredDataGroup, 54 ScoredDataItem, 55 ) 56 from atroposlib.envs.server_handling.server_manager import ( 57 APIServerConfig, 58 ServerBaseline, 59 ServerManager, 60 ) 61 from atroposlib.type_definitions import Item 62 63 from environments.agent_loop import AgentResult, HermesAgentLoop 64 from environments.tool_context import ToolContext 65 from tools.budget_config import ( 66 DEFAULT_RESULT_SIZE_CHARS, 67 DEFAULT_TURN_BUDGET_CHARS, 68 DEFAULT_PREVIEW_SIZE_CHARS, 69 ) 70 71 # Import hermes-agent toolset infrastructure 72 from model_tools import get_tool_definitions 73 from toolset_distributions import sample_toolsets_from_distribution 74 75 logger = logging.getLogger(__name__) 76 77 78 class HermesAgentEnvConfig(BaseEnvConfig): 79 """ 80 Configuration for hermes-agent Atropos environments. 81 82 Extends BaseEnvConfig with agent-specific settings for toolsets, 83 terminal backend, dataset loading, and tool call parsing. 84 """ 85 86 # --- Toolset configuration --- 87 # Mutually exclusive: use either enabled_toolsets OR distribution 88 enabled_toolsets: Optional[List[str]] = Field( 89 default=None, 90 description="Explicit list of hermes toolsets to enable (e.g., ['terminal', 'file', 'web']). " 91 "If None and distribution is also None, all available toolsets are enabled.", 92 ) 93 disabled_toolsets: Optional[List[str]] = Field( 94 default=None, 95 description="Toolsets to disable. Applied as a filter on top of enabled_toolsets or distribution.", 96 ) 97 distribution: Optional[str] = Field( 98 default=None, 99 description="Name of a toolset distribution from toolset_distributions.py " 100 "(e.g., 'development', 'terminal_tasks'). Sampled once per group. " 101 "Mutually exclusive with enabled_toolsets.", 102 ) 103 104 # --- Agent loop configuration --- 105 max_agent_turns: int = Field( 106 default=30, 107 description="Maximum number of LLM calls (tool-calling iterations) per rollout.", 108 ) 109 system_prompt: Optional[str] = Field( 110 default=None, 111 description="System prompt for the agent. Tools are handled via the tools= parameter, " 112 "not embedded in the prompt text.", 113 ) 114 agent_temperature: float = Field( 115 default=1.0, 116 description="Sampling temperature for agent generation during rollouts.", 117 ) 118 119 # --- Terminal backend --- 120 terminal_backend: str = Field( 121 default="local", 122 description="Terminal backend: 'local', 'docker', 'modal', 'daytona', 'ssh', 'singularity'. " 123 "Modal or Daytona recommended for production RL (cloud isolation per rollout).", 124 ) 125 terminal_timeout: int = Field( 126 default=120, 127 description="Per-command timeout in seconds for terminal tool calls. " 128 "Commands exceeding this are killed. Increase for tasks with long-running " 129 "commands (compilation, pip install, etc.).", 130 ) 131 terminal_lifetime: int = Field( 132 default=3600, 133 description="Sandbox inactivity lifetime in seconds. The cleanup thread kills " 134 "sandboxes that have been idle longer than this. Must be longer than " 135 "the longest gap between tool calls (e.g., waiting for LLM response).", 136 ) 137 138 # --- Dataset --- 139 dataset_name: Optional[str] = Field( 140 default=None, 141 description="HuggingFace dataset name. Optional if tasks are defined inline.", 142 ) 143 dataset_split: str = Field( 144 default="train", 145 description="Dataset split to use.", 146 ) 147 prompt_field: str = Field( 148 default="prompt", 149 description="Which field in the dataset contains the prompt.", 150 ) 151 152 # --- Thread pool --- 153 tool_pool_size: int = Field( 154 default=128, 155 description="Thread pool size for tool execution. Each concurrent task needs a " 156 "thread for tool calls. Must be large enough for parallel evaluation. " 157 "Too small = thread pool starvation.", 158 ) 159 160 # --- Phase 2: Tool call parsing --- 161 tool_call_parser: str = Field( 162 default="hermes", 163 description="Tool call parser name for Phase 2 (VLLM server type). " 164 "Ignored in Phase 1 (OpenAI server type where VLLM parses natively). " 165 "Options: hermes, mistral, llama3_json, qwen, deepseek_v3, etc.", 166 ) 167 168 # --- Tool result budget --- 169 # Defaults imported from tools.budget_config (single source of truth). 170 default_result_size_chars: int = Field( 171 default=DEFAULT_RESULT_SIZE_CHARS, 172 description="Default per-tool threshold (chars) for persisting large results " 173 "to sandbox. Results exceeding this are written to /tmp/hermes-results/ " 174 "and replaced with a preview. Per-tool registry values take precedence " 175 "unless overridden via tool_result_overrides.", 176 ) 177 turn_budget_chars: int = Field( 178 default=DEFAULT_TURN_BUDGET_CHARS, 179 description="Aggregate char budget per assistant turn. If all tool results " 180 "in a single turn exceed this, the largest are persisted to disk first.", 181 ) 182 preview_size_chars: int = Field( 183 default=DEFAULT_PREVIEW_SIZE_CHARS, 184 description="Size of the inline preview shown after a tool result is persisted.", 185 ) 186 tool_result_overrides: Optional[Dict[str, int]] = Field( 187 default=None, 188 description="Per-tool threshold overrides (chars). Keys are tool names, " 189 "values are char thresholds. Overrides both the default and registry " 190 "per-tool values. Example: {'terminal': 10000, 'search_files': 5000}. " 191 "Note: read_file is pinned to infinity and cannot be overridden.", 192 ) 193 194 # --- Provider-specific parameters --- 195 # Passed as extra_body to the OpenAI client's chat.completions.create() call. 196 # Useful for OpenRouter provider preferences, transforms, route settings, etc. 197 # Example YAML: 198 # extra_body: 199 # provider: 200 # ignore: ["DeepInfra", "Fireworks"] 201 # order: ["Together"] 202 # transforms: ["middle-out"] 203 extra_body: Optional[Dict[str, Any]] = Field( 204 default=None, 205 description="Extra body parameters passed to the OpenAI client's " 206 "chat.completions.create(). Used for OpenRouter provider preferences, " 207 "transforms, and other provider-specific settings.", 208 ) 209 210 def build_budget_config(self): 211 """Build a BudgetConfig from env config fields.""" 212 from tools.budget_config import BudgetConfig 213 return BudgetConfig( 214 default_result_size=self.default_result_size_chars, 215 turn_budget=self.turn_budget_chars, 216 preview_size=self.preview_size_chars, 217 tool_overrides=dict(self.tool_result_overrides) if self.tool_result_overrides else {}, 218 ) 219 220 221 class HermesAgentBaseEnv(BaseEnv): 222 """ 223 Abstract base environment for hermes-agent Atropos integration. 224 225 Handles two modes of operation: 226 - Phase 1 (OpenAI server type): Uses server.chat_completion() directly. 227 The server (VLLM, SGLang, OpenRouter, OpenAI) handles tool call parsing 228 and reasoning extraction natively. DummyManagedServer provides placeholder 229 tokens. Good for SFT data gen, verifier testing, evaluation. 230 231 - Phase 2 (VLLM server type): Uses ManagedServer for exact token IDs + logprobs 232 via /generate. Client-side tool call parser reconstructs structured tool_calls 233 from raw output. Full RL training capability. 234 235 Subclasses must implement: 236 setup() -- Load dataset, initialize state 237 get_next_item() -- Return the next item to roll out 238 format_prompt() -- Convert a dataset item into the user message string 239 compute_reward() -- Score the rollout using ToolContext 240 evaluate() -- Periodic evaluation 241 """ 242 243 name: Optional[str] = "hermes-agent" 244 env_config_cls = HermesAgentEnvConfig 245 246 def __init__( 247 self, 248 config: HermesAgentEnvConfig, 249 server_configs: Union[ServerBaseline, List[APIServerConfig]], 250 slurm=False, 251 testing=False, 252 ): 253 super().__init__(config, server_configs, slurm, testing) 254 255 # Set terminal environment variables so hermes tools pick them up. 256 # These can all be overridden per-environment via config fields instead 257 # of requiring users to set shell env vars. 258 if config.terminal_backend: 259 os.environ["TERMINAL_ENV"] = config.terminal_backend 260 os.environ["TERMINAL_TIMEOUT"] = str(config.terminal_timeout) 261 os.environ["TERMINAL_LIFETIME_SECONDS"] = str(config.terminal_lifetime) 262 print( 263 f"🖥️ Terminal: backend={config.terminal_backend}, " 264 f"timeout={config.terminal_timeout}s, lifetime={config.terminal_lifetime}s" 265 ) 266 267 # Resize the agent loop's thread pool for tool execution. 268 # This must be large enough for the number of concurrent tasks 269 # (e.g., 89 parallel TB2 eval tasks each need a thread for tool calls). 270 from environments.agent_loop import resize_tool_pool 271 resize_tool_pool(config.tool_pool_size) 272 273 # Set tool_parser on the ServerManager so ManagedServer uses it 274 # for bidirectional tool call translation (raw text ↔ OpenAI tool_calls). 275 if hasattr(self.server, 'tool_parser'): 276 self.server.tool_parser = config.tool_call_parser 277 print(f"🔧 Tool parser: {config.tool_call_parser}") 278 279 # Current group's resolved tools (set in collect_trajectories) 280 self._current_group_tools: Optional[Tuple[List[Dict], Set[str]]] = None 281 282 # Tool error tracking for wandb logging 283 self._tool_error_buffer: List[Dict[str, Any]] = [] 284 285 # ========================================================================= 286 # Toolset resolution (per-group) 287 # ========================================================================= 288 289 def _resolve_tools_for_group(self) -> Tuple[List[Dict[str, Any]], Set[str]]: 290 """ 291 Resolve toolsets for a group. Called once in collect_trajectories(), 292 then shared by all collect_trajectory() calls in the group. 293 294 If distribution is set, samples probabilistically. 295 If enabled_toolsets is set, uses that explicit list. 296 disabled_toolsets is applied as a filter on top. 297 298 Returns: 299 (tool_schemas, valid_tool_names) tuple 300 """ 301 config = self.config 302 303 if config.distribution: 304 group_toolsets = sample_toolsets_from_distribution(config.distribution) 305 logger.info("Sampled toolsets from '%s': %s", config.distribution, group_toolsets) 306 else: 307 group_toolsets = config.enabled_toolsets # None means "all available" 308 if group_toolsets is None: 309 logger.warning( 310 "enabled_toolsets is None -- loading ALL tools including messaging. " 311 "Set explicit enabled_toolsets for RL training." 312 ) 313 314 tools = get_tool_definitions( 315 enabled_toolsets=group_toolsets, 316 disabled_toolsets=config.disabled_toolsets, 317 quiet_mode=True, 318 ) 319 320 valid_names = {t["function"]["name"] for t in tools} if tools else set() 321 logger.info("Resolved %d tools for group: %s", len(valid_names), sorted(valid_names)) 322 return tools, valid_names 323 324 # ========================================================================= 325 # Server mode detection 326 # ========================================================================= 327 328 def _use_managed_server(self) -> bool: 329 """ 330 Determine if we should use ManagedServer (Phase 2) or direct server (Phase 1). 331 332 Phase 2 (ManagedServer) is used when the server type is 'vllm' or 'sglang', 333 which go through the /generate endpoint for exact token tracking. 334 335 Phase 1 (direct server) is used for 'openai' server type, which uses 336 /v1/chat/completions with native tool call parsing. 337 """ 338 if not self.server.servers: 339 return False 340 341 server = self.server.servers[0] 342 # If the server is an OpenAI server (not VLLM/SGLang), use direct mode 343 from atroposlib.envs.server_handling.openai_server import OpenAIServer 344 return not isinstance(server, OpenAIServer) 345 346 # ========================================================================= 347 # Core Atropos integration 348 # ========================================================================= 349 350 async def collect_trajectories( 351 self, item: Item 352 ) -> Tuple[ 353 Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]], 354 List[Item], 355 ]: 356 """ 357 Override collect_trajectories to resolve toolsets once per group, 358 then delegate to the standard group-level collection. 359 360 The default BaseEnv.collect_trajectories() calls collect_trajectory() 361 group_size times in parallel. We resolve tools once here and store 362 them for all those calls to use. 363 """ 364 # Resolve toolsets for this group (shared by all rollouts in the group) 365 self._current_group_tools = self._resolve_tools_for_group() 366 367 # Delegate to the default implementation which calls collect_trajectory() 368 # group_size times via asyncio.gather 369 return await super().collect_trajectories(item) 370 371 # ========================================================================= 372 # Wandb rollout display -- format trajectories nicely 373 # ========================================================================= 374 375 @staticmethod 376 def _format_trajectory_for_display(messages: List[Dict[str, Any]]) -> str: 377 """ 378 Format a conversation's messages into a readable trajectory string 379 for wandb rollout tables. Shows tool calls, tool results, and reasoning 380 in a structured way instead of raw token decoding. 381 """ 382 parts = [] 383 for msg in messages: 384 role = msg.get("role", "unknown") 385 content = msg.get("content", "") 386 387 if role == "system": 388 parts.append(f"[SYSTEM]\n{content}") 389 390 elif role == "user": 391 parts.append(f"[USER]\n{content}") 392 393 elif role == "assistant": 394 # Show reasoning if present 395 reasoning = msg.get("reasoning_content", "") 396 if reasoning: 397 # Truncate long reasoning for display 398 if len(reasoning) > 300: 399 reasoning = reasoning[:300] + "..." 400 parts.append(f"[ASSISTANT thinking]\n{reasoning}") 401 402 # Show content 403 if content: 404 parts.append(f"[ASSISTANT]\n{content}") 405 406 # Show tool calls 407 tool_calls = msg.get("tool_calls", []) 408 for tc in tool_calls: 409 func = tc.get("function", {}) 410 name = func.get("name", "?") 411 args = func.get("arguments", "{}") 412 # Truncate long arguments for display 413 if len(args) > 200: 414 args = args[:200] + "..." 415 parts.append(f"[TOOL CALL] {name}({args})") 416 417 elif role == "tool": 418 tool_id = msg.get("tool_call_id", "") 419 result = content 420 # Truncate long tool results for display 421 if len(result) > 500: 422 result = result[:500] + "..." 423 parts.append(f"[TOOL RESULT] {result}") 424 425 return "\n\n".join(parts) 426 427 async def add_rollouts_for_wandb( 428 self, 429 scored_data, 430 item=None, 431 ): 432 """ 433 Override to show formatted trajectories with tool calls visible, 434 instead of raw token decoding which loses all structure. 435 """ 436 num_keep = self.config.num_rollouts_per_group_for_logging 437 if num_keep == -1: 438 num_keep = self.config.group_size 439 440 group = [] 441 for i in range(min(num_keep, len(scored_data.get("scores", [])))): 442 score = scored_data["scores"][i] 443 444 # Use messages if available for rich display 445 messages = None 446 if scored_data.get("messages") and i < len(scored_data["messages"]): 447 messages = scored_data["messages"][i] 448 449 if messages: 450 text = self._format_trajectory_for_display(messages) 451 elif scored_data.get("tokens") and i < len(scored_data["tokens"]): 452 text = self.tokenizer.decode(scored_data["tokens"][i]) 453 else: 454 text = "(no data)" 455 456 group.append((text, score)) 457 458 self.rollouts_for_wandb.append(group) 459 if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep: 460 self.rollouts_for_wandb.pop(0) 461 462 async def wandb_log(self, wandb_metrics: Optional[Dict] = None): 463 """Log base metrics including tool errors to wandb.""" 464 if wandb_metrics is None: 465 wandb_metrics = {} 466 467 # Log tool error stats 468 if self._tool_error_buffer: 469 wandb_metrics["train/tool_errors_count"] = len(self._tool_error_buffer) 470 471 # Log error details as a summary string (tables can crash wandb on tmp cleanup) 472 error_summaries = [] 473 for err in self._tool_error_buffer: 474 error_summaries.append( 475 f"[turn {err['turn']}] {err['tool']}({err['args'][:80]}) -> {err['error'][:150]}" 476 ) 477 wandb_metrics["train/tool_error_details"] = "\n".join(error_summaries) 478 479 # Also print to stdout for immediate visibility 480 for summary in error_summaries: 481 print(f" Tool Error: {summary}") 482 483 self._tool_error_buffer = [] 484 else: 485 wandb_metrics["train/tool_errors_count"] = 0 486 487 await super().wandb_log(wandb_metrics) 488 489 async def collect_trajectory( 490 self, item: Item 491 ) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]: 492 """ 493 Run a single rollout: agent loop + reward computation. 494 495 This is called group_size times in parallel by collect_trajectories(). 496 Each call gets its own task_id for terminal/browser session isolation. 497 """ 498 task_id = str(uuid.uuid4()) 499 500 # Get group-level tools (resolved once in collect_trajectories) 501 if self._current_group_tools is None: 502 # Fallback: resolve per-trajectory if called outside collect_trajectories 503 tools, valid_names = self._resolve_tools_for_group() 504 else: 505 tools, valid_names = self._current_group_tools 506 507 # Build initial messages 508 messages: List[Dict[str, Any]] = [] 509 if self.config.system_prompt: 510 messages.append({"role": "system", "content": self.config.system_prompt}) 511 messages.append({"role": "user", "content": self.format_prompt(item)}) 512 513 # Run the agent loop 514 result: AgentResult 515 if self._use_managed_server(): 516 # Phase 2: ManagedServer with ToolCallTranslator -- exact tokens + logprobs 517 # tool_parser is set on ServerManager in __init__ and passed through 518 # to ManagedServer, which uses ToolCallTranslator for bidirectional 519 # translation between raw text and OpenAI tool_calls. 520 try: 521 async with self.server.managed_server( 522 tokenizer=self.tokenizer, 523 preserve_think_blocks=bool(self.config.thinking_mode), 524 ) as managed: 525 agent = HermesAgentLoop( 526 server=managed, 527 tool_schemas=tools, 528 valid_tool_names=valid_names, 529 max_turns=self.config.max_agent_turns, 530 task_id=task_id, 531 temperature=self.config.agent_temperature, 532 max_tokens=self.config.max_token_length, 533 extra_body=self.config.extra_body, 534 budget_config=self.config.build_budget_config(), 535 ) 536 result = await agent.run(messages) 537 except NotImplementedError: 538 # DummyManagedServer not allowed -- fall back to Phase 1 539 logger.warning( 540 "ManagedServer not available (OpenAI server?). " 541 "Falling back to direct server mode." 542 ) 543 agent = HermesAgentLoop( 544 server=self.server, 545 tool_schemas=tools, 546 valid_tool_names=valid_names, 547 max_turns=self.config.max_agent_turns, 548 task_id=task_id, 549 temperature=self.config.agent_temperature, 550 max_tokens=self.config.max_token_length, 551 extra_body=self.config.extra_body, 552 budget_config=self.config.build_budget_config(), 553 ) 554 result = await agent.run(messages) 555 else: 556 # Phase 1: OpenAI server -- native tool_calls, placeholder tokens 557 agent = HermesAgentLoop( 558 server=self.server, 559 tool_schemas=tools, 560 valid_tool_names=valid_names, 561 max_turns=self.config.max_agent_turns, 562 task_id=task_id, 563 temperature=self.config.agent_temperature, 564 max_tokens=self.config.max_token_length, 565 extra_body=self.config.extra_body, 566 budget_config=self.config.build_budget_config(), 567 ) 568 result = await agent.run(messages) 569 570 # Skip reward computation if the agent loop produced no meaningful work 571 # (e.g., API call failed on turn 1). No point spinning up a Modal sandbox 572 # just to verify files that were never created. 573 only_system_and_user = all( 574 msg.get("role") in ("system", "user") for msg in result.messages 575 ) 576 if result.turns_used == 0 or only_system_and_user: 577 logger.warning( 578 "Agent loop produced no output (turns=%d, msgs=%d). Skipping reward.", 579 result.turns_used, len(result.messages), 580 ) 581 reward = 0.0 582 else: 583 # Compute reward using ToolContext (gives verifier full tool access) 584 ctx = ToolContext(task_id) 585 try: 586 reward = await self.compute_reward(item, result, ctx) 587 except Exception as e: 588 logger.error("compute_reward failed: %s", e) 589 reward = 0.0 590 finally: 591 ctx.cleanup() 592 593 # Track tool errors for wandb logging 594 if result.tool_errors: 595 for err in result.tool_errors: 596 self._tool_error_buffer.append({ 597 "turn": err.turn, 598 "tool": err.tool_name, 599 "args": err.arguments[:150], 600 "error": err.error[:300], 601 "result": err.tool_result[:300], 602 }) 603 604 # Build ScoredDataItem from ManagedServer state 605 # Phase 2: real tokens/masks/logprobs from SequenceNodes 606 # Phase 1: placeholder tokens (still need a valid ScoredDataItem for the pipeline) 607 nodes = (result.managed_state or {}).get("nodes", []) 608 609 if nodes: 610 # Phase 2 (or DummyManagedServer): use actual node data 611 node = nodes[-1] # Final sequence node = full trajectory 612 scored_item: Dict[str, Any] = { 613 "tokens": node.tokens, 614 "masks": node.masked_tokens, 615 "scores": reward, 616 } 617 618 # Include logprobs if available (Phase 2) 619 if hasattr(node, "logprobs") and node.logprobs: 620 scored_item["advantages"] = None # Computed by trainer 621 scored_item["ref_logprobs"] = None 622 else: 623 # Phase 1 with no managed state: create placeholder tokens 624 # so the data pipeline doesn't break. These are NOT suitable 625 # for training but allow process mode (SFT data gen) to work. 626 # Tokenize the full conversation to get approximate tokens. 627 full_text = "\n".join( 628 msg.get("content", "") for msg in result.messages if msg.get("content") 629 ) 630 if self.tokenizer: 631 tokens = self.tokenizer.encode(full_text, add_special_tokens=True) 632 else: 633 tokens = list(range(min(len(full_text) // 4, 128))) 634 635 scored_item = { 636 "tokens": tokens, 637 "masks": [-100] + tokens[1:], # Mask first token as prompt 638 "scores": reward, 639 } 640 641 # Always include messages for wandb rollout display and data logging 642 scored_item["messages"] = result.messages 643 644 return scored_item, [] 645 646 # ========================================================================= 647 # Abstract methods -- subclasses must implement 648 # ========================================================================= 649 650 @abstractmethod 651 async def setup(self): 652 """ 653 Load dataset, initialize state. 654 655 Called once when the environment starts. Typical implementation: 656 self.dataset = load_dataset(self.config.dataset_name, split=self.config.dataset_split) 657 self.iter = 0 658 """ 659 raise NotImplementedError 660 661 @abstractmethod 662 async def get_next_item(self) -> Item: 663 """ 664 Return the next item from the dataset for rollout. 665 666 Called by the base env's main loop to get items for workers. 667 Should cycle through the dataset. 668 """ 669 raise NotImplementedError 670 671 @abstractmethod 672 def format_prompt(self, item: Item) -> str: 673 """ 674 Convert a dataset item into the user message for the agent. 675 676 Args: 677 item: Dataset item (dict, tuple, etc.) 678 679 Returns: 680 The prompt string to send to the agent 681 """ 682 raise NotImplementedError 683 684 @abstractmethod 685 async def compute_reward( 686 self, item: Item, result: AgentResult, ctx: ToolContext 687 ) -> float: 688 """ 689 Score the rollout. Has full access to: 690 - item: the original dataset item (ground truth, test commands, etc.) 691 - result: AgentResult with full messages, turn count, reasoning, etc. 692 - ctx: ToolContext -- call ANY hermes-agent tool (terminal, file, web, 693 browser, vision...) scoped to this rollout's sandbox. Nothing 694 is off-limits. 695 696 Args: 697 item: The dataset item that was rolled out 698 result: The agent's rollout result 699 ctx: ToolContext with full tool access for verification 700 701 Returns: 702 Reward float (typically 0.0 to 1.0, but any float is valid) 703 """ 704 raise NotImplementedError 705 706 @abstractmethod 707 async def evaluate(self, *args, **kwargs): 708 """ 709 Periodic evaluation. Called every steps_per_eval steps. 710 711 Typical implementation runs the agent on a held-out eval set 712 and logs metrics via wandb/evaluate_log. 713 """ 714 raise NotImplementedError