/ environments / agentic_opd_env.py
agentic_opd_env.py
   1  """
   2  AgenticOPDEnv — On-Policy Distillation for Agentic Tool-Calling Tasks
   3  =====================================================================
   4  
   5  First Atropos environment to populate the distill_token_ids / distill_logprobs
   6  fields on ScoredDataGroup, enabling on-policy distillation (OPD) training.
   7  
   8  Key idea (from OpenClaw-RL, Princeton 2026):
   9    Every time an agent receives a next-state signal (tool result, error trace,
  10    test verdict), that signal contains hindsight information about how the
  11    agent's PREVIOUS response could have been better. This environment:
  12  
  13    1. Runs standard agentic rollouts (tool-calling agent loop)
  14    2. Walks the conversation to find (assistant_turn, next_state) pairs
  15    3. Uses an LLM judge to extract "hints" from next-state signals
  16    4. Builds an enhanced prompt (original context + hint)
  17    5. Scores the student's response tokens under the enhanced distribution
  18       using VLLM's prompt_logprobs (via Atropos's get_logprobs API)
  19    6. Packages the teacher's top-K predictions as distill_token_ids /
  20       distill_logprobs on the ScoredDataGroup
  21  
  22  The trainer then computes per-token advantages:
  23    A_t = teacher_logprob(token_t) - student_logprob(token_t)
  24    Positive → teacher approves this token (upweight)
  25    Negative → teacher disapproves (downweight)
  26  
  27  This gives dense, token-level training signal from every tool interaction,
  28  instead of just a scalar reward at the end of the trajectory.
  29  
  30  Task: Coding tasks with test verification (rich next-state signals from
  31  test results, error messages, terminal output). Falls back to built-in
  32  coding problems if no HuggingFace dataset is configured.
  33  
  34  Requirements:
  35    - VLLM backend (server_type: vllm) — needed for prompt logprob scoring
  36    - Phase 2 mode (ManagedServer) — needed for token-level tracking
  37  
  38  Usage:
  39      # Process mode (offline data generation with OPD)
  40      python environments/agentic_opd_env.py process \\
  41          --env.total_steps 10 --env.group_size 2 \\
  42          --env.data_path_to_save_groups output.jsonl \\
  43          --openai.base_url http://localhost:8000/v1 \\
  44          --openai.model_name Qwen/Qwen3-4B
  45  
  46      # Serve mode (connected to Atropos trainer)
  47      python environments/agentic_opd_env.py serve \\
  48          --openai.base_url http://localhost:8000/v1 \\
  49          --openai.model_name Qwen/Qwen3-4B
  50  
  51      # Evaluate mode
  52      python environments/agentic_opd_env.py evaluate \\
  53          --env.eval_size 10 \\
  54          --openai.base_url http://localhost:8000/v1 \\
  55          --openai.model_name Qwen/Qwen3-4B
  56  
  57  Reference: Wang et al., "OpenClaw-RL: Train Any Agent Simply by Talking"
  58             arXiv:2603.10165, March 2026
  59  """
  60  
  61  from __future__ import annotations
  62  
  63  import asyncio
  64  import copy
  65  import json
  66  import logging
  67  import os
  68  import random
  69  import re
  70  import sys
  71  import time
  72  import uuid
  73  from pathlib import Path
  74  from typing import Any, Dict, List, Optional, Set, Tuple, Union
  75  
  76  from pydantic import Field
  77  
  78  # Ensure hermes-agent root is on path
  79  _repo_root = Path(__file__).resolve().parent.parent
  80  if str(_repo_root) not in sys.path:
  81      sys.path.insert(0, str(_repo_root))
  82  
  83  from atroposlib.envs.base import ScoredDataGroup, ScoredDataItem
  84  from atroposlib.envs.server_handling.server_manager import APIServerConfig
  85  from atroposlib.type_definitions import Item
  86  
  87  from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
  88  from environments.agent_loop import AgentResult, HermesAgentLoop
  89  from environments.tool_context import ToolContext
  90  
  91  logger = logging.getLogger(__name__)
  92  
  93  
  94  # ═══════════════════════════════════════════════════════════════════════
  95  # Built-in coding tasks (fallback when no HF dataset is configured)
  96  # ═══════════════════════════════════════════════════════════════════════
  97  
  98  BUILTIN_CODING_TASKS = [
  99      {
 100          "task": "Write a Python function `fizzbuzz(n)` that returns a list of strings from 1 to n. "
 101          "For multiples of 3 return 'Fizz', for multiples of 5 return 'Buzz', "
 102          "for multiples of both return 'FizzBuzz', otherwise the number as a string.",
 103          "test_code": (
 104              "from solution import fizzbuzz\n"
 105              "assert fizzbuzz(15) == ['1','2','Fizz','4','Buzz','Fizz','7','8','Fizz','Buzz','11','Fizz','13','14','FizzBuzz']\n"
 106              "assert fizzbuzz(1) == ['1']\n"
 107              "assert fizzbuzz(0) == []\n"
 108              "print('All tests passed!')\n"
 109          ),
 110          "difficulty": "easy",
 111      },
 112      {
 113          "task": "Write a Python function `is_palindrome(s)` that checks if a string is a palindrome, "
 114          "ignoring case and non-alphanumeric characters. Return True or False.",
 115          "test_code": (
 116              "from solution import is_palindrome\n"
 117              "assert is_palindrome('A man, a plan, a canal: Panama') == True\n"
 118              "assert is_palindrome('race a car') == False\n"
 119              "assert is_palindrome('') == True\n"
 120              "assert is_palindrome('Was it a car or a cat I saw?') == True\n"
 121              "print('All tests passed!')\n"
 122          ),
 123          "difficulty": "easy",
 124      },
 125      {
 126          "task": "Write a Python function `two_sum(nums, target)` that returns the indices of the two "
 127          "numbers in `nums` that add up to `target`. Assume exactly one solution exists. "
 128          "Return a list of two indices [i, j] where i < j.",
 129          "test_code": (
 130              "from solution import two_sum\n"
 131              "assert two_sum([2, 7, 11, 15], 9) == [0, 1]\n"
 132              "assert two_sum([3, 2, 4], 6) == [1, 2]\n"
 133              "assert two_sum([3, 3], 6) == [0, 1]\n"
 134              "print('All tests passed!')\n"
 135          ),
 136          "difficulty": "easy",
 137      },
 138      {
 139          "task": "Write a Python function `flatten(lst)` that takes an arbitrarily nested list and "
 140          "returns a flat list of all elements. For example, flatten([1, [2, [3, 4], 5]]) "
 141          "should return [1, 2, 3, 4, 5].",
 142          "test_code": (
 143              "from solution import flatten\n"
 144              "assert flatten([1, [2, [3, 4], 5]]) == [1, 2, 3, 4, 5]\n"
 145              "assert flatten([]) == []\n"
 146              "assert flatten([1, 2, 3]) == [1, 2, 3]\n"
 147              "assert flatten([[[[1]]]]) == [1]\n"
 148              "assert flatten([1, [2], [[3]], [[[4]]]]) == [1, 2, 3, 4]\n"
 149              "print('All tests passed!')\n"
 150          ),
 151          "difficulty": "medium",
 152      },
 153      {
 154          "task": "Write a Python function `longest_common_prefix(strs)` that finds the longest "
 155          "common prefix string amongst a list of strings. If there is no common prefix, "
 156          "return an empty string.",
 157          "test_code": (
 158              "from solution import longest_common_prefix\n"
 159              "assert longest_common_prefix(['flower', 'flow', 'flight']) == 'fl'\n"
 160              "assert longest_common_prefix(['dog', 'racecar', 'car']) == ''\n"
 161              "assert longest_common_prefix(['interspecies', 'interstellar', 'interstate']) == 'inters'\n"
 162              "assert longest_common_prefix(['a']) == 'a'\n"
 163              "assert longest_common_prefix([]) == ''\n"
 164              "print('All tests passed!')\n"
 165          ),
 166          "difficulty": "easy",
 167      },
 168      {
 169          "task": "Write a Python function `group_anagrams(strs)` that groups anagrams together. "
 170          "Return a list of lists, where each inner list contains strings that are anagrams of "
 171          "each other. The order of groups and strings within groups does not matter.",
 172          "test_code": (
 173              "from solution import group_anagrams\n"
 174              "result = group_anagrams(['eat', 'tea', 'tan', 'ate', 'nat', 'bat'])\n"
 175              "result_sorted = sorted([sorted(g) for g in result])\n"
 176              "assert result_sorted == [['ate', 'eat', 'tea'], ['bat'], ['nat', 'tan']]\n"
 177              "assert group_anagrams([]) == []\n"
 178              "assert group_anagrams(['a']) == [['a']]\n"
 179              "print('All tests passed!')\n"
 180          ),
 181          "difficulty": "medium",
 182      },
 183      {
 184          "task": "Write a Python function `valid_parentheses(s)` that determines if a string "
 185          "containing just '(', ')', '{', '}', '[' and ']' is valid. A string is valid if "
 186          "open brackets are closed by the same type and in the correct order.",
 187          "test_code": (
 188              "from solution import valid_parentheses\n"
 189              "assert valid_parentheses('()') == True\n"
 190              "assert valid_parentheses('()[]{}') == True\n"
 191              "assert valid_parentheses('(]') == False\n"
 192              "assert valid_parentheses('([)]') == False\n"
 193              "assert valid_parentheses('{[]}') == True\n"
 194              "assert valid_parentheses('') == True\n"
 195              "print('All tests passed!')\n"
 196          ),
 197          "difficulty": "easy",
 198      },
 199      {
 200          "task": "Write a Python function `merge_intervals(intervals)` that merges overlapping "
 201          "intervals. Each interval is a list [start, end]. Return the merged intervals sorted "
 202          "by start time.",
 203          "test_code": (
 204              "from solution import merge_intervals\n"
 205              "assert merge_intervals([[1,3],[2,6],[8,10],[15,18]]) == [[1,6],[8,10],[15,18]]\n"
 206              "assert merge_intervals([[1,4],[4,5]]) == [[1,5]]\n"
 207              "assert merge_intervals([[1,4],[0,4]]) == [[0,4]]\n"
 208              "assert merge_intervals([]) == []\n"
 209              "assert merge_intervals([[1,2]]) == [[1,2]]\n"
 210              "print('All tests passed!')\n"
 211          ),
 212          "difficulty": "medium",
 213      },
 214  ]
 215  
 216  
 217  # ═══════════════════════════════════════════════════════════════════════
 218  # Hint extraction prompts (adapted from OpenClaw-RL)
 219  # ═══════════════════════════════════════════════════════════════════════
 220  
 221  _HINT_JUDGE_SYSTEM = (
 222      "You are a process reward model used for hindsight hint extraction.\n"
 223      "You are given:\n"
 224      "1) The assistant response at turn t.\n"
 225      "2) The next state at turn t+1, along with its **role**.\n\n"
 226      "## Understanding the next state's role\n"
 227      "- role='user': A reply from the user (follow-up, correction, new request, etc.).\n"
 228      "- role='tool': The return value of a tool the assistant invoked. "
 229      "This content was NOT available before the assistant's action — "
 230      "it exists BECAUSE the assistant called the tool. "
 231      "A successful, non-error tool output generally means the assistant's "
 232      "action was appropriate; do NOT treat it as information the assistant "
 233      "should have already known.\n\n"
 234      "Your goal is to decide whether the next state reveals useful hindsight information\n"
 235      "that could have helped improve the assistant response at turn t.\n\n"
 236      "Output format rules (strict):\n"
 237      "- You MUST include exactly one final decision token: \\boxed{1} or \\boxed{-1}.\n"
 238      "- If and only if decision is \\boxed{1}, provide a concise, information-dense hint in 1-3 sentences,\n"
 239      "  wrapped between [HINT_START] and [HINT_END].\n"
 240      "- If decision is \\boxed{-1}, do not provide a hint block.\n"
 241      "- Hint must be concrete and actionable for improving the previous response."
 242  )
 243  
 244  _BOXED_RE = re.compile(r"\\boxed\{(-?\d+)\}")
 245  _HINT_RE = re.compile(r"\[HINT_START\](.*?)\[HINT_END\]", re.DOTALL)
 246  
 247  
 248  def _build_hint_judge_messages(
 249      response_text: str, next_state_text: str, next_state_role: str = "tool"
 250  ) -> list[dict]:
 251      """Build messages for the hint extraction judge."""
 252      user = (
 253          f"## Assistant response (turn t)\n{response_text}\n\n"
 254          f"## Next state (turn t+1) [role: {next_state_role}]\n{next_state_text}\n\n"
 255          "Now output your decision and (if positive) the hint in the required format."
 256      )
 257      return [
 258          {"role": "system", "content": _HINT_JUDGE_SYSTEM},
 259          {"role": "user", "content": user},
 260      ]
 261  
 262  
 263  def _parse_hint_result(text: str) -> tuple[int | None, str]:
 264      """Parse the judge's boxed decision and hint text."""
 265      boxed = _BOXED_RE.findall(text)
 266      score = int(boxed[-1]) if boxed else None
 267      if score not in (1, -1):
 268          score = None
 269      hint_matches = _HINT_RE.findall(text)
 270      hint = hint_matches[-1].strip() if hint_matches else ""
 271      return score, hint
 272  
 273  
 274  def _select_best_hint(votes: list[dict]) -> dict | None:
 275      """Select the best hint from majority-voted judge results."""
 276      good = [
 277          v
 278          for v in votes
 279          if v.get("score") == 1
 280          and isinstance(v.get("hint"), str)
 281          and len(v["hint"].strip()) > 10
 282      ]
 283      if not good:
 284          return None
 285      return max(good, key=lambda v: len(v["hint"].strip()))
 286  
 287  
 288  def _append_hint_to_messages(messages: list[dict], hint: str) -> list[dict]:
 289      """Clone messages and append hint to the last user message."""
 290      cloned = copy.deepcopy(messages)
 291      if not cloned:
 292          return [{"role": "user", "content": f"[user's hint / instruction]\n{hint}"}]
 293  
 294      # Find last user message
 295      target_idx = None
 296      for i in range(len(cloned) - 1, -1, -1):
 297          if cloned[i].get("role") == "user":
 298              target_idx = i
 299              break
 300      if target_idx is None:
 301          target_idx = len(cloned) - 1
 302  
 303      content = cloned[target_idx].get("content", "")
 304      if isinstance(content, list):
 305          content = " ".join(
 306              c.get("text", "") if isinstance(c, dict) else str(c) for c in content
 307          )
 308      suffix = f"\n\n[user's hint / instruction]\n{hint.strip()}"
 309      cloned[target_idx]["content"] = (content + suffix).strip()
 310      return cloned
 311  
 312  
 313  # ═══════════════════════════════════════════════════════════════════════
 314  # Configuration
 315  # ═══════════════════════════════════════════════════════════════════════
 316  
 317  
 318  class AgenticOPDConfig(HermesAgentEnvConfig):
 319      """Configuration for the agentic OPD environment."""
 320  
 321      # --- OPD settings ---
 322      opd_enabled: bool = Field(
 323          default=True,
 324          description="Enable on-policy distillation pipeline. When disabled, "
 325          "the environment behaves like a standard agentic env (no distill fields).",
 326      )
 327      distill_topk: int = Field(
 328          default=50,
 329          description="Number of top-K teacher logprobs per position for distillation.",
 330      )
 331      prm_votes: int = Field(
 332          default=3,
 333          description="Number of independent judge queries for majority-voted hint extraction.",
 334      )
 335      hint_max_next_state_chars: int = Field(
 336          default=4000,
 337          description="Maximum characters of next-state text to include in the hint judge prompt. "
 338          "Tool results can be very long — truncating prevents judge context overflow.",
 339      )
 340  
 341      # --- Reward settings ---
 342      correctness_weight: float = Field(
 343          default=0.7,
 344          description="Weight for test pass/fail in reward.",
 345      )
 346      efficiency_weight: float = Field(
 347          default=0.15,
 348          description="Weight for efficiency (fewer turns = better).",
 349      )
 350      tool_usage_weight: float = Field(
 351          default=0.15,
 352          description="Weight for appropriate tool usage signal.",
 353      )
 354  
 355      # --- Dataset ---
 356      dataset_name: Optional[str] = Field(
 357          default=None,
 358          description="HuggingFace dataset with coding tasks. "
 359          "Expected fields: 'task' (problem description) and 'test_code' (pytest/assert tests). "
 360          "Falls back to built-in tasks if not set or unavailable.",
 361      )
 362  
 363      # --- Eval ---
 364      eval_size: int = Field(
 365          default=10,
 366          description="Number of held-out items for evaluation.",
 367      )
 368      eval_split_ratio: float = Field(
 369          default=0.15,
 370          description="Fraction of dataset to hold out for evaluation.",
 371      )
 372  
 373  
 374  # ═══════════════════════════════════════════════════════════════════════
 375  # Environment
 376  # ═══════════════════════════════════════════════════════════════════════
 377  
 378  
 379  class AgenticOPDEnv(HermesAgentBaseEnv):
 380      """
 381      RL environment with on-policy distillation from next-state signals.
 382  
 383      Runs coding tasks where the agent writes code and runs tests.
 384      Tool results (test pass/fail, error traces) serve as next-state signals
 385      for hint extraction and teacher logprob scoring.
 386  
 387      This is the first Atropos environment to populate distill_token_ids
 388      and distill_logprobs on ScoredDataGroup for OPD training.
 389      """
 390  
 391      name = "agentic-opd"
 392      env_config_cls = AgenticOPDConfig
 393  
 394      # Default toolsets: terminal for running code, file for writing it
 395      default_toolsets = ["terminal", "file"]
 396  
 397      @classmethod
 398      def config_init(cls) -> Tuple[AgenticOPDConfig, List[APIServerConfig]]:
 399          """Default configuration."""
 400          env_config = AgenticOPDConfig(
 401              # Toolsets
 402              enabled_toolsets=["terminal", "file"],
 403              # Agent loop
 404              max_agent_turns=15,
 405              agent_temperature=1.0,
 406              system_prompt=(
 407                  "You are a skilled Python programmer. When given a coding task:\n"
 408                  "1. Write the solution to a file called 'solution.py'\n"
 409                  "2. Write the test code to a file called 'test_solution.py'\n"
 410                  "3. Run the tests with: python test_solution.py\n"
 411                  "4. If tests fail, read the error output carefully, fix your code, and re-run\n"
 412                  "5. Once all tests pass, report success\n\n"
 413                  "Be efficient — write clean code and fix errors methodically."
 414              ),
 415              # OPD
 416              opd_enabled=True,
 417              distill_topk=50,
 418              prm_votes=3,
 419              # Training
 420              group_size=4,
 421              total_steps=500,
 422              steps_per_eval=50,
 423              use_wandb=True,
 424              wandb_name="agentic-opd",
 425          )
 426  
 427          server_configs = [
 428              APIServerConfig(
 429                  base_url="http://localhost:8000/v1",
 430                  model_name="Qwen/Qwen3-4B",
 431                  server_type="vllm",
 432              )
 433          ]
 434  
 435          return env_config, server_configs
 436  
 437      def __init__(self, *args, **kwargs):
 438          super().__init__(*args, **kwargs)
 439          self._items: list[dict] = []
 440          self._eval_items: list[dict] = []
 441          self._index: int = 0
 442  
 443          # Metric buffers
 444          self._reward_buffer: list[float] = []
 445          self._correctness_buffer: list[float] = []
 446          self._efficiency_buffer: list[float] = []
 447          self._tool_usage_buffer: list[float] = []
 448          self._hints_extracted_buffer: list[int] = []
 449          self._opd_turns_scored_buffer: list[int] = []
 450  
 451      # ═══════════════════════════════════════════════════════════════════
 452      # 1. setup — load dataset
 453      # ═══════════════════════════════════════════════════════════════════
 454  
 455      async def setup(self) -> None:
 456          """Load coding tasks from HuggingFace or use built-in set."""
 457          if self.config.dataset_name:
 458              try:
 459                  from datasets import load_dataset
 460  
 461                  logger.info(
 462                      "Loading dataset '%s'...", self.config.dataset_name
 463                  )
 464                  ds = load_dataset(
 465                      self.config.dataset_name, split=self.config.dataset_split
 466                  )
 467                  task_field = self.config.prompt_field
 468                  self._items = [
 469                      {
 470                          "task": row.get(task_field, row.get("task", "")),
 471                          "test_code": row.get("test_code", row.get("tests", "")),
 472                          "difficulty": row.get("difficulty", "unknown"),
 473                      }
 474                      for row in ds
 475                      if row.get(task_field, row.get("task", ""))
 476                  ]
 477                  if self._items:
 478                      random.shuffle(self._items)
 479                      eval_size = max(
 480                          self.config.eval_size,
 481                          int(len(self._items) * self.config.eval_split_ratio),
 482                      )
 483                      self._eval_items = self._items[:eval_size]
 484                      self._items = self._items[eval_size:]
 485                      logger.info(
 486                          "Loaded %d train / %d eval items from '%s'",
 487                          len(self._items),
 488                          len(self._eval_items),
 489                          self.config.dataset_name,
 490                      )
 491                      return
 492              except Exception as e:
 493                  logger.warning(
 494                      "Could not load dataset '%s': %s. Using built-in tasks.",
 495                      self.config.dataset_name,
 496                      e,
 497                  )
 498  
 499          # Fallback to built-in tasks
 500          items = copy.deepcopy(BUILTIN_CODING_TASKS)
 501          random.shuffle(items)
 502          split = max(1, len(items) * 85 // 100)
 503          self._items = items[:split]
 504          self._eval_items = items[split:]
 505          logger.info(
 506              "Using built-in coding tasks: %d train / %d eval items",
 507              len(self._items),
 508              len(self._eval_items),
 509          )
 510  
 511      # ═══════════════════════════════════════════════════════════════════
 512      # 2. get_next_item
 513      # ═══════════════════════════════════════════════════════════════════
 514  
 515      async def get_next_item(self) -> dict:
 516          """Return the next coding task, cycling through the dataset."""
 517          if not self._items:
 518              raise RuntimeError("Dataset is empty. Did you call setup()?")
 519          item = self._items[self._index % len(self._items)]
 520          self._index += 1
 521          return item
 522  
 523      # ═══════════════════════════════════════════════════════════════════
 524      # 3. format_prompt
 525      # ═══════════════════════════════════════════════════════════════════
 526  
 527      def format_prompt(self, item: dict) -> str:
 528          """Format the coding task as a user prompt."""
 529          prompt = (
 530              f"Solve the following coding task.\n\n"
 531              f"## Task\n{item['task']}\n\n"
 532          )
 533          if item.get("test_code"):
 534              prompt += (
 535                  f"## Tests\nThe following test code will be used to verify your solution:\n"
 536                  f"```python\n{item['test_code']}```\n\n"
 537              )
 538          prompt += (
 539              "## Instructions\n"
 540              "1. Write your solution to `solution.py`\n"
 541              "2. Write the test code to `test_solution.py`\n"
 542              "3. Run `python test_solution.py` to verify\n"
 543              "4. Fix any failures and re-run until all tests pass\n"
 544          )
 545          return prompt
 546  
 547      # ═══════════════════════════════════════════════════════════════════
 548      # 4. compute_reward
 549      # ═══════════════════════════════════════════════════════════════════
 550  
 551      async def compute_reward(
 552          self,
 553          item: dict,
 554          result: AgentResult,
 555          ctx: ToolContext,
 556      ) -> float:
 557          """
 558          Multi-signal reward:
 559            - correctness (0.7): Did the tests pass?
 560            - efficiency (0.15): Fewer turns = better
 561            - tool_usage (0.15): Did the agent actually write + run code?
 562          """
 563          cfg = self.config
 564  
 565          # ---- Signal 1: Test correctness ----
 566          # Check if test_solution.py exists and passes in the agent's sandbox
 567          correctness = 0.0
 568          try:
 569              test_result = ctx.terminal("python test_solution.py 2>&1", timeout=30)
 570              output = test_result.get("output", "")
 571              exit_code = test_result.get("exit_code", 1)
 572              if exit_code == 0 and "passed" in output.lower():
 573                  correctness = 1.0
 574              elif exit_code == 0:
 575                  correctness = 0.8  # Ran without error but no explicit "passed"
 576              elif "assert" in output.lower() and "error" in output.lower():
 577                  correctness = 0.2  # Partial — code runs but assertions fail
 578              else:
 579                  correctness = 0.1  # Code errors out entirely
 580          except Exception as e:
 581              logger.debug("Test execution failed in reward: %s", e)
 582              correctness = 0.0
 583  
 584          # ---- Signal 2: Efficiency ----
 585          max_turns = cfg.max_agent_turns
 586          turns_used = result.turns_used
 587          if turns_used <= 3:
 588              efficiency = 1.0
 589          elif turns_used <= max_turns // 2:
 590              efficiency = 0.8
 591          elif turns_used <= max_turns * 3 // 4:
 592              efficiency = 0.5
 593          else:
 594              efficiency = 0.2
 595  
 596          # ---- Signal 3: Tool usage ----
 597          tools_used = set()
 598          for msg in result.messages:
 599              if msg.get("role") == "assistant" and msg.get("tool_calls"):
 600                  for tc in msg["tool_calls"]:
 601                      fn = tc.get("function", {}) if isinstance(tc, dict) else {}
 602                      name = fn.get("name", "")
 603                      if name:
 604                          tools_used.add(name)
 605  
 606          # Good: used both terminal and file tools
 607          if "terminal" in tools_used and ("write_file" in tools_used or "patch" in tools_used):
 608              tool_usage = 1.0
 609          elif "terminal" in tools_used:
 610              tool_usage = 0.6
 611          elif tools_used:
 612              tool_usage = 0.3
 613          else:
 614              tool_usage = 0.0
 615  
 616          # ---- Combine ----
 617          reward = (
 618              cfg.correctness_weight * correctness
 619              + cfg.efficiency_weight * efficiency
 620              + cfg.tool_usage_weight * tool_usage
 621          )
 622          reward = min(1.0, max(0.0, reward))
 623  
 624          # Track metrics
 625          self._reward_buffer.append(reward)
 626          self._correctness_buffer.append(correctness)
 627          self._efficiency_buffer.append(efficiency)
 628          self._tool_usage_buffer.append(tool_usage)
 629  
 630          logger.debug(
 631              "Reward: correctness=%.2f, efficiency=%.2f, tool_usage=%.2f → %.3f",
 632              correctness,
 633              efficiency,
 634              tool_usage,
 635              reward,
 636          )
 637          return reward
 638  
 639      # ═══════════════════════════════════════════════════════════════════
 640      # 5. collect_trajectories — OPD pipeline
 641      # ═══════════════════════════════════════════════════════════════════
 642  
 643      async def collect_trajectories(
 644          self, item: Item
 645      ) -> Tuple[
 646          Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]],
 647          List[Item],
 648      ]:
 649          """
 650          Override collect_trajectories to add the OPD pipeline.
 651  
 652          1. Run standard rollouts via super() → ScoredDataGroup with tokens/masks/scores
 653          2. For each rollout, extract hints from next-state signals
 654          3. Score student tokens under enhanced (hint-augmented) distribution
 655          4. Add distill_token_ids / distill_logprobs to the ScoredDataGroup
 656          """
 657          # Step 1: Run standard rollouts
 658          scored_group, backlog = await super().collect_trajectories(item)
 659  
 660          # Step 2: OPD pipeline (only if enabled and we have VLLM server)
 661          if (
 662              self.config.opd_enabled
 663              and scored_group is not None
 664              and isinstance(scored_group, dict)
 665              and self._use_managed_server()
 666          ):
 667              await self._apply_opd_pipeline(scored_group)
 668  
 669          return scored_group, backlog
 670  
 671      async def _apply_opd_pipeline(self, group: ScoredDataGroup) -> None:
 672          """
 673          Apply on-policy distillation to each rollout in the group.
 674  
 675          For each rollout's messages:
 676          1. Find (assistant, next_state) turn pairs
 677          2. Extract hints via LLM judge with majority voting
 678          3. Build enhanced prompt (original + hint)
 679          4. Score student tokens under enhanced distribution via get_logprobs
 680          5. Add distill_token_ids / distill_logprobs to the group
 681          """
 682          messages_list = group.get("messages", [])
 683          tokens_list = group.get("tokens", [])
 684  
 685          if not messages_list or not tokens_list:
 686              logger.debug("OPD: No messages or tokens to process")
 687              return
 688  
 689          all_distill_token_ids: List[Optional[List[List[int]]]] = []
 690          all_distill_logprobs: List[Optional[List[List[float]]]] = []
 691  
 692          for seq_idx, (messages, student_tokens) in enumerate(
 693              zip(messages_list, tokens_list)
 694          ):
 695              try:
 696                  distill_ids, distill_lps = await self._opd_for_sequence(
 697                      messages, student_tokens
 698                  )
 699                  all_distill_token_ids.append(distill_ids)
 700                  all_distill_logprobs.append(distill_lps)
 701              except Exception as e:
 702                  logger.warning(
 703                      "OPD failed for sequence %d: %s", seq_idx, e
 704                  )
 705                  all_distill_token_ids.append(None)
 706                  all_distill_logprobs.append(None)
 707  
 708          # Only set distill fields if at least one sequence succeeded
 709          any_succeeded = any(d is not None for d in all_distill_token_ids)
 710          if any_succeeded:
 711              # Replace None entries with zero-padded arrays matching token length
 712              for i in range(len(all_distill_token_ids)):
 713                  if all_distill_token_ids[i] is None and i < len(tokens_list):
 714                      seq_len = len(tokens_list[i])
 715                      k = self.config.distill_topk
 716                      all_distill_token_ids[i] = [[0] * k] * seq_len
 717                      all_distill_logprobs[i] = [[0.0] * k] * seq_len
 718  
 719              group["distill_token_ids"] = all_distill_token_ids
 720              group["distill_logprobs"] = all_distill_logprobs
 721              logger.info(
 722                  "OPD: Set distill fields on %d/%d sequences",
 723                  sum(1 for d in all_distill_token_ids if d is not None),
 724                  len(all_distill_token_ids),
 725              )
 726  
 727      async def _opd_for_sequence(
 728          self, messages: List[Dict], student_tokens: List[int]
 729      ) -> Tuple[List[List[int]], List[List[float]]]:
 730          """
 731          Run OPD for a single rollout sequence.
 732  
 733          1. Walk conversation to find (assistant, next_state) pairs
 734          2. Extract hints from next-state signals
 735          3. For each hint-augmented turn, score student tokens via get_logprobs
 736          4. Merge per-turn teacher logprobs into a full-sequence distill array
 737  
 738          Returns:
 739              (distill_token_ids, distill_logprobs) each of shape [seq_len][top_k]
 740          """
 741          k = self.config.distill_topk
 742          seq_len = len(student_tokens)
 743  
 744          # Initialize with zeros (no distill info = neutral)
 745          distill_token_ids: List[List[int]] = [[0] * k for _ in range(seq_len)]
 746          distill_logprobs: List[List[float]] = [[0.0] * k for _ in range(seq_len)]
 747  
 748          # Find (assistant, next_state) turn pairs
 749          turn_pairs = self._extract_turn_pairs(messages)
 750          if not turn_pairs:
 751              return distill_token_ids, distill_logprobs
 752  
 753          hints_extracted = 0
 754          turns_scored = 0
 755  
 756          for pair in turn_pairs:
 757              try:
 758                  hint = await self._extract_hint(
 759                      pair["assistant_text"],
 760                      pair["next_state_text"],
 761                      pair["next_state_role"],
 762                  )
 763                  if not hint:
 764                      continue
 765  
 766                  hints_extracted += 1
 767  
 768                  # Build enhanced prompt with hint
 769                  enhanced_messages = _append_hint_to_messages(
 770                      pair["context_messages"], hint
 771                  )
 772  
 773                  # Tokenize the enhanced prompt
 774                  if not self.tokenizer:
 775                      logger.warning("OPD: No tokenizer available, skipping scoring")
 776                      continue
 777  
 778                  enhanced_prompt = self.tokenizer.apply_chat_template(
 779                      enhanced_messages,
 780                      tokenize=False,
 781                      add_generation_prompt=True,
 782                  )
 783  
 784                  # Tokenize the assistant response to score
 785                  response_text = pair["assistant_text"]
 786                  enhanced_full_text = enhanced_prompt + response_text
 787                  enhanced_ids = self.tokenizer(
 788                      enhanced_full_text, add_special_tokens=False
 789                  )["input_ids"]
 790  
 791                  response_ids = self.tokenizer(
 792                      response_text, add_special_tokens=False
 793                  )["input_ids"]
 794                  response_len = len(response_ids)
 795  
 796                  if response_len == 0:
 797                      continue
 798  
 799                  # Score via get_logprobs — teacher scoring the student's tokens
 800                  # under the enhanced (hint-augmented) distribution
 801                  try:
 802                      logprob_result = await self.server.get_logprobs(
 803                          input_ids=enhanced_ids,
 804                          top_k=k,
 805                          split="eval",  # Use eval semaphore to not block training
 806                      )
 807                  except Exception as e:
 808                      logger.debug("get_logprobs failed: %s", e)
 809                      continue
 810  
 811                  teacher_topk_ids = logprob_result.get("prompt_topk_token_ids", [])
 812                  teacher_topk_lps = logprob_result.get("prompt_topk_logprobs", [])
 813  
 814                  if not teacher_topk_ids:
 815                      continue
 816  
 817                  # Extract only the response positions (last response_len entries)
 818                  if len(teacher_topk_ids) >= response_len:
 819                      resp_topk_ids = teacher_topk_ids[-response_len:]
 820                      resp_topk_lps = teacher_topk_lps[-response_len:]
 821                  else:
 822                      # Pad from the left if the response was shorter than expected
 823                      pad_len = response_len - len(teacher_topk_ids)
 824                      resp_topk_ids = [[0] * k] * pad_len + teacher_topk_ids
 825                      resp_topk_lps = [[0.0] * k] * pad_len + teacher_topk_lps
 826  
 827                  # Map these back to the student's full sequence positions
 828                  # Find where this assistant turn's tokens appear in the full sequence
 829                  turn_start = self._find_token_span(
 830                      student_tokens, response_ids
 831                  )
 832                  if turn_start is not None:
 833                      for j in range(min(response_len, seq_len - turn_start)):
 834                          pos = turn_start + j
 835                          if pos < seq_len and j < len(resp_topk_ids):
 836                              # Pad/truncate to exactly k entries
 837                              ids = resp_topk_ids[j][:k]
 838                              lps = resp_topk_lps[j][:k]
 839                              while len(ids) < k:
 840                                  ids.append(0)
 841                                  lps.append(0.0)
 842                              distill_token_ids[pos] = ids
 843                              distill_logprobs[pos] = lps
 844                      turns_scored += 1
 845  
 846              except Exception as e:
 847                  logger.debug("OPD turn processing failed: %s", e)
 848                  continue
 849  
 850          # Track OPD metrics
 851          self._hints_extracted_buffer.append(hints_extracted)
 852          self._opd_turns_scored_buffer.append(turns_scored)
 853  
 854          logger.debug(
 855              "OPD sequence: %d turn pairs, %d hints extracted, %d turns scored",
 856              len(turn_pairs),
 857              hints_extracted,
 858              turns_scored,
 859          )
 860          return distill_token_ids, distill_logprobs
 861  
 862      def _extract_turn_pairs(
 863          self, messages: List[Dict]
 864      ) -> List[Dict[str, Any]]:
 865          """
 866          Walk conversation messages to find (assistant, next_state) pairs.
 867  
 868          A "turn pair" is an assistant message with content (the response)
 869          followed by one or more tool results or a user reply (the next state).
 870  
 871          Returns list of dicts:
 872            {
 873              "context_messages": messages up to (not including) the assistant turn,
 874              "assistant_text": the assistant's response text,
 875              "next_state_text": the next state content (tool result or user reply),
 876              "next_state_role": "tool" or "user",
 877            }
 878          """
 879          pairs = []
 880          i = 0
 881          while i < len(messages):
 882              msg = messages[i]
 883              if msg.get("role") == "assistant" and msg.get("content"):
 884                  # Found an assistant message with content
 885                  assistant_text = msg["content"]
 886                  context = messages[:i]  # Everything before this turn
 887  
 888                  # Look ahead for next state
 889                  j = i + 1
 890                  # Skip tool_calls-only assistant messages and collect tool results
 891                  next_states = []
 892                  while j < len(messages):
 893                      next_msg = messages[j]
 894                      if next_msg.get("role") == "tool":
 895                          next_states.append(next_msg)
 896                          j += 1
 897                      elif next_msg.get("role") == "user":
 898                          next_states.append(next_msg)
 899                          break
 900                      else:
 901                          break
 902  
 903                  if next_states:
 904                      # Combine all next-state content
 905                      next_text_parts = []
 906                      next_role = next_states[0].get("role", "tool")
 907                      for ns in next_states:
 908                          content = ns.get("content", "")
 909                          if content:
 910                              # Truncate very long tool outputs
 911                              max_chars = self.config.hint_max_next_state_chars
 912                              if len(content) > max_chars:
 913                                  content = content[:max_chars] + "\n...[truncated]"
 914                              next_text_parts.append(content)
 915  
 916                      next_text = "\n---\n".join(next_text_parts)
 917                      if next_text.strip():
 918                          pairs.append(
 919                              {
 920                                  "context_messages": context,
 921                                  "assistant_text": assistant_text,
 922                                  "next_state_text": next_text,
 923                                  "next_state_role": next_role,
 924                              }
 925                          )
 926              i += 1
 927          return pairs
 928  
 929      async def _extract_hint(
 930          self,
 931          assistant_text: str,
 932          next_state_text: str,
 933          next_state_role: str,
 934      ) -> Optional[str]:
 935          """
 936          Extract a hindsight hint from a next-state signal using majority-voted LLM judge.
 937  
 938          Returns the hint string if the judge votes positively, None otherwise.
 939          """
 940          judge_messages = _build_hint_judge_messages(
 941              response_text=assistant_text,
 942              next_state_text=next_state_text,
 943              next_state_role=next_state_role,
 944          )
 945  
 946          # Majority voting across multiple judge queries
 947          votes = []
 948          tasks = []
 949          for _ in range(self.config.prm_votes):
 950              tasks.append(
 951                  self.server.chat_completion(
 952                      messages=judge_messages,
 953                      n=1,
 954                      max_tokens=500,
 955                      temperature=0.7,
 956                      split="eval",
 957                  )
 958              )
 959  
 960          results = await asyncio.gather(*tasks, return_exceptions=True)
 961  
 962          for result in results:
 963              if isinstance(result, Exception):
 964                  logger.debug("Hint judge call failed: %s", result)
 965                  votes.append({"score": None, "hint": ""})
 966                  continue
 967              try:
 968                  text = result.choices[0].message.content or ""
 969                  score, hint = _parse_hint_result(text)
 970                  votes.append({"score": score, "hint": hint})
 971              except Exception as e:
 972                  logger.debug("Hint parse failed: %s", e)
 973                  votes.append({"score": None, "hint": ""})
 974  
 975          selected = _select_best_hint(votes)
 976          if selected is None:
 977              return None
 978          return selected["hint"]
 979  
 980      @staticmethod
 981      def _find_token_span(
 982          full_tokens: List[int], sub_tokens: List[int]
 983      ) -> Optional[int]:
 984          """
 985          Find where sub_tokens appears in full_tokens.
 986          Returns the start index, or None if not found.
 987  
 988          Uses a sliding window search. For long sequences, searches
 989          from the end since assistant responses are typically at the end.
 990          """
 991          if not sub_tokens or not full_tokens:
 992              return None
 993          sub_len = len(sub_tokens)
 994          full_len = len(full_tokens)
 995          if sub_len > full_len:
 996              return None
 997  
 998          # Search backwards (assistant responses are usually near the end)
 999          for i in range(full_len - sub_len, -1, -1):
1000              if full_tokens[i : i + sub_len] == sub_tokens:
1001                  return i
1002          return None
1003  
1004      # ═══════════════════════════════════════════════════════════════════
1005      # 6. evaluate
1006      # ═══════════════════════════════════════════════════════════════════
1007  
1008      async def evaluate(self, *args, **kwargs) -> None:
1009          """
1010          Evaluate on held-out coding tasks using the full agent loop.
1011          No OPD during eval — just standard agentic evaluation.
1012          """
1013          if not self._eval_items:
1014              logger.warning("No eval items available.")
1015              return
1016  
1017          eval_size = min(self.config.eval_size, len(self._eval_items))
1018          eval_items = self._eval_items[:eval_size]
1019  
1020          logger.info("Running eval on %d coding tasks...", len(eval_items))
1021          start_time = time.time()
1022          samples = []
1023  
1024          tools, valid_names = self._resolve_tools_for_group()
1025  
1026          for i, item in enumerate(eval_items):
1027              task_id = str(uuid.uuid4())
1028              logger.info(
1029                  "Eval [%d/%d]: %s...", i + 1, len(eval_items), item["task"][:60]
1030              )
1031  
1032              try:
1033                  messages: List[Dict[str, Any]] = []
1034                  if self.config.system_prompt:
1035                      messages.append(
1036                          {"role": "system", "content": self.config.system_prompt}
1037                      )
1038                  messages.append(
1039                      {"role": "user", "content": self.format_prompt(item)}
1040                  )
1041  
1042                  agent = HermesAgentLoop(
1043                      server=self.server,
1044                      tool_schemas=tools,
1045                      valid_tool_names=valid_names,
1046                      max_turns=self.config.max_agent_turns,
1047                      task_id=task_id,
1048                      temperature=0.0,
1049                      max_tokens=self.config.max_token_length,
1050                      extra_body=self.config.extra_body,
1051                      budget_config=self.config.build_budget_config(),
1052                  )
1053                  result = await agent.run(messages)
1054  
1055                  # Compute reward (track buffer lengths to rollback eval pollution)
1056                  buf_len = len(self._correctness_buffer)
1057                  ctx = ToolContext(task_id)
1058                  try:
1059                      reward = await self.compute_reward(item, result, ctx)
1060                  finally:
1061                      ctx.cleanup()
1062  
1063                  # Extract correctness and rollback training buffers
1064                  correctness = (
1065                      self._correctness_buffer[buf_len]
1066                      if len(self._correctness_buffer) > buf_len
1067                      else 0.0
1068                  )
1069                  for buf in (
1070                      self._reward_buffer,
1071                      self._correctness_buffer,
1072                      self._efficiency_buffer,
1073                      self._tool_usage_buffer,
1074                  ):
1075                      if len(buf) > buf_len:
1076                          buf.pop()
1077  
1078                  # Also rollback OPD buffers if they were touched
1079                  for buf in (
1080                      self._hints_extracted_buffer,
1081                      self._opd_turns_scored_buffer,
1082                  ):
1083                      if len(buf) > buf_len:
1084                          buf.pop()
1085  
1086                  # Extract final response
1087                  final_response = ""
1088                  for msg in reversed(result.messages):
1089                      if (
1090                          msg.get("role") == "assistant"
1091                          and msg.get("content")
1092                          and not final_response
1093                      ):
1094                          final_response = msg["content"]
1095                          break
1096  
1097                  samples.append(
1098                      {
1099                          "prompt": item["task"][:200],
1100                          "response": final_response[:500],
1101                          "correctness": correctness,
1102                          "reward": reward,
1103                          "turns": result.turns_used,
1104                      }
1105                  )
1106  
1107                  logger.info(
1108                      "  → correctness=%.2f, reward=%.3f, turns=%d",
1109                      correctness,
1110                      reward,
1111                      result.turns_used,
1112                  )
1113  
1114              except Exception as e:
1115                  logger.error("Eval error: %s", e)
1116                  samples.append(
1117                      {
1118                          "prompt": item["task"][:200],
1119                          "response": f"ERROR: {e}",
1120                          "correctness": 0.0,
1121                          "reward": 0.0,
1122                          "turns": 0,
1123                      }
1124                  )
1125  
1126          end_time = time.time()
1127  
1128          correctness_scores = [s["correctness"] for s in samples]
1129          rewards = [s["reward"] for s in samples]
1130          n = len(samples)
1131  
1132          eval_metrics = {
1133              "eval/mean_correctness": sum(correctness_scores) / n if n else 0.0,
1134              "eval/mean_reward": sum(rewards) / n if n else 0.0,
1135              "eval/pass_rate": (
1136                  sum(1 for c in correctness_scores if c >= 0.8) / n if n else 0.0
1137              ),
1138              "eval/n_items": n,
1139          }
1140  
1141          logger.info(
1142              "Eval complete — correctness=%.3f, reward=%.3f, pass_rate=%.0f%%",
1143              eval_metrics["eval/mean_correctness"],
1144              eval_metrics["eval/mean_reward"],
1145              eval_metrics["eval/pass_rate"] * 100,
1146          )
1147  
1148          await self.evaluate_log(
1149              metrics=eval_metrics,
1150              samples=samples,
1151              start_time=start_time,
1152              end_time=end_time,
1153          )
1154  
1155      # ═══════════════════════════════════════════════════════════════════
1156      # 7. wandb_log — custom OPD metrics
1157      # ═══════════════════════════════════════════════════════════════════
1158  
1159      async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None:
1160          """Log reward breakdown and OPD-specific metrics to wandb."""
1161          if wandb_metrics is None:
1162              wandb_metrics = {}
1163  
1164          if self._reward_buffer:
1165              n = len(self._reward_buffer)
1166              wandb_metrics["train/mean_reward"] = sum(self._reward_buffer) / n
1167              wandb_metrics["train/mean_correctness"] = (
1168                  sum(self._correctness_buffer) / n
1169              )
1170              wandb_metrics["train/mean_efficiency"] = (
1171                  sum(self._efficiency_buffer) / n
1172              )
1173              wandb_metrics["train/mean_tool_usage"] = (
1174                  sum(self._tool_usage_buffer) / n
1175              )
1176              wandb_metrics["train/pass_rate"] = (
1177                  sum(1 for c in self._correctness_buffer if c >= 0.8) / n
1178              )
1179              wandb_metrics["train/total_rollouts"] = n
1180  
1181              self._reward_buffer.clear()
1182              self._correctness_buffer.clear()
1183              self._efficiency_buffer.clear()
1184              self._tool_usage_buffer.clear()
1185  
1186          # OPD-specific metrics
1187          if self._hints_extracted_buffer:
1188              n = len(self._hints_extracted_buffer)
1189              wandb_metrics["opd/mean_hints_per_rollout"] = (
1190                  sum(self._hints_extracted_buffer) / n
1191              )
1192              wandb_metrics["opd/mean_turns_scored"] = (
1193                  sum(self._opd_turns_scored_buffer) / n
1194              )
1195              wandb_metrics["opd/hint_rate"] = (
1196                  sum(1 for h in self._hints_extracted_buffer if h > 0) / n
1197              )
1198              wandb_metrics["opd/total_hints"] = sum(self._hints_extracted_buffer)
1199              wandb_metrics["opd/total_scored_turns"] = sum(
1200                  self._opd_turns_scored_buffer
1201              )
1202  
1203              self._hints_extracted_buffer.clear()
1204              self._opd_turns_scored_buffer.clear()
1205  
1206          await super().wandb_log(wandb_metrics)
1207  
1208  
1209  # ═══════════════════════════════════════════════════════════════════════
1210  # Entry point
1211  # ═══════════════════════════════════════════════════════════════════════
1212  
1213  if __name__ == "__main__":
1214      AgenticOPDEnv.cli()