/ environments / tool_context.py
tool_context.py
1 """ 2 ToolContext -- Unrestricted Tool Access for Reward Functions 3 4 A per-rollout handle that gives reward/verification functions direct access to 5 ALL hermes-agent tools, scoped to the rollout's task_id. The same task_id means 6 the terminal/browser session is the SAME one the model used during its rollout -- 7 all state (files, processes, browser tabs) is preserved. 8 9 The verifier author decides which tools to use. Nothing is hardcoded or gated. 10 11 Example usage in a compute_reward(): 12 async def compute_reward(self, item, result, ctx): 13 # Run tests in the model's terminal sandbox 14 test = ctx.terminal("pytest -v") 15 if test["exit_code"] == 0: 16 return 1.0 17 18 # Check if a file was created 19 content = ctx.read_file("/workspace/solution.py") 20 if content.get("content"): 21 return 0.5 22 23 return 0.0 24 """ 25 26 import json 27 import logging 28 import os 29 from typing import Any, Dict, List, Optional 30 31 import asyncio 32 import concurrent.futures 33 34 from model_tools import handle_function_call 35 from tools.terminal_tool import cleanup_vm 36 from tools.browser_tool import cleanup_browser 37 38 logger = logging.getLogger(__name__) 39 40 # Thread pool for running sync tool calls that internally use asyncio.run() 41 _tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=4) 42 43 44 def _run_tool_in_thread(tool_name: str, arguments: Dict[str, Any], task_id: str) -> str: 45 """ 46 Run a tool call in a thread pool executor so backends that use asyncio.run() 47 internally (modal, docker, daytona) get a clean event loop. 48 49 If we're already in an async context, executes handle_function_call() in a 50 disposable worker thread and blocks for the result. 51 If not (e.g., called from sync code), runs directly. 52 """ 53 try: 54 loop = asyncio.get_running_loop() 55 # We're in an async context -- need to run in thread 56 with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: 57 future = pool.submit( 58 handle_function_call, tool_name, arguments, task_id 59 ) 60 return future.result(timeout=300) 61 except RuntimeError: 62 # No running event loop -- safe to call directly 63 return handle_function_call(tool_name, arguments, task_id) 64 65 66 class ToolContext: 67 """ 68 Open-ended access to all hermes-agent tools for a specific rollout. 69 70 Passed to compute_reward() so verifiers can use any tool they need: 71 terminal commands, file reads/writes, web searches, browser automation, etc. 72 All calls share the rollout's task_id for session isolation. 73 """ 74 75 def __init__(self, task_id: str): 76 self.task_id = task_id 77 78 # ------------------------------------------------------------------------- 79 # Terminal tools 80 # ------------------------------------------------------------------------- 81 82 def terminal(self, command: str, timeout: int = 180) -> Dict[str, Any]: 83 """ 84 Run a command in the rollout's terminal session. 85 86 Args: 87 command: Shell command to execute 88 timeout: Command timeout in seconds 89 90 Returns: 91 Dict with 'exit_code' (int) and 'output' (str) 92 """ 93 import os 94 backend = os.getenv("TERMINAL_ENV", "local") 95 logger.debug("ToolContext.terminal [%s backend] task=%s: %s", backend, self.task_id[:8], command[:100]) 96 97 # Run via thread helper so modal/docker/daytona backends' asyncio.run() doesn't deadlock 98 result = _run_tool_in_thread( 99 "terminal", 100 {"command": command, "timeout": timeout}, 101 self.task_id, 102 ) 103 try: 104 return json.loads(result) 105 except json.JSONDecodeError: 106 return {"exit_code": -1, "output": result} 107 108 # ------------------------------------------------------------------------- 109 # File tools 110 # ------------------------------------------------------------------------- 111 112 def read_file(self, path: str) -> Dict[str, Any]: 113 """ 114 Read a file from the rollout's filesystem. 115 116 Args: 117 path: File path to read 118 119 Returns: 120 Dict with file content or error 121 """ 122 result = handle_function_call( 123 "read_file", {"path": path}, task_id=self.task_id 124 ) 125 try: 126 return json.loads(result) 127 except json.JSONDecodeError: 128 return {"error": result} 129 130 def write_file(self, path: str, content: str) -> Dict[str, Any]: 131 """ 132 Write a TEXT file in the rollout's filesystem. 133 134 Uses a shell heredoc under the hood, so this is only safe for text content. 135 For binary files (images, compiled artifacts, etc.), use upload_file() instead. 136 137 Args: 138 path: File path to write 139 content: Text content to write 140 141 Returns: 142 Dict with success status or error 143 """ 144 result = handle_function_call( 145 "write_file", {"path": path, "content": content}, task_id=self.task_id 146 ) 147 try: 148 return json.loads(result) 149 except json.JSONDecodeError: 150 return {"error": result} 151 152 def upload_file(self, local_path: str, remote_path: str) -> Dict[str, Any]: 153 """ 154 Upload a local file to the rollout's sandbox (binary-safe). 155 156 Unlike write_file() which passes content through a shell heredoc (text-only), 157 this method base64-encodes the file and decodes it inside the sandbox. 158 Safe for any file type: binaries, images, archives, etc. 159 160 For large files (>1MB), the content is split into chunks to avoid 161 hitting shell command-length limits. 162 163 Args: 164 local_path: Path to a local file on the host 165 remote_path: Destination path inside the sandbox 166 167 Returns: 168 Dict with 'exit_code' and 'output' 169 """ 170 import base64 171 from pathlib import Path as _Path 172 173 local = _Path(local_path) 174 if not local.exists(): 175 return {"exit_code": -1, "output": f"Local file not found: {local_path}"} 176 177 raw = local.read_bytes() 178 b64 = base64.b64encode(raw).decode("ascii") 179 180 # Ensure parent directory exists in the sandbox 181 parent = str(_Path(remote_path).parent) 182 if parent not in (".", "/"): 183 self.terminal(f"mkdir -p {parent}", timeout=10) 184 185 # For small files, single command is fine 186 chunk_size = 60_000 # ~60KB per chunk (well within shell limits) 187 if len(b64) <= chunk_size: 188 result = self.terminal( 189 f"printf '%s' '{b64}' | base64 -d > {remote_path}", 190 timeout=30, 191 ) 192 else: 193 # For larger files, write base64 in chunks then decode 194 tmp_b64 = "/tmp/_hermes_upload.b64" 195 self.terminal(f": > {tmp_b64}", timeout=5) # truncate 196 for i in range(0, len(b64), chunk_size): 197 chunk = b64[i : i + chunk_size] 198 self.terminal(f"printf '%s' '{chunk}' >> {tmp_b64}", timeout=15) 199 result = self.terminal( 200 f"base64 -d {tmp_b64} > {remote_path} && rm -f {tmp_b64}", 201 timeout=30, 202 ) 203 204 return result 205 206 def upload_dir(self, local_dir: str, remote_dir: str) -> List[Dict[str, Any]]: 207 """ 208 Upload an entire local directory to the rollout's sandbox (binary-safe). 209 210 Recursively uploads all files, preserving directory structure. 211 212 Args: 213 local_dir: Path to a local directory on the host 214 remote_dir: Destination directory inside the sandbox 215 216 Returns: 217 List of results, one per file uploaded 218 """ 219 from pathlib import Path as _Path 220 221 local = _Path(local_dir) 222 if not local.exists() or not local.is_dir(): 223 return [{"exit_code": -1, "output": f"Local directory not found: {local_dir}"}] 224 225 results = [] 226 for file_path in sorted(local.rglob("*")): 227 if file_path.is_file(): 228 relative = file_path.relative_to(local) 229 target = f"{remote_dir}/{relative}" 230 results.append(self.upload_file(str(file_path), target)) 231 return results 232 233 def download_file(self, remote_path: str, local_path: str) -> Dict[str, Any]: 234 """ 235 Download a file from the rollout's sandbox to the host (binary-safe). 236 237 The inverse of upload_file(). Base64-encodes the file inside the sandbox, 238 reads the encoded data through the terminal, and decodes it locally. 239 Safe for any file type. 240 241 Args: 242 remote_path: Path to the file inside the sandbox 243 local_path: Destination path on the host 244 245 Returns: 246 Dict with 'success' (bool) and 'bytes' (int) or 'error' (str) 247 """ 248 import base64 249 from pathlib import Path as _Path 250 251 # Base64-encode the file inside the sandbox and capture output 252 result = self.terminal( 253 f"base64 {remote_path} 2>/dev/null", 254 timeout=30, 255 ) 256 257 if result.get("exit_code", -1) != 0: 258 return { 259 "success": False, 260 "error": f"Failed to read remote file: {result.get('output', '')}", 261 } 262 263 b64_data = result.get("output", "").strip() 264 if not b64_data: 265 return {"success": False, "error": f"Remote file is empty or missing: {remote_path}"} 266 267 try: 268 raw = base64.b64decode(b64_data) 269 except Exception as e: 270 return {"success": False, "error": f"Base64 decode failed: {e}"} 271 272 # Write to local host filesystem 273 local = _Path(local_path) 274 local.parent.mkdir(parents=True, exist_ok=True) 275 local.write_bytes(raw) 276 277 return {"success": True, "bytes": len(raw)} 278 279 def download_dir(self, remote_dir: str, local_dir: str) -> List[Dict[str, Any]]: 280 """ 281 Download a directory from the rollout's sandbox to the host (binary-safe). 282 283 Lists all files in the remote directory, then downloads each one. 284 Preserves directory structure. 285 286 Args: 287 remote_dir: Path to the directory inside the sandbox 288 local_dir: Destination directory on the host 289 290 Returns: 291 List of results, one per file downloaded 292 """ 293 from pathlib import Path as _Path 294 295 # List files in the remote directory 296 ls_result = self.terminal( 297 f"find {remote_dir} -type f 2>/dev/null", 298 timeout=15, 299 ) 300 301 if ls_result.get("exit_code", -1) != 0: 302 return [{"success": False, "error": f"Failed to list remote dir: {remote_dir}"}] 303 304 file_list = ls_result.get("output", "").strip() 305 if not file_list: 306 return [{"success": False, "error": f"Remote directory is empty or missing: {remote_dir}"}] 307 308 results = [] 309 for remote_file in file_list.splitlines(): 310 remote_file = remote_file.strip() 311 if not remote_file: 312 continue 313 # Compute the relative path to preserve directory structure 314 if remote_file.startswith(remote_dir): 315 relative = remote_file[len(remote_dir):].lstrip("/") 316 else: 317 relative = _Path(remote_file).name 318 local_file = str(_Path(local_dir) / relative) 319 results.append(self.download_file(remote_file, local_file)) 320 321 return results 322 323 def search(self, query: str, path: str = ".") -> Dict[str, Any]: 324 """ 325 Search for text in the rollout's filesystem. 326 327 Args: 328 query: Search query 329 path: Directory to search in 330 331 Returns: 332 Dict with search results 333 """ 334 result = handle_function_call( 335 "search_files", {"pattern": query, "path": path}, task_id=self.task_id 336 ) 337 try: 338 return json.loads(result) 339 except json.JSONDecodeError: 340 return {"error": result} 341 342 # ------------------------------------------------------------------------- 343 # Web tools 344 # ------------------------------------------------------------------------- 345 346 def web_search(self, query: str) -> Dict[str, Any]: 347 """ 348 Search the web. 349 350 Args: 351 query: Search query 352 353 Returns: 354 Dict with search results 355 """ 356 result = handle_function_call("web_search", {"query": query}) 357 try: 358 return json.loads(result) 359 except json.JSONDecodeError: 360 return {"error": result} 361 362 def web_extract(self, urls: List[str]) -> Dict[str, Any]: 363 """ 364 Extract content from URLs. 365 366 Args: 367 urls: List of URLs to extract content from 368 369 Returns: 370 Dict with extracted content 371 """ 372 result = handle_function_call("web_extract", {"urls": urls}) 373 try: 374 return json.loads(result) 375 except json.JSONDecodeError: 376 return {"error": result} 377 378 # ------------------------------------------------------------------------- 379 # Browser tools 380 # ------------------------------------------------------------------------- 381 382 def browser_navigate(self, url: str) -> Dict[str, Any]: 383 """ 384 Navigate the rollout's browser session to a URL. 385 386 Args: 387 url: URL to navigate to 388 389 Returns: 390 Dict with page snapshot or error 391 """ 392 result = handle_function_call( 393 "browser_navigate", {"url": url}, task_id=self.task_id 394 ) 395 try: 396 return json.loads(result) 397 except json.JSONDecodeError: 398 return {"error": result} 399 400 def browser_snapshot(self) -> Dict[str, Any]: 401 """ 402 Take a snapshot of the current browser page. 403 404 Returns: 405 Dict with page content/accessibility snapshot 406 """ 407 result = handle_function_call( 408 "browser_snapshot", {}, task_id=self.task_id 409 ) 410 try: 411 return json.loads(result) 412 except json.JSONDecodeError: 413 return {"error": result} 414 415 # ------------------------------------------------------------------------- 416 # Generic tool access 417 # ------------------------------------------------------------------------- 418 419 def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str: 420 """ 421 Call any hermes-agent tool by name. 422 423 This is the generic escape hatch -- if a tool doesn't have a convenience 424 wrapper above, you can call it directly here. 425 426 Args: 427 tool_name: Name of the tool (e.g., "vision_analyze", "skills_list") 428 arguments: Dict of arguments for the tool 429 430 Returns: 431 Raw JSON string result from the tool 432 """ 433 return _run_tool_in_thread(tool_name, arguments, self.task_id) 434 435 # ------------------------------------------------------------------------- 436 # Cleanup 437 # ------------------------------------------------------------------------- 438 439 def cleanup(self): 440 """ 441 Release all resources (terminal VMs, browser sessions, background processes) 442 for this rollout. 443 444 Called automatically by the base environment via try/finally after 445 compute_reward() completes. You generally don't need to call this yourself. 446 """ 447 # Kill any background processes from this rollout (safety net) 448 try: 449 from tools.process_registry import process_registry 450 killed = process_registry.kill_all(task_id=self.task_id) 451 if killed: 452 logger.debug("Process cleanup for task %s: killed %d process(es)", self.task_id, killed) 453 except Exception as e: 454 logger.debug("Process cleanup for task %s: %s", self.task_id, e) 455 456 try: 457 cleanup_vm(self.task_id) 458 except Exception as e: 459 logger.debug("VM cleanup for task %s: %s", self.task_id, e) 460 461 # Suppress browser_tool's noisy debug prints during cleanup. 462 # The cleanup still runs (safe), it just doesn't spam the console. 463 _prev_quiet = os.environ.get("HERMES_QUIET") 464 os.environ["HERMES_QUIET"] = "1" 465 try: 466 cleanup_browser(self.task_id) 467 except Exception as e: 468 logger.debug("Browser cleanup for task %s: %s", self.task_id, e) 469 finally: 470 if _prev_quiet is None: 471 os.environ.pop("HERMES_QUIET", None) 472 else: 473 os.environ["HERMES_QUIET"] = _prev_quiet