/ environments / hermes_base_env.py
hermes_base_env.py
  1  """
  2  HermesAgentBaseEnv -- Abstract Base Environment for Hermes-Agent + Atropos
  3  
  4  Provides the Atropos integration plumbing that all hermes-agent environments share:
  5  - Two-mode operation (OpenAI server for Phase 1, VLLM ManagedServer for Phase 2)
  6  - Per-group toolset/distribution resolution
  7  - Agent loop orchestration via HermesAgentLoop
  8  - ToolContext creation for reward functions
  9  - ScoredDataGroup construction from ManagedServer state
 10  
 11  Subclasses only need to implement:
 12      setup()           -- Load dataset, initialize state
 13      get_next_item()   -- Return the next item from the dataset
 14      format_prompt()   -- Convert a dataset item into the user message
 15      compute_reward()  -- Score the rollout (has full ToolContext access)
 16      evaluate()        -- Periodic evaluation
 17  """
 18  
 19  import asyncio
 20  import json
 21  import logging
 22  import os
 23  import sys
 24  import uuid
 25  from abc import abstractmethod
 26  from pathlib import Path
 27  from typing import Any, Dict, List, Optional, Set, Tuple, Union
 28  
 29  # Ensure the hermes-agent repo root is on sys.path so that imports like
 30  # `from model_tools import ...` and `from environments.X import ...` work
 31  # regardless of where the script is invoked from.
 32  _repo_root = Path(__file__).resolve().parent.parent
 33  if str(_repo_root) not in sys.path:
 34      sys.path.insert(0, str(_repo_root))
 35  
 36  from dotenv import load_dotenv
 37  from pydantic import Field
 38  
 39  # Load API keys from hermes-agent/.env so all environments can access them
 40  _env_path = _repo_root / ".env"
 41  if _env_path.exists():
 42      load_dotenv(dotenv_path=_env_path)
 43  
 44  # Apply monkey patches for async-safe tool operation inside Atropos's event loop.
 45  # This patches SwerexModalEnvironment to use a background thread instead of
 46  # asyncio.run(), which would deadlock inside Atropos. Safe for normal CLI too.
 47  from environments.patches import apply_patches
 48  apply_patches()
 49  
 50  from atroposlib.envs.base import (
 51      BaseEnv,
 52      BaseEnvConfig,
 53      ScoredDataGroup,
 54      ScoredDataItem,
 55  )
 56  from atroposlib.envs.server_handling.server_manager import (
 57      APIServerConfig,
 58      ServerBaseline,
 59      ServerManager,
 60  )
 61  from atroposlib.type_definitions import Item
 62  
 63  from environments.agent_loop import AgentResult, HermesAgentLoop
 64  from environments.tool_context import ToolContext
 65  from tools.budget_config import (
 66      DEFAULT_RESULT_SIZE_CHARS,
 67      DEFAULT_TURN_BUDGET_CHARS,
 68      DEFAULT_PREVIEW_SIZE_CHARS,
 69  )
 70  
 71  # Import hermes-agent toolset infrastructure
 72  from model_tools import get_tool_definitions
 73  from toolset_distributions import sample_toolsets_from_distribution
 74  
 75  logger = logging.getLogger(__name__)
 76  
 77  
 78  class HermesAgentEnvConfig(BaseEnvConfig):
 79      """
 80      Configuration for hermes-agent Atropos environments.
 81  
 82      Extends BaseEnvConfig with agent-specific settings for toolsets,
 83      terminal backend, dataset loading, and tool call parsing.
 84      """
 85  
 86      # --- Toolset configuration ---
 87      # Mutually exclusive: use either enabled_toolsets OR distribution
 88      enabled_toolsets: Optional[List[str]] = Field(
 89          default=None,
 90          description="Explicit list of hermes toolsets to enable (e.g., ['terminal', 'file', 'web']). "
 91          "If None and distribution is also None, all available toolsets are enabled.",
 92      )
 93      disabled_toolsets: Optional[List[str]] = Field(
 94          default=None,
 95          description="Toolsets to disable. Applied as a filter on top of enabled_toolsets or distribution.",
 96      )
 97      distribution: Optional[str] = Field(
 98          default=None,
 99          description="Name of a toolset distribution from toolset_distributions.py "
100          "(e.g., 'development', 'terminal_tasks'). Sampled once per group. "
101          "Mutually exclusive with enabled_toolsets.",
102      )
103  
104      # --- Agent loop configuration ---
105      max_agent_turns: int = Field(
106          default=30,
107          description="Maximum number of LLM calls (tool-calling iterations) per rollout.",
108      )
109      system_prompt: Optional[str] = Field(
110          default=None,
111          description="System prompt for the agent. Tools are handled via the tools= parameter, "
112          "not embedded in the prompt text.",
113      )
114      agent_temperature: float = Field(
115          default=1.0,
116          description="Sampling temperature for agent generation during rollouts.",
117      )
118  
119      # --- Terminal backend ---
120      terminal_backend: str = Field(
121          default="local",
122          description="Terminal backend: 'local', 'docker', 'modal', 'daytona', 'ssh', 'singularity'. "
123          "Modal or Daytona recommended for production RL (cloud isolation per rollout).",
124      )
125      terminal_timeout: int = Field(
126          default=120,
127          description="Per-command timeout in seconds for terminal tool calls. "
128          "Commands exceeding this are killed. Increase for tasks with long-running "
129          "commands (compilation, pip install, etc.).",
130      )
131      terminal_lifetime: int = Field(
132          default=3600,
133          description="Sandbox inactivity lifetime in seconds. The cleanup thread kills "
134          "sandboxes that have been idle longer than this. Must be longer than "
135          "the longest gap between tool calls (e.g., waiting for LLM response).",
136      )
137  
138      # --- Dataset ---
139      dataset_name: Optional[str] = Field(
140          default=None,
141          description="HuggingFace dataset name. Optional if tasks are defined inline.",
142      )
143      dataset_split: str = Field(
144          default="train",
145          description="Dataset split to use.",
146      )
147      prompt_field: str = Field(
148          default="prompt",
149          description="Which field in the dataset contains the prompt.",
150      )
151  
152      # --- Thread pool ---
153      tool_pool_size: int = Field(
154          default=128,
155          description="Thread pool size for tool execution. Each concurrent task needs a "
156          "thread for tool calls. Must be large enough for parallel evaluation. "
157          "Too small = thread pool starvation.",
158      )
159  
160      # --- Phase 2: Tool call parsing ---
161      tool_call_parser: str = Field(
162          default="hermes",
163          description="Tool call parser name for Phase 2 (VLLM server type). "
164          "Ignored in Phase 1 (OpenAI server type where VLLM parses natively). "
165          "Options: hermes, mistral, llama3_json, qwen, deepseek_v3, etc.",
166      )
167  
168      # --- Tool result budget ---
169      # Defaults imported from tools.budget_config (single source of truth).
170      default_result_size_chars: int = Field(
171          default=DEFAULT_RESULT_SIZE_CHARS,
172          description="Default per-tool threshold (chars) for persisting large results "
173          "to sandbox. Results exceeding this are written to /tmp/hermes-results/ "
174          "and replaced with a preview. Per-tool registry values take precedence "
175          "unless overridden via tool_result_overrides.",
176      )
177      turn_budget_chars: int = Field(
178          default=DEFAULT_TURN_BUDGET_CHARS,
179          description="Aggregate char budget per assistant turn. If all tool results "
180          "in a single turn exceed this, the largest are persisted to disk first.",
181      )
182      preview_size_chars: int = Field(
183          default=DEFAULT_PREVIEW_SIZE_CHARS,
184          description="Size of the inline preview shown after a tool result is persisted.",
185      )
186      tool_result_overrides: Optional[Dict[str, int]] = Field(
187          default=None,
188          description="Per-tool threshold overrides (chars). Keys are tool names, "
189          "values are char thresholds. Overrides both the default and registry "
190          "per-tool values. Example: {'terminal': 10000, 'search_files': 5000}. "
191          "Note: read_file is pinned to infinity and cannot be overridden.",
192      )
193  
194      # --- Provider-specific parameters ---
195      # Passed as extra_body to the OpenAI client's chat.completions.create() call.
196      # Useful for OpenRouter provider preferences, transforms, route settings, etc.
197      # Example YAML:
198      #   extra_body:
199      #     provider:
200      #       ignore: ["DeepInfra", "Fireworks"]
201      #       order: ["Together"]
202      #     transforms: ["middle-out"]
203      extra_body: Optional[Dict[str, Any]] = Field(
204          default=None,
205          description="Extra body parameters passed to the OpenAI client's "
206          "chat.completions.create(). Used for OpenRouter provider preferences, "
207          "transforms, and other provider-specific settings.",
208      )
209  
210      def build_budget_config(self):
211          """Build a BudgetConfig from env config fields."""
212          from tools.budget_config import BudgetConfig
213          return BudgetConfig(
214              default_result_size=self.default_result_size_chars,
215              turn_budget=self.turn_budget_chars,
216              preview_size=self.preview_size_chars,
217              tool_overrides=dict(self.tool_result_overrides) if self.tool_result_overrides else {},
218          )
219  
220  
221  class HermesAgentBaseEnv(BaseEnv):
222      """
223      Abstract base environment for hermes-agent Atropos integration.
224  
225      Handles two modes of operation:
226      - Phase 1 (OpenAI server type): Uses server.chat_completion() directly.
227        The server (VLLM, SGLang, OpenRouter, OpenAI) handles tool call parsing
228        and reasoning extraction natively. DummyManagedServer provides placeholder
229        tokens. Good for SFT data gen, verifier testing, evaluation.
230  
231      - Phase 2 (VLLM server type): Uses ManagedServer for exact token IDs + logprobs
232        via /generate. Client-side tool call parser reconstructs structured tool_calls
233        from raw output. Full RL training capability.
234  
235      Subclasses must implement:
236          setup()           -- Load dataset, initialize state
237          get_next_item()   -- Return the next item to roll out
238          format_prompt()   -- Convert a dataset item into the user message string
239          compute_reward()  -- Score the rollout using ToolContext
240          evaluate()        -- Periodic evaluation
241      """
242  
243      name: Optional[str] = "hermes-agent"
244      env_config_cls = HermesAgentEnvConfig
245  
246      def __init__(
247          self,
248          config: HermesAgentEnvConfig,
249          server_configs: Union[ServerBaseline, List[APIServerConfig]],
250          slurm=False,
251          testing=False,
252      ):
253          super().__init__(config, server_configs, slurm, testing)
254  
255          # Set terminal environment variables so hermes tools pick them up.
256          # These can all be overridden per-environment via config fields instead
257          # of requiring users to set shell env vars.
258          if config.terminal_backend:
259              os.environ["TERMINAL_ENV"] = config.terminal_backend
260          os.environ["TERMINAL_TIMEOUT"] = str(config.terminal_timeout)
261          os.environ["TERMINAL_LIFETIME_SECONDS"] = str(config.terminal_lifetime)
262          print(
263              f"🖥️  Terminal: backend={config.terminal_backend}, "
264              f"timeout={config.terminal_timeout}s, lifetime={config.terminal_lifetime}s"
265          )
266  
267          # Resize the agent loop's thread pool for tool execution.
268          # This must be large enough for the number of concurrent tasks
269          # (e.g., 89 parallel TB2 eval tasks each need a thread for tool calls).
270          from environments.agent_loop import resize_tool_pool
271          resize_tool_pool(config.tool_pool_size)
272  
273          # Set tool_parser on the ServerManager so ManagedServer uses it
274          # for bidirectional tool call translation (raw text ↔ OpenAI tool_calls).
275          if hasattr(self.server, 'tool_parser'):
276              self.server.tool_parser = config.tool_call_parser
277              print(f"🔧 Tool parser: {config.tool_call_parser}")
278  
279          # Current group's resolved tools (set in collect_trajectories)
280          self._current_group_tools: Optional[Tuple[List[Dict], Set[str]]] = None
281  
282          # Tool error tracking for wandb logging
283          self._tool_error_buffer: List[Dict[str, Any]] = []
284  
285      # =========================================================================
286      # Toolset resolution (per-group)
287      # =========================================================================
288  
289      def _resolve_tools_for_group(self) -> Tuple[List[Dict[str, Any]], Set[str]]:
290          """
291          Resolve toolsets for a group. Called once in collect_trajectories(),
292          then shared by all collect_trajectory() calls in the group.
293  
294          If distribution is set, samples probabilistically.
295          If enabled_toolsets is set, uses that explicit list.
296          disabled_toolsets is applied as a filter on top.
297  
298          Returns:
299              (tool_schemas, valid_tool_names) tuple
300          """
301          config = self.config
302  
303          if config.distribution:
304              group_toolsets = sample_toolsets_from_distribution(config.distribution)
305              logger.info("Sampled toolsets from '%s': %s", config.distribution, group_toolsets)
306          else:
307              group_toolsets = config.enabled_toolsets  # None means "all available"
308              if group_toolsets is None:
309                  logger.warning(
310                      "enabled_toolsets is None -- loading ALL tools including messaging. "
311                      "Set explicit enabled_toolsets for RL training."
312                  )
313  
314          tools = get_tool_definitions(
315              enabled_toolsets=group_toolsets,
316              disabled_toolsets=config.disabled_toolsets,
317              quiet_mode=True,
318          )
319  
320          valid_names = {t["function"]["name"] for t in tools} if tools else set()
321          logger.info("Resolved %d tools for group: %s", len(valid_names), sorted(valid_names))
322          return tools, valid_names
323  
324      # =========================================================================
325      # Server mode detection
326      # =========================================================================
327  
328      def _use_managed_server(self) -> bool:
329          """
330          Determine if we should use ManagedServer (Phase 2) or direct server (Phase 1).
331  
332          Phase 2 (ManagedServer) is used when the server type is 'vllm' or 'sglang',
333          which go through the /generate endpoint for exact token tracking.
334  
335          Phase 1 (direct server) is used for 'openai' server type, which uses
336          /v1/chat/completions with native tool call parsing.
337          """
338          if not self.server.servers:
339              return False
340  
341          server = self.server.servers[0]
342          # If the server is an OpenAI server (not VLLM/SGLang), use direct mode
343          from atroposlib.envs.server_handling.openai_server import OpenAIServer
344          return not isinstance(server, OpenAIServer)
345  
346      # =========================================================================
347      # Core Atropos integration
348      # =========================================================================
349  
350      async def collect_trajectories(
351          self, item: Item
352      ) -> Tuple[
353          Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]],
354          List[Item],
355      ]:
356          """
357          Override collect_trajectories to resolve toolsets once per group,
358          then delegate to the standard group-level collection.
359  
360          The default BaseEnv.collect_trajectories() calls collect_trajectory()
361          group_size times in parallel. We resolve tools once here and store
362          them for all those calls to use.
363          """
364          # Resolve toolsets for this group (shared by all rollouts in the group)
365          self._current_group_tools = self._resolve_tools_for_group()
366  
367          # Delegate to the default implementation which calls collect_trajectory()
368          # group_size times via asyncio.gather
369          return await super().collect_trajectories(item)
370  
371      # =========================================================================
372      # Wandb rollout display -- format trajectories nicely
373      # =========================================================================
374  
375      @staticmethod
376      def _format_trajectory_for_display(messages: List[Dict[str, Any]]) -> str:
377          """
378          Format a conversation's messages into a readable trajectory string
379          for wandb rollout tables. Shows tool calls, tool results, and reasoning
380          in a structured way instead of raw token decoding.
381          """
382          parts = []
383          for msg in messages:
384              role = msg.get("role", "unknown")
385              content = msg.get("content", "")
386  
387              if role == "system":
388                  parts.append(f"[SYSTEM]\n{content}")
389  
390              elif role == "user":
391                  parts.append(f"[USER]\n{content}")
392  
393              elif role == "assistant":
394                  # Show reasoning if present
395                  reasoning = msg.get("reasoning_content", "")
396                  if reasoning:
397                      # Truncate long reasoning for display
398                      if len(reasoning) > 300:
399                          reasoning = reasoning[:300] + "..."
400                      parts.append(f"[ASSISTANT thinking]\n{reasoning}")
401  
402                  # Show content
403                  if content:
404                      parts.append(f"[ASSISTANT]\n{content}")
405  
406                  # Show tool calls
407                  tool_calls = msg.get("tool_calls", [])
408                  for tc in tool_calls:
409                      func = tc.get("function", {})
410                      name = func.get("name", "?")
411                      args = func.get("arguments", "{}")
412                      # Truncate long arguments for display
413                      if len(args) > 200:
414                          args = args[:200] + "..."
415                      parts.append(f"[TOOL CALL] {name}({args})")
416  
417              elif role == "tool":
418                  tool_id = msg.get("tool_call_id", "")
419                  result = content
420                  # Truncate long tool results for display
421                  if len(result) > 500:
422                      result = result[:500] + "..."
423                  parts.append(f"[TOOL RESULT] {result}")
424  
425          return "\n\n".join(parts)
426  
427      async def add_rollouts_for_wandb(
428          self,
429          scored_data,
430          item=None,
431      ):
432          """
433          Override to show formatted trajectories with tool calls visible,
434          instead of raw token decoding which loses all structure.
435          """
436          num_keep = self.config.num_rollouts_per_group_for_logging
437          if num_keep == -1:
438              num_keep = self.config.group_size
439  
440          group = []
441          for i in range(min(num_keep, len(scored_data.get("scores", [])))):
442              score = scored_data["scores"][i]
443  
444              # Use messages if available for rich display
445              messages = None
446              if scored_data.get("messages") and i < len(scored_data["messages"]):
447                  messages = scored_data["messages"][i]
448  
449              if messages:
450                  text = self._format_trajectory_for_display(messages)
451              elif scored_data.get("tokens") and i < len(scored_data["tokens"]):
452                  text = self.tokenizer.decode(scored_data["tokens"][i])
453              else:
454                  text = "(no data)"
455  
456              group.append((text, score))
457  
458          self.rollouts_for_wandb.append(group)
459          if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
460              self.rollouts_for_wandb.pop(0)
461  
462      async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
463          """Log base metrics including tool errors to wandb."""
464          if wandb_metrics is None:
465              wandb_metrics = {}
466  
467          # Log tool error stats
468          if self._tool_error_buffer:
469              wandb_metrics["train/tool_errors_count"] = len(self._tool_error_buffer)
470  
471              # Log error details as a summary string (tables can crash wandb on tmp cleanup)
472              error_summaries = []
473              for err in self._tool_error_buffer:
474                  error_summaries.append(
475                      f"[turn {err['turn']}] {err['tool']}({err['args'][:80]}) -> {err['error'][:150]}"
476                  )
477              wandb_metrics["train/tool_error_details"] = "\n".join(error_summaries)
478  
479              # Also print to stdout for immediate visibility
480              for summary in error_summaries:
481                  print(f"  Tool Error: {summary}")
482  
483              self._tool_error_buffer = []
484          else:
485              wandb_metrics["train/tool_errors_count"] = 0
486  
487          await super().wandb_log(wandb_metrics)
488  
489      async def collect_trajectory(
490          self, item: Item
491      ) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
492          """
493          Run a single rollout: agent loop + reward computation.
494  
495          This is called group_size times in parallel by collect_trajectories().
496          Each call gets its own task_id for terminal/browser session isolation.
497          """
498          task_id = str(uuid.uuid4())
499  
500          # Get group-level tools (resolved once in collect_trajectories)
501          if self._current_group_tools is None:
502              # Fallback: resolve per-trajectory if called outside collect_trajectories
503              tools, valid_names = self._resolve_tools_for_group()
504          else:
505              tools, valid_names = self._current_group_tools
506  
507          # Build initial messages
508          messages: List[Dict[str, Any]] = []
509          if self.config.system_prompt:
510              messages.append({"role": "system", "content": self.config.system_prompt})
511          messages.append({"role": "user", "content": self.format_prompt(item)})
512  
513          # Run the agent loop
514          result: AgentResult
515          if self._use_managed_server():
516              # Phase 2: ManagedServer with ToolCallTranslator -- exact tokens + logprobs
517              # tool_parser is set on ServerManager in __init__ and passed through
518              # to ManagedServer, which uses ToolCallTranslator for bidirectional
519              # translation between raw text and OpenAI tool_calls.
520              try:
521                  async with self.server.managed_server(
522                      tokenizer=self.tokenizer,
523                      preserve_think_blocks=bool(self.config.thinking_mode),
524                  ) as managed:
525                      agent = HermesAgentLoop(
526                          server=managed,
527                          tool_schemas=tools,
528                          valid_tool_names=valid_names,
529                          max_turns=self.config.max_agent_turns,
530                          task_id=task_id,
531                          temperature=self.config.agent_temperature,
532                          max_tokens=self.config.max_token_length,
533                          extra_body=self.config.extra_body,
534                          budget_config=self.config.build_budget_config(),
535                      )
536                      result = await agent.run(messages)
537              except NotImplementedError:
538                  # DummyManagedServer not allowed -- fall back to Phase 1
539                  logger.warning(
540                      "ManagedServer not available (OpenAI server?). "
541                      "Falling back to direct server mode."
542                  )
543                  agent = HermesAgentLoop(
544                      server=self.server,
545                      tool_schemas=tools,
546                      valid_tool_names=valid_names,
547                      max_turns=self.config.max_agent_turns,
548                      task_id=task_id,
549                      temperature=self.config.agent_temperature,
550                      max_tokens=self.config.max_token_length,
551                      extra_body=self.config.extra_body,
552                      budget_config=self.config.build_budget_config(),
553                  )
554                  result = await agent.run(messages)
555          else:
556              # Phase 1: OpenAI server -- native tool_calls, placeholder tokens
557              agent = HermesAgentLoop(
558                  server=self.server,
559                  tool_schemas=tools,
560                  valid_tool_names=valid_names,
561                  max_turns=self.config.max_agent_turns,
562                  task_id=task_id,
563                  temperature=self.config.agent_temperature,
564                  max_tokens=self.config.max_token_length,
565                  extra_body=self.config.extra_body,
566                  budget_config=self.config.build_budget_config(),
567              )
568              result = await agent.run(messages)
569  
570          # Skip reward computation if the agent loop produced no meaningful work
571          # (e.g., API call failed on turn 1). No point spinning up a Modal sandbox
572          # just to verify files that were never created.
573          only_system_and_user = all(
574              msg.get("role") in ("system", "user") for msg in result.messages
575          )
576          if result.turns_used == 0 or only_system_and_user:
577              logger.warning(
578                  "Agent loop produced no output (turns=%d, msgs=%d). Skipping reward.",
579                  result.turns_used, len(result.messages),
580              )
581              reward = 0.0
582          else:
583              # Compute reward using ToolContext (gives verifier full tool access)
584              ctx = ToolContext(task_id)
585              try:
586                  reward = await self.compute_reward(item, result, ctx)
587              except Exception as e:
588                  logger.error("compute_reward failed: %s", e)
589                  reward = 0.0
590              finally:
591                  ctx.cleanup()
592  
593          # Track tool errors for wandb logging
594          if result.tool_errors:
595              for err in result.tool_errors:
596                  self._tool_error_buffer.append({
597                      "turn": err.turn,
598                      "tool": err.tool_name,
599                      "args": err.arguments[:150],
600                      "error": err.error[:300],
601                      "result": err.tool_result[:300],
602                  })
603  
604          # Build ScoredDataItem from ManagedServer state
605          # Phase 2: real tokens/masks/logprobs from SequenceNodes
606          # Phase 1: placeholder tokens (still need a valid ScoredDataItem for the pipeline)
607          nodes = (result.managed_state or {}).get("nodes", [])
608  
609          if nodes:
610              # Phase 2 (or DummyManagedServer): use actual node data
611              node = nodes[-1]  # Final sequence node = full trajectory
612              scored_item: Dict[str, Any] = {
613                  "tokens": node.tokens,
614                  "masks": node.masked_tokens,
615                  "scores": reward,
616              }
617  
618              # Include logprobs if available (Phase 2)
619              if hasattr(node, "logprobs") and node.logprobs:
620                  scored_item["advantages"] = None  # Computed by trainer
621                  scored_item["ref_logprobs"] = None
622          else:
623              # Phase 1 with no managed state: create placeholder tokens
624              # so the data pipeline doesn't break. These are NOT suitable
625              # for training but allow process mode (SFT data gen) to work.
626              # Tokenize the full conversation to get approximate tokens.
627              full_text = "\n".join(
628                  msg.get("content", "") for msg in result.messages if msg.get("content")
629              )
630              if self.tokenizer:
631                  tokens = self.tokenizer.encode(full_text, add_special_tokens=True)
632              else:
633                  tokens = list(range(min(len(full_text) // 4, 128)))
634  
635              scored_item = {
636                  "tokens": tokens,
637                  "masks": [-100] + tokens[1:],  # Mask first token as prompt
638                  "scores": reward,
639              }
640  
641          # Always include messages for wandb rollout display and data logging
642          scored_item["messages"] = result.messages
643  
644          return scored_item, []
645  
646      # =========================================================================
647      # Abstract methods -- subclasses must implement
648      # =========================================================================
649  
650      @abstractmethod
651      async def setup(self):
652          """
653          Load dataset, initialize state.
654  
655          Called once when the environment starts. Typical implementation:
656              self.dataset = load_dataset(self.config.dataset_name, split=self.config.dataset_split)
657              self.iter = 0
658          """
659          raise NotImplementedError
660  
661      @abstractmethod
662      async def get_next_item(self) -> Item:
663          """
664          Return the next item from the dataset for rollout.
665  
666          Called by the base env's main loop to get items for workers.
667          Should cycle through the dataset.
668          """
669          raise NotImplementedError
670  
671      @abstractmethod
672      def format_prompt(self, item: Item) -> str:
673          """
674          Convert a dataset item into the user message for the agent.
675  
676          Args:
677              item: Dataset item (dict, tuple, etc.)
678  
679          Returns:
680              The prompt string to send to the agent
681          """
682          raise NotImplementedError
683  
684      @abstractmethod
685      async def compute_reward(
686          self, item: Item, result: AgentResult, ctx: ToolContext
687      ) -> float:
688          """
689          Score the rollout. Has full access to:
690          - item: the original dataset item (ground truth, test commands, etc.)
691          - result: AgentResult with full messages, turn count, reasoning, etc.
692          - ctx: ToolContext -- call ANY hermes-agent tool (terminal, file, web,
693                 browser, vision...) scoped to this rollout's sandbox. Nothing
694                 is off-limits.
695  
696          Args:
697              item: The dataset item that was rolled out
698              result: The agent's rollout result
699              ctx: ToolContext with full tool access for verification
700  
701          Returns:
702              Reward float (typically 0.0 to 1.0, but any float is valid)
703          """
704          raise NotImplementedError
705  
706      @abstractmethod
707      async def evaluate(self, *args, **kwargs):
708          """
709          Periodic evaluation. Called every steps_per_eval steps.
710  
711          Typical implementation runs the agent on a held-out eval set
712          and logs metrics via wandb/evaluate_log.
713          """
714          raise NotImplementedError