/ environments / agentic_opd_env.py
agentic_opd_env.py
1 """ 2 AgenticOPDEnv — On-Policy Distillation for Agentic Tool-Calling Tasks 3 ===================================================================== 4 5 First Atropos environment to populate the distill_token_ids / distill_logprobs 6 fields on ScoredDataGroup, enabling on-policy distillation (OPD) training. 7 8 Key idea (from OpenClaw-RL, Princeton 2026): 9 Every time an agent receives a next-state signal (tool result, error trace, 10 test verdict), that signal contains hindsight information about how the 11 agent's PREVIOUS response could have been better. This environment: 12 13 1. Runs standard agentic rollouts (tool-calling agent loop) 14 2. Walks the conversation to find (assistant_turn, next_state) pairs 15 3. Uses an LLM judge to extract "hints" from next-state signals 16 4. Builds an enhanced prompt (original context + hint) 17 5. Scores the student's response tokens under the enhanced distribution 18 using VLLM's prompt_logprobs (via Atropos's get_logprobs API) 19 6. Packages the teacher's top-K predictions as distill_token_ids / 20 distill_logprobs on the ScoredDataGroup 21 22 The trainer then computes per-token advantages: 23 A_t = teacher_logprob(token_t) - student_logprob(token_t) 24 Positive → teacher approves this token (upweight) 25 Negative → teacher disapproves (downweight) 26 27 This gives dense, token-level training signal from every tool interaction, 28 instead of just a scalar reward at the end of the trajectory. 29 30 Task: Coding tasks with test verification (rich next-state signals from 31 test results, error messages, terminal output). Falls back to built-in 32 coding problems if no HuggingFace dataset is configured. 33 34 Requirements: 35 - VLLM backend (server_type: vllm) — needed for prompt logprob scoring 36 - Phase 2 mode (ManagedServer) — needed for token-level tracking 37 38 Usage: 39 # Process mode (offline data generation with OPD) 40 python environments/agentic_opd_env.py process \\ 41 --env.total_steps 10 --env.group_size 2 \\ 42 --env.data_path_to_save_groups output.jsonl \\ 43 --openai.base_url http://localhost:8000/v1 \\ 44 --openai.model_name Qwen/Qwen3-4B 45 46 # Serve mode (connected to Atropos trainer) 47 python environments/agentic_opd_env.py serve \\ 48 --openai.base_url http://localhost:8000/v1 \\ 49 --openai.model_name Qwen/Qwen3-4B 50 51 # Evaluate mode 52 python environments/agentic_opd_env.py evaluate \\ 53 --env.eval_size 10 \\ 54 --openai.base_url http://localhost:8000/v1 \\ 55 --openai.model_name Qwen/Qwen3-4B 56 57 Reference: Wang et al., "OpenClaw-RL: Train Any Agent Simply by Talking" 58 arXiv:2603.10165, March 2026 59 """ 60 61 from __future__ import annotations 62 63 import asyncio 64 import copy 65 import json 66 import logging 67 import os 68 import random 69 import re 70 import sys 71 import time 72 import uuid 73 from pathlib import Path 74 from typing import Any, Dict, List, Optional, Set, Tuple, Union 75 76 from pydantic import Field 77 78 # Ensure hermes-agent root is on path 79 _repo_root = Path(__file__).resolve().parent.parent 80 if str(_repo_root) not in sys.path: 81 sys.path.insert(0, str(_repo_root)) 82 83 from atroposlib.envs.base import ScoredDataGroup, ScoredDataItem 84 from atroposlib.envs.server_handling.server_manager import APIServerConfig 85 from atroposlib.type_definitions import Item 86 87 from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig 88 from environments.agent_loop import AgentResult, HermesAgentLoop 89 from environments.tool_context import ToolContext 90 91 logger = logging.getLogger(__name__) 92 93 94 # ═══════════════════════════════════════════════════════════════════════ 95 # Built-in coding tasks (fallback when no HF dataset is configured) 96 # ═══════════════════════════════════════════════════════════════════════ 97 98 BUILTIN_CODING_TASKS = [ 99 { 100 "task": "Write a Python function `fizzbuzz(n)` that returns a list of strings from 1 to n. " 101 "For multiples of 3 return 'Fizz', for multiples of 5 return 'Buzz', " 102 "for multiples of both return 'FizzBuzz', otherwise the number as a string.", 103 "test_code": ( 104 "from solution import fizzbuzz\n" 105 "assert fizzbuzz(15) == ['1','2','Fizz','4','Buzz','Fizz','7','8','Fizz','Buzz','11','Fizz','13','14','FizzBuzz']\n" 106 "assert fizzbuzz(1) == ['1']\n" 107 "assert fizzbuzz(0) == []\n" 108 "print('All tests passed!')\n" 109 ), 110 "difficulty": "easy", 111 }, 112 { 113 "task": "Write a Python function `is_palindrome(s)` that checks if a string is a palindrome, " 114 "ignoring case and non-alphanumeric characters. Return True or False.", 115 "test_code": ( 116 "from solution import is_palindrome\n" 117 "assert is_palindrome('A man, a plan, a canal: Panama') == True\n" 118 "assert is_palindrome('race a car') == False\n" 119 "assert is_palindrome('') == True\n" 120 "assert is_palindrome('Was it a car or a cat I saw?') == True\n" 121 "print('All tests passed!')\n" 122 ), 123 "difficulty": "easy", 124 }, 125 { 126 "task": "Write a Python function `two_sum(nums, target)` that returns the indices of the two " 127 "numbers in `nums` that add up to `target`. Assume exactly one solution exists. " 128 "Return a list of two indices [i, j] where i < j.", 129 "test_code": ( 130 "from solution import two_sum\n" 131 "assert two_sum([2, 7, 11, 15], 9) == [0, 1]\n" 132 "assert two_sum([3, 2, 4], 6) == [1, 2]\n" 133 "assert two_sum([3, 3], 6) == [0, 1]\n" 134 "print('All tests passed!')\n" 135 ), 136 "difficulty": "easy", 137 }, 138 { 139 "task": "Write a Python function `flatten(lst)` that takes an arbitrarily nested list and " 140 "returns a flat list of all elements. For example, flatten([1, [2, [3, 4], 5]]) " 141 "should return [1, 2, 3, 4, 5].", 142 "test_code": ( 143 "from solution import flatten\n" 144 "assert flatten([1, [2, [3, 4], 5]]) == [1, 2, 3, 4, 5]\n" 145 "assert flatten([]) == []\n" 146 "assert flatten([1, 2, 3]) == [1, 2, 3]\n" 147 "assert flatten([[[[1]]]]) == [1]\n" 148 "assert flatten([1, [2], [[3]], [[[4]]]]) == [1, 2, 3, 4]\n" 149 "print('All tests passed!')\n" 150 ), 151 "difficulty": "medium", 152 }, 153 { 154 "task": "Write a Python function `longest_common_prefix(strs)` that finds the longest " 155 "common prefix string amongst a list of strings. If there is no common prefix, " 156 "return an empty string.", 157 "test_code": ( 158 "from solution import longest_common_prefix\n" 159 "assert longest_common_prefix(['flower', 'flow', 'flight']) == 'fl'\n" 160 "assert longest_common_prefix(['dog', 'racecar', 'car']) == ''\n" 161 "assert longest_common_prefix(['interspecies', 'interstellar', 'interstate']) == 'inters'\n" 162 "assert longest_common_prefix(['a']) == 'a'\n" 163 "assert longest_common_prefix([]) == ''\n" 164 "print('All tests passed!')\n" 165 ), 166 "difficulty": "easy", 167 }, 168 { 169 "task": "Write a Python function `group_anagrams(strs)` that groups anagrams together. " 170 "Return a list of lists, where each inner list contains strings that are anagrams of " 171 "each other. The order of groups and strings within groups does not matter.", 172 "test_code": ( 173 "from solution import group_anagrams\n" 174 "result = group_anagrams(['eat', 'tea', 'tan', 'ate', 'nat', 'bat'])\n" 175 "result_sorted = sorted([sorted(g) for g in result])\n" 176 "assert result_sorted == [['ate', 'eat', 'tea'], ['bat'], ['nat', 'tan']]\n" 177 "assert group_anagrams([]) == []\n" 178 "assert group_anagrams(['a']) == [['a']]\n" 179 "print('All tests passed!')\n" 180 ), 181 "difficulty": "medium", 182 }, 183 { 184 "task": "Write a Python function `valid_parentheses(s)` that determines if a string " 185 "containing just '(', ')', '{', '}', '[' and ']' is valid. A string is valid if " 186 "open brackets are closed by the same type and in the correct order.", 187 "test_code": ( 188 "from solution import valid_parentheses\n" 189 "assert valid_parentheses('()') == True\n" 190 "assert valid_parentheses('()[]{}') == True\n" 191 "assert valid_parentheses('(]') == False\n" 192 "assert valid_parentheses('([)]') == False\n" 193 "assert valid_parentheses('{[]}') == True\n" 194 "assert valid_parentheses('') == True\n" 195 "print('All tests passed!')\n" 196 ), 197 "difficulty": "easy", 198 }, 199 { 200 "task": "Write a Python function `merge_intervals(intervals)` that merges overlapping " 201 "intervals. Each interval is a list [start, end]. Return the merged intervals sorted " 202 "by start time.", 203 "test_code": ( 204 "from solution import merge_intervals\n" 205 "assert merge_intervals([[1,3],[2,6],[8,10],[15,18]]) == [[1,6],[8,10],[15,18]]\n" 206 "assert merge_intervals([[1,4],[4,5]]) == [[1,5]]\n" 207 "assert merge_intervals([[1,4],[0,4]]) == [[0,4]]\n" 208 "assert merge_intervals([]) == []\n" 209 "assert merge_intervals([[1,2]]) == [[1,2]]\n" 210 "print('All tests passed!')\n" 211 ), 212 "difficulty": "medium", 213 }, 214 ] 215 216 217 # ═══════════════════════════════════════════════════════════════════════ 218 # Hint extraction prompts (adapted from OpenClaw-RL) 219 # ═══════════════════════════════════════════════════════════════════════ 220 221 _HINT_JUDGE_SYSTEM = ( 222 "You are a process reward model used for hindsight hint extraction.\n" 223 "You are given:\n" 224 "1) The assistant response at turn t.\n" 225 "2) The next state at turn t+1, along with its **role**.\n\n" 226 "## Understanding the next state's role\n" 227 "- role='user': A reply from the user (follow-up, correction, new request, etc.).\n" 228 "- role='tool': The return value of a tool the assistant invoked. " 229 "This content was NOT available before the assistant's action — " 230 "it exists BECAUSE the assistant called the tool. " 231 "A successful, non-error tool output generally means the assistant's " 232 "action was appropriate; do NOT treat it as information the assistant " 233 "should have already known.\n\n" 234 "Your goal is to decide whether the next state reveals useful hindsight information\n" 235 "that could have helped improve the assistant response at turn t.\n\n" 236 "Output format rules (strict):\n" 237 "- You MUST include exactly one final decision token: \\boxed{1} or \\boxed{-1}.\n" 238 "- If and only if decision is \\boxed{1}, provide a concise, information-dense hint in 1-3 sentences,\n" 239 " wrapped between [HINT_START] and [HINT_END].\n" 240 "- If decision is \\boxed{-1}, do not provide a hint block.\n" 241 "- Hint must be concrete and actionable for improving the previous response." 242 ) 243 244 _BOXED_RE = re.compile(r"\\boxed\{(-?\d+)\}") 245 _HINT_RE = re.compile(r"\[HINT_START\](.*?)\[HINT_END\]", re.DOTALL) 246 247 248 def _build_hint_judge_messages( 249 response_text: str, next_state_text: str, next_state_role: str = "tool" 250 ) -> list[dict]: 251 """Build messages for the hint extraction judge.""" 252 user = ( 253 f"## Assistant response (turn t)\n{response_text}\n\n" 254 f"## Next state (turn t+1) [role: {next_state_role}]\n{next_state_text}\n\n" 255 "Now output your decision and (if positive) the hint in the required format." 256 ) 257 return [ 258 {"role": "system", "content": _HINT_JUDGE_SYSTEM}, 259 {"role": "user", "content": user}, 260 ] 261 262 263 def _parse_hint_result(text: str) -> tuple[int | None, str]: 264 """Parse the judge's boxed decision and hint text.""" 265 boxed = _BOXED_RE.findall(text) 266 score = int(boxed[-1]) if boxed else None 267 if score not in (1, -1): 268 score = None 269 hint_matches = _HINT_RE.findall(text) 270 hint = hint_matches[-1].strip() if hint_matches else "" 271 return score, hint 272 273 274 def _select_best_hint(votes: list[dict]) -> dict | None: 275 """Select the best hint from majority-voted judge results.""" 276 good = [ 277 v 278 for v in votes 279 if v.get("score") == 1 280 and isinstance(v.get("hint"), str) 281 and len(v["hint"].strip()) > 10 282 ] 283 if not good: 284 return None 285 return max(good, key=lambda v: len(v["hint"].strip())) 286 287 288 def _append_hint_to_messages(messages: list[dict], hint: str) -> list[dict]: 289 """Clone messages and append hint to the last user message.""" 290 cloned = copy.deepcopy(messages) 291 if not cloned: 292 return [{"role": "user", "content": f"[user's hint / instruction]\n{hint}"}] 293 294 # Find last user message 295 target_idx = None 296 for i in range(len(cloned) - 1, -1, -1): 297 if cloned[i].get("role") == "user": 298 target_idx = i 299 break 300 if target_idx is None: 301 target_idx = len(cloned) - 1 302 303 content = cloned[target_idx].get("content", "") 304 if isinstance(content, list): 305 content = " ".join( 306 c.get("text", "") if isinstance(c, dict) else str(c) for c in content 307 ) 308 suffix = f"\n\n[user's hint / instruction]\n{hint.strip()}" 309 cloned[target_idx]["content"] = (content + suffix).strip() 310 return cloned 311 312 313 # ═══════════════════════════════════════════════════════════════════════ 314 # Configuration 315 # ═══════════════════════════════════════════════════════════════════════ 316 317 318 class AgenticOPDConfig(HermesAgentEnvConfig): 319 """Configuration for the agentic OPD environment.""" 320 321 # --- OPD settings --- 322 opd_enabled: bool = Field( 323 default=True, 324 description="Enable on-policy distillation pipeline. When disabled, " 325 "the environment behaves like a standard agentic env (no distill fields).", 326 ) 327 distill_topk: int = Field( 328 default=50, 329 description="Number of top-K teacher logprobs per position for distillation.", 330 ) 331 prm_votes: int = Field( 332 default=3, 333 description="Number of independent judge queries for majority-voted hint extraction.", 334 ) 335 hint_max_next_state_chars: int = Field( 336 default=4000, 337 description="Maximum characters of next-state text to include in the hint judge prompt. " 338 "Tool results can be very long — truncating prevents judge context overflow.", 339 ) 340 341 # --- Reward settings --- 342 correctness_weight: float = Field( 343 default=0.7, 344 description="Weight for test pass/fail in reward.", 345 ) 346 efficiency_weight: float = Field( 347 default=0.15, 348 description="Weight for efficiency (fewer turns = better).", 349 ) 350 tool_usage_weight: float = Field( 351 default=0.15, 352 description="Weight for appropriate tool usage signal.", 353 ) 354 355 # --- Dataset --- 356 dataset_name: Optional[str] = Field( 357 default=None, 358 description="HuggingFace dataset with coding tasks. " 359 "Expected fields: 'task' (problem description) and 'test_code' (pytest/assert tests). " 360 "Falls back to built-in tasks if not set or unavailable.", 361 ) 362 363 # --- Eval --- 364 eval_size: int = Field( 365 default=10, 366 description="Number of held-out items for evaluation.", 367 ) 368 eval_split_ratio: float = Field( 369 default=0.15, 370 description="Fraction of dataset to hold out for evaluation.", 371 ) 372 373 374 # ═══════════════════════════════════════════════════════════════════════ 375 # Environment 376 # ═══════════════════════════════════════════════════════════════════════ 377 378 379 class AgenticOPDEnv(HermesAgentBaseEnv): 380 """ 381 RL environment with on-policy distillation from next-state signals. 382 383 Runs coding tasks where the agent writes code and runs tests. 384 Tool results (test pass/fail, error traces) serve as next-state signals 385 for hint extraction and teacher logprob scoring. 386 387 This is the first Atropos environment to populate distill_token_ids 388 and distill_logprobs on ScoredDataGroup for OPD training. 389 """ 390 391 name = "agentic-opd" 392 env_config_cls = AgenticOPDConfig 393 394 # Default toolsets: terminal for running code, file for writing it 395 default_toolsets = ["terminal", "file"] 396 397 @classmethod 398 def config_init(cls) -> Tuple[AgenticOPDConfig, List[APIServerConfig]]: 399 """Default configuration.""" 400 env_config = AgenticOPDConfig( 401 # Toolsets 402 enabled_toolsets=["terminal", "file"], 403 # Agent loop 404 max_agent_turns=15, 405 agent_temperature=1.0, 406 system_prompt=( 407 "You are a skilled Python programmer. When given a coding task:\n" 408 "1. Write the solution to a file called 'solution.py'\n" 409 "2. Write the test code to a file called 'test_solution.py'\n" 410 "3. Run the tests with: python test_solution.py\n" 411 "4. If tests fail, read the error output carefully, fix your code, and re-run\n" 412 "5. Once all tests pass, report success\n\n" 413 "Be efficient — write clean code and fix errors methodically." 414 ), 415 # OPD 416 opd_enabled=True, 417 distill_topk=50, 418 prm_votes=3, 419 # Training 420 group_size=4, 421 total_steps=500, 422 steps_per_eval=50, 423 use_wandb=True, 424 wandb_name="agentic-opd", 425 ) 426 427 server_configs = [ 428 APIServerConfig( 429 base_url="http://localhost:8000/v1", 430 model_name="Qwen/Qwen3-4B", 431 server_type="vllm", 432 ) 433 ] 434 435 return env_config, server_configs 436 437 def __init__(self, *args, **kwargs): 438 super().__init__(*args, **kwargs) 439 self._items: list[dict] = [] 440 self._eval_items: list[dict] = [] 441 self._index: int = 0 442 443 # Metric buffers 444 self._reward_buffer: list[float] = [] 445 self._correctness_buffer: list[float] = [] 446 self._efficiency_buffer: list[float] = [] 447 self._tool_usage_buffer: list[float] = [] 448 self._hints_extracted_buffer: list[int] = [] 449 self._opd_turns_scored_buffer: list[int] = [] 450 451 # ═══════════════════════════════════════════════════════════════════ 452 # 1. setup — load dataset 453 # ═══════════════════════════════════════════════════════════════════ 454 455 async def setup(self) -> None: 456 """Load coding tasks from HuggingFace or use built-in set.""" 457 if self.config.dataset_name: 458 try: 459 from datasets import load_dataset 460 461 logger.info( 462 "Loading dataset '%s'...", self.config.dataset_name 463 ) 464 ds = load_dataset( 465 self.config.dataset_name, split=self.config.dataset_split 466 ) 467 task_field = self.config.prompt_field 468 self._items = [ 469 { 470 "task": row.get(task_field, row.get("task", "")), 471 "test_code": row.get("test_code", row.get("tests", "")), 472 "difficulty": row.get("difficulty", "unknown"), 473 } 474 for row in ds 475 if row.get(task_field, row.get("task", "")) 476 ] 477 if self._items: 478 random.shuffle(self._items) 479 eval_size = max( 480 self.config.eval_size, 481 int(len(self._items) * self.config.eval_split_ratio), 482 ) 483 self._eval_items = self._items[:eval_size] 484 self._items = self._items[eval_size:] 485 logger.info( 486 "Loaded %d train / %d eval items from '%s'", 487 len(self._items), 488 len(self._eval_items), 489 self.config.dataset_name, 490 ) 491 return 492 except Exception as e: 493 logger.warning( 494 "Could not load dataset '%s': %s. Using built-in tasks.", 495 self.config.dataset_name, 496 e, 497 ) 498 499 # Fallback to built-in tasks 500 items = copy.deepcopy(BUILTIN_CODING_TASKS) 501 random.shuffle(items) 502 split = max(1, len(items) * 85 // 100) 503 self._items = items[:split] 504 self._eval_items = items[split:] 505 logger.info( 506 "Using built-in coding tasks: %d train / %d eval items", 507 len(self._items), 508 len(self._eval_items), 509 ) 510 511 # ═══════════════════════════════════════════════════════════════════ 512 # 2. get_next_item 513 # ═══════════════════════════════════════════════════════════════════ 514 515 async def get_next_item(self) -> dict: 516 """Return the next coding task, cycling through the dataset.""" 517 if not self._items: 518 raise RuntimeError("Dataset is empty. Did you call setup()?") 519 item = self._items[self._index % len(self._items)] 520 self._index += 1 521 return item 522 523 # ═══════════════════════════════════════════════════════════════════ 524 # 3. format_prompt 525 # ═══════════════════════════════════════════════════════════════════ 526 527 def format_prompt(self, item: dict) -> str: 528 """Format the coding task as a user prompt.""" 529 prompt = ( 530 f"Solve the following coding task.\n\n" 531 f"## Task\n{item['task']}\n\n" 532 ) 533 if item.get("test_code"): 534 prompt += ( 535 f"## Tests\nThe following test code will be used to verify your solution:\n" 536 f"```python\n{item['test_code']}```\n\n" 537 ) 538 prompt += ( 539 "## Instructions\n" 540 "1. Write your solution to `solution.py`\n" 541 "2. Write the test code to `test_solution.py`\n" 542 "3. Run `python test_solution.py` to verify\n" 543 "4. Fix any failures and re-run until all tests pass\n" 544 ) 545 return prompt 546 547 # ═══════════════════════════════════════════════════════════════════ 548 # 4. compute_reward 549 # ═══════════════════════════════════════════════════════════════════ 550 551 async def compute_reward( 552 self, 553 item: dict, 554 result: AgentResult, 555 ctx: ToolContext, 556 ) -> float: 557 """ 558 Multi-signal reward: 559 - correctness (0.7): Did the tests pass? 560 - efficiency (0.15): Fewer turns = better 561 - tool_usage (0.15): Did the agent actually write + run code? 562 """ 563 cfg = self.config 564 565 # ---- Signal 1: Test correctness ---- 566 # Check if test_solution.py exists and passes in the agent's sandbox 567 correctness = 0.0 568 try: 569 test_result = ctx.terminal("python test_solution.py 2>&1", timeout=30) 570 output = test_result.get("output", "") 571 exit_code = test_result.get("exit_code", 1) 572 if exit_code == 0 and "passed" in output.lower(): 573 correctness = 1.0 574 elif exit_code == 0: 575 correctness = 0.8 # Ran without error but no explicit "passed" 576 elif "assert" in output.lower() and "error" in output.lower(): 577 correctness = 0.2 # Partial — code runs but assertions fail 578 else: 579 correctness = 0.1 # Code errors out entirely 580 except Exception as e: 581 logger.debug("Test execution failed in reward: %s", e) 582 correctness = 0.0 583 584 # ---- Signal 2: Efficiency ---- 585 max_turns = cfg.max_agent_turns 586 turns_used = result.turns_used 587 if turns_used <= 3: 588 efficiency = 1.0 589 elif turns_used <= max_turns // 2: 590 efficiency = 0.8 591 elif turns_used <= max_turns * 3 // 4: 592 efficiency = 0.5 593 else: 594 efficiency = 0.2 595 596 # ---- Signal 3: Tool usage ---- 597 tools_used = set() 598 for msg in result.messages: 599 if msg.get("role") == "assistant" and msg.get("tool_calls"): 600 for tc in msg["tool_calls"]: 601 fn = tc.get("function", {}) if isinstance(tc, dict) else {} 602 name = fn.get("name", "") 603 if name: 604 tools_used.add(name) 605 606 # Good: used both terminal and file tools 607 if "terminal" in tools_used and ("write_file" in tools_used or "patch" in tools_used): 608 tool_usage = 1.0 609 elif "terminal" in tools_used: 610 tool_usage = 0.6 611 elif tools_used: 612 tool_usage = 0.3 613 else: 614 tool_usage = 0.0 615 616 # ---- Combine ---- 617 reward = ( 618 cfg.correctness_weight * correctness 619 + cfg.efficiency_weight * efficiency 620 + cfg.tool_usage_weight * tool_usage 621 ) 622 reward = min(1.0, max(0.0, reward)) 623 624 # Track metrics 625 self._reward_buffer.append(reward) 626 self._correctness_buffer.append(correctness) 627 self._efficiency_buffer.append(efficiency) 628 self._tool_usage_buffer.append(tool_usage) 629 630 logger.debug( 631 "Reward: correctness=%.2f, efficiency=%.2f, tool_usage=%.2f → %.3f", 632 correctness, 633 efficiency, 634 tool_usage, 635 reward, 636 ) 637 return reward 638 639 # ═══════════════════════════════════════════════════════════════════ 640 # 5. collect_trajectories — OPD pipeline 641 # ═══════════════════════════════════════════════════════════════════ 642 643 async def collect_trajectories( 644 self, item: Item 645 ) -> Tuple[ 646 Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]], 647 List[Item], 648 ]: 649 """ 650 Override collect_trajectories to add the OPD pipeline. 651 652 1. Run standard rollouts via super() → ScoredDataGroup with tokens/masks/scores 653 2. For each rollout, extract hints from next-state signals 654 3. Score student tokens under enhanced (hint-augmented) distribution 655 4. Add distill_token_ids / distill_logprobs to the ScoredDataGroup 656 """ 657 # Step 1: Run standard rollouts 658 scored_group, backlog = await super().collect_trajectories(item) 659 660 # Step 2: OPD pipeline (only if enabled and we have VLLM server) 661 if ( 662 self.config.opd_enabled 663 and scored_group is not None 664 and isinstance(scored_group, dict) 665 and self._use_managed_server() 666 ): 667 await self._apply_opd_pipeline(scored_group) 668 669 return scored_group, backlog 670 671 async def _apply_opd_pipeline(self, group: ScoredDataGroup) -> None: 672 """ 673 Apply on-policy distillation to each rollout in the group. 674 675 For each rollout's messages: 676 1. Find (assistant, next_state) turn pairs 677 2. Extract hints via LLM judge with majority voting 678 3. Build enhanced prompt (original + hint) 679 4. Score student tokens under enhanced distribution via get_logprobs 680 5. Add distill_token_ids / distill_logprobs to the group 681 """ 682 messages_list = group.get("messages", []) 683 tokens_list = group.get("tokens", []) 684 685 if not messages_list or not tokens_list: 686 logger.debug("OPD: No messages or tokens to process") 687 return 688 689 all_distill_token_ids: List[Optional[List[List[int]]]] = [] 690 all_distill_logprobs: List[Optional[List[List[float]]]] = [] 691 692 for seq_idx, (messages, student_tokens) in enumerate( 693 zip(messages_list, tokens_list) 694 ): 695 try: 696 distill_ids, distill_lps = await self._opd_for_sequence( 697 messages, student_tokens 698 ) 699 all_distill_token_ids.append(distill_ids) 700 all_distill_logprobs.append(distill_lps) 701 except Exception as e: 702 logger.warning( 703 "OPD failed for sequence %d: %s", seq_idx, e 704 ) 705 all_distill_token_ids.append(None) 706 all_distill_logprobs.append(None) 707 708 # Only set distill fields if at least one sequence succeeded 709 any_succeeded = any(d is not None for d in all_distill_token_ids) 710 if any_succeeded: 711 # Replace None entries with zero-padded arrays matching token length 712 for i in range(len(all_distill_token_ids)): 713 if all_distill_token_ids[i] is None and i < len(tokens_list): 714 seq_len = len(tokens_list[i]) 715 k = self.config.distill_topk 716 all_distill_token_ids[i] = [[0] * k] * seq_len 717 all_distill_logprobs[i] = [[0.0] * k] * seq_len 718 719 group["distill_token_ids"] = all_distill_token_ids 720 group["distill_logprobs"] = all_distill_logprobs 721 logger.info( 722 "OPD: Set distill fields on %d/%d sequences", 723 sum(1 for d in all_distill_token_ids if d is not None), 724 len(all_distill_token_ids), 725 ) 726 727 async def _opd_for_sequence( 728 self, messages: List[Dict], student_tokens: List[int] 729 ) -> Tuple[List[List[int]], List[List[float]]]: 730 """ 731 Run OPD for a single rollout sequence. 732 733 1. Walk conversation to find (assistant, next_state) pairs 734 2. Extract hints from next-state signals 735 3. For each hint-augmented turn, score student tokens via get_logprobs 736 4. Merge per-turn teacher logprobs into a full-sequence distill array 737 738 Returns: 739 (distill_token_ids, distill_logprobs) each of shape [seq_len][top_k] 740 """ 741 k = self.config.distill_topk 742 seq_len = len(student_tokens) 743 744 # Initialize with zeros (no distill info = neutral) 745 distill_token_ids: List[List[int]] = [[0] * k for _ in range(seq_len)] 746 distill_logprobs: List[List[float]] = [[0.0] * k for _ in range(seq_len)] 747 748 # Find (assistant, next_state) turn pairs 749 turn_pairs = self._extract_turn_pairs(messages) 750 if not turn_pairs: 751 return distill_token_ids, distill_logprobs 752 753 hints_extracted = 0 754 turns_scored = 0 755 756 for pair in turn_pairs: 757 try: 758 hint = await self._extract_hint( 759 pair["assistant_text"], 760 pair["next_state_text"], 761 pair["next_state_role"], 762 ) 763 if not hint: 764 continue 765 766 hints_extracted += 1 767 768 # Build enhanced prompt with hint 769 enhanced_messages = _append_hint_to_messages( 770 pair["context_messages"], hint 771 ) 772 773 # Tokenize the enhanced prompt 774 if not self.tokenizer: 775 logger.warning("OPD: No tokenizer available, skipping scoring") 776 continue 777 778 enhanced_prompt = self.tokenizer.apply_chat_template( 779 enhanced_messages, 780 tokenize=False, 781 add_generation_prompt=True, 782 ) 783 784 # Tokenize the assistant response to score 785 response_text = pair["assistant_text"] 786 enhanced_full_text = enhanced_prompt + response_text 787 enhanced_ids = self.tokenizer( 788 enhanced_full_text, add_special_tokens=False 789 )["input_ids"] 790 791 response_ids = self.tokenizer( 792 response_text, add_special_tokens=False 793 )["input_ids"] 794 response_len = len(response_ids) 795 796 if response_len == 0: 797 continue 798 799 # Score via get_logprobs — teacher scoring the student's tokens 800 # under the enhanced (hint-augmented) distribution 801 try: 802 logprob_result = await self.server.get_logprobs( 803 input_ids=enhanced_ids, 804 top_k=k, 805 split="eval", # Use eval semaphore to not block training 806 ) 807 except Exception as e: 808 logger.debug("get_logprobs failed: %s", e) 809 continue 810 811 teacher_topk_ids = logprob_result.get("prompt_topk_token_ids", []) 812 teacher_topk_lps = logprob_result.get("prompt_topk_logprobs", []) 813 814 if not teacher_topk_ids: 815 continue 816 817 # Extract only the response positions (last response_len entries) 818 if len(teacher_topk_ids) >= response_len: 819 resp_topk_ids = teacher_topk_ids[-response_len:] 820 resp_topk_lps = teacher_topk_lps[-response_len:] 821 else: 822 # Pad from the left if the response was shorter than expected 823 pad_len = response_len - len(teacher_topk_ids) 824 resp_topk_ids = [[0] * k] * pad_len + teacher_topk_ids 825 resp_topk_lps = [[0.0] * k] * pad_len + teacher_topk_lps 826 827 # Map these back to the student's full sequence positions 828 # Find where this assistant turn's tokens appear in the full sequence 829 turn_start = self._find_token_span( 830 student_tokens, response_ids 831 ) 832 if turn_start is not None: 833 for j in range(min(response_len, seq_len - turn_start)): 834 pos = turn_start + j 835 if pos < seq_len and j < len(resp_topk_ids): 836 # Pad/truncate to exactly k entries 837 ids = resp_topk_ids[j][:k] 838 lps = resp_topk_lps[j][:k] 839 while len(ids) < k: 840 ids.append(0) 841 lps.append(0.0) 842 distill_token_ids[pos] = ids 843 distill_logprobs[pos] = lps 844 turns_scored += 1 845 846 except Exception as e: 847 logger.debug("OPD turn processing failed: %s", e) 848 continue 849 850 # Track OPD metrics 851 self._hints_extracted_buffer.append(hints_extracted) 852 self._opd_turns_scored_buffer.append(turns_scored) 853 854 logger.debug( 855 "OPD sequence: %d turn pairs, %d hints extracted, %d turns scored", 856 len(turn_pairs), 857 hints_extracted, 858 turns_scored, 859 ) 860 return distill_token_ids, distill_logprobs 861 862 def _extract_turn_pairs( 863 self, messages: List[Dict] 864 ) -> List[Dict[str, Any]]: 865 """ 866 Walk conversation messages to find (assistant, next_state) pairs. 867 868 A "turn pair" is an assistant message with content (the response) 869 followed by one or more tool results or a user reply (the next state). 870 871 Returns list of dicts: 872 { 873 "context_messages": messages up to (not including) the assistant turn, 874 "assistant_text": the assistant's response text, 875 "next_state_text": the next state content (tool result or user reply), 876 "next_state_role": "tool" or "user", 877 } 878 """ 879 pairs = [] 880 i = 0 881 while i < len(messages): 882 msg = messages[i] 883 if msg.get("role") == "assistant" and msg.get("content"): 884 # Found an assistant message with content 885 assistant_text = msg["content"] 886 context = messages[:i] # Everything before this turn 887 888 # Look ahead for next state 889 j = i + 1 890 # Skip tool_calls-only assistant messages and collect tool results 891 next_states = [] 892 while j < len(messages): 893 next_msg = messages[j] 894 if next_msg.get("role") == "tool": 895 next_states.append(next_msg) 896 j += 1 897 elif next_msg.get("role") == "user": 898 next_states.append(next_msg) 899 break 900 else: 901 break 902 903 if next_states: 904 # Combine all next-state content 905 next_text_parts = [] 906 next_role = next_states[0].get("role", "tool") 907 for ns in next_states: 908 content = ns.get("content", "") 909 if content: 910 # Truncate very long tool outputs 911 max_chars = self.config.hint_max_next_state_chars 912 if len(content) > max_chars: 913 content = content[:max_chars] + "\n...[truncated]" 914 next_text_parts.append(content) 915 916 next_text = "\n---\n".join(next_text_parts) 917 if next_text.strip(): 918 pairs.append( 919 { 920 "context_messages": context, 921 "assistant_text": assistant_text, 922 "next_state_text": next_text, 923 "next_state_role": next_role, 924 } 925 ) 926 i += 1 927 return pairs 928 929 async def _extract_hint( 930 self, 931 assistant_text: str, 932 next_state_text: str, 933 next_state_role: str, 934 ) -> Optional[str]: 935 """ 936 Extract a hindsight hint from a next-state signal using majority-voted LLM judge. 937 938 Returns the hint string if the judge votes positively, None otherwise. 939 """ 940 judge_messages = _build_hint_judge_messages( 941 response_text=assistant_text, 942 next_state_text=next_state_text, 943 next_state_role=next_state_role, 944 ) 945 946 # Majority voting across multiple judge queries 947 votes = [] 948 tasks = [] 949 for _ in range(self.config.prm_votes): 950 tasks.append( 951 self.server.chat_completion( 952 messages=judge_messages, 953 n=1, 954 max_tokens=500, 955 temperature=0.7, 956 split="eval", 957 ) 958 ) 959 960 results = await asyncio.gather(*tasks, return_exceptions=True) 961 962 for result in results: 963 if isinstance(result, Exception): 964 logger.debug("Hint judge call failed: %s", result) 965 votes.append({"score": None, "hint": ""}) 966 continue 967 try: 968 text = result.choices[0].message.content or "" 969 score, hint = _parse_hint_result(text) 970 votes.append({"score": score, "hint": hint}) 971 except Exception as e: 972 logger.debug("Hint parse failed: %s", e) 973 votes.append({"score": None, "hint": ""}) 974 975 selected = _select_best_hint(votes) 976 if selected is None: 977 return None 978 return selected["hint"] 979 980 @staticmethod 981 def _find_token_span( 982 full_tokens: List[int], sub_tokens: List[int] 983 ) -> Optional[int]: 984 """ 985 Find where sub_tokens appears in full_tokens. 986 Returns the start index, or None if not found. 987 988 Uses a sliding window search. For long sequences, searches 989 from the end since assistant responses are typically at the end. 990 """ 991 if not sub_tokens or not full_tokens: 992 return None 993 sub_len = len(sub_tokens) 994 full_len = len(full_tokens) 995 if sub_len > full_len: 996 return None 997 998 # Search backwards (assistant responses are usually near the end) 999 for i in range(full_len - sub_len, -1, -1): 1000 if full_tokens[i : i + sub_len] == sub_tokens: 1001 return i 1002 return None 1003 1004 # ═══════════════════════════════════════════════════════════════════ 1005 # 6. evaluate 1006 # ═══════════════════════════════════════════════════════════════════ 1007 1008 async def evaluate(self, *args, **kwargs) -> None: 1009 """ 1010 Evaluate on held-out coding tasks using the full agent loop. 1011 No OPD during eval — just standard agentic evaluation. 1012 """ 1013 if not self._eval_items: 1014 logger.warning("No eval items available.") 1015 return 1016 1017 eval_size = min(self.config.eval_size, len(self._eval_items)) 1018 eval_items = self._eval_items[:eval_size] 1019 1020 logger.info("Running eval on %d coding tasks...", len(eval_items)) 1021 start_time = time.time() 1022 samples = [] 1023 1024 tools, valid_names = self._resolve_tools_for_group() 1025 1026 for i, item in enumerate(eval_items): 1027 task_id = str(uuid.uuid4()) 1028 logger.info( 1029 "Eval [%d/%d]: %s...", i + 1, len(eval_items), item["task"][:60] 1030 ) 1031 1032 try: 1033 messages: List[Dict[str, Any]] = [] 1034 if self.config.system_prompt: 1035 messages.append( 1036 {"role": "system", "content": self.config.system_prompt} 1037 ) 1038 messages.append( 1039 {"role": "user", "content": self.format_prompt(item)} 1040 ) 1041 1042 agent = HermesAgentLoop( 1043 server=self.server, 1044 tool_schemas=tools, 1045 valid_tool_names=valid_names, 1046 max_turns=self.config.max_agent_turns, 1047 task_id=task_id, 1048 temperature=0.0, 1049 max_tokens=self.config.max_token_length, 1050 extra_body=self.config.extra_body, 1051 budget_config=self.config.build_budget_config(), 1052 ) 1053 result = await agent.run(messages) 1054 1055 # Compute reward (track buffer lengths to rollback eval pollution) 1056 buf_len = len(self._correctness_buffer) 1057 ctx = ToolContext(task_id) 1058 try: 1059 reward = await self.compute_reward(item, result, ctx) 1060 finally: 1061 ctx.cleanup() 1062 1063 # Extract correctness and rollback training buffers 1064 correctness = ( 1065 self._correctness_buffer[buf_len] 1066 if len(self._correctness_buffer) > buf_len 1067 else 0.0 1068 ) 1069 for buf in ( 1070 self._reward_buffer, 1071 self._correctness_buffer, 1072 self._efficiency_buffer, 1073 self._tool_usage_buffer, 1074 ): 1075 if len(buf) > buf_len: 1076 buf.pop() 1077 1078 # Also rollback OPD buffers if they were touched 1079 for buf in ( 1080 self._hints_extracted_buffer, 1081 self._opd_turns_scored_buffer, 1082 ): 1083 if len(buf) > buf_len: 1084 buf.pop() 1085 1086 # Extract final response 1087 final_response = "" 1088 for msg in reversed(result.messages): 1089 if ( 1090 msg.get("role") == "assistant" 1091 and msg.get("content") 1092 and not final_response 1093 ): 1094 final_response = msg["content"] 1095 break 1096 1097 samples.append( 1098 { 1099 "prompt": item["task"][:200], 1100 "response": final_response[:500], 1101 "correctness": correctness, 1102 "reward": reward, 1103 "turns": result.turns_used, 1104 } 1105 ) 1106 1107 logger.info( 1108 " → correctness=%.2f, reward=%.3f, turns=%d", 1109 correctness, 1110 reward, 1111 result.turns_used, 1112 ) 1113 1114 except Exception as e: 1115 logger.error("Eval error: %s", e) 1116 samples.append( 1117 { 1118 "prompt": item["task"][:200], 1119 "response": f"ERROR: {e}", 1120 "correctness": 0.0, 1121 "reward": 0.0, 1122 "turns": 0, 1123 } 1124 ) 1125 1126 end_time = time.time() 1127 1128 correctness_scores = [s["correctness"] for s in samples] 1129 rewards = [s["reward"] for s in samples] 1130 n = len(samples) 1131 1132 eval_metrics = { 1133 "eval/mean_correctness": sum(correctness_scores) / n if n else 0.0, 1134 "eval/mean_reward": sum(rewards) / n if n else 0.0, 1135 "eval/pass_rate": ( 1136 sum(1 for c in correctness_scores if c >= 0.8) / n if n else 0.0 1137 ), 1138 "eval/n_items": n, 1139 } 1140 1141 logger.info( 1142 "Eval complete — correctness=%.3f, reward=%.3f, pass_rate=%.0f%%", 1143 eval_metrics["eval/mean_correctness"], 1144 eval_metrics["eval/mean_reward"], 1145 eval_metrics["eval/pass_rate"] * 100, 1146 ) 1147 1148 await self.evaluate_log( 1149 metrics=eval_metrics, 1150 samples=samples, 1151 start_time=start_time, 1152 end_time=end_time, 1153 ) 1154 1155 # ═══════════════════════════════════════════════════════════════════ 1156 # 7. wandb_log — custom OPD metrics 1157 # ═══════════════════════════════════════════════════════════════════ 1158 1159 async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None: 1160 """Log reward breakdown and OPD-specific metrics to wandb.""" 1161 if wandb_metrics is None: 1162 wandb_metrics = {} 1163 1164 if self._reward_buffer: 1165 n = len(self._reward_buffer) 1166 wandb_metrics["train/mean_reward"] = sum(self._reward_buffer) / n 1167 wandb_metrics["train/mean_correctness"] = ( 1168 sum(self._correctness_buffer) / n 1169 ) 1170 wandb_metrics["train/mean_efficiency"] = ( 1171 sum(self._efficiency_buffer) / n 1172 ) 1173 wandb_metrics["train/mean_tool_usage"] = ( 1174 sum(self._tool_usage_buffer) / n 1175 ) 1176 wandb_metrics["train/pass_rate"] = ( 1177 sum(1 for c in self._correctness_buffer if c >= 0.8) / n 1178 ) 1179 wandb_metrics["train/total_rollouts"] = n 1180 1181 self._reward_buffer.clear() 1182 self._correctness_buffer.clear() 1183 self._efficiency_buffer.clear() 1184 self._tool_usage_buffer.clear() 1185 1186 # OPD-specific metrics 1187 if self._hints_extracted_buffer: 1188 n = len(self._hints_extracted_buffer) 1189 wandb_metrics["opd/mean_hints_per_rollout"] = ( 1190 sum(self._hints_extracted_buffer) / n 1191 ) 1192 wandb_metrics["opd/mean_turns_scored"] = ( 1193 sum(self._opd_turns_scored_buffer) / n 1194 ) 1195 wandb_metrics["opd/hint_rate"] = ( 1196 sum(1 for h in self._hints_extracted_buffer if h > 0) / n 1197 ) 1198 wandb_metrics["opd/total_hints"] = sum(self._hints_extracted_buffer) 1199 wandb_metrics["opd/total_scored_turns"] = sum( 1200 self._opd_turns_scored_buffer 1201 ) 1202 1203 self._hints_extracted_buffer.clear() 1204 self._opd_turns_scored_buffer.clear() 1205 1206 await super().wandb_log(wandb_metrics) 1207 1208 1209 # ═══════════════════════════════════════════════════════════════════════ 1210 # Entry point 1211 # ═══════════════════════════════════════════════════════════════════════ 1212 1213 if __name__ == "__main__": 1214 AgenticOPDEnv.cli()