/ environments / tool_context.py
tool_context.py
  1  """
  2  ToolContext -- Unrestricted Tool Access for Reward Functions
  3  
  4  A per-rollout handle that gives reward/verification functions direct access to
  5  ALL hermes-agent tools, scoped to the rollout's task_id. The same task_id means
  6  the terminal/browser session is the SAME one the model used during its rollout --
  7  all state (files, processes, browser tabs) is preserved.
  8  
  9  The verifier author decides which tools to use. Nothing is hardcoded or gated.
 10  
 11  Example usage in a compute_reward():
 12      async def compute_reward(self, item, result, ctx):
 13          # Run tests in the model's terminal sandbox
 14          test = ctx.terminal("pytest -v")
 15          if test["exit_code"] == 0:
 16              return 1.0
 17  
 18          # Check if a file was created
 19          content = ctx.read_file("/workspace/solution.py")
 20          if content.get("content"):
 21              return 0.5
 22  
 23          return 0.0
 24  """
 25  
 26  import json
 27  import logging
 28  import os
 29  from typing import Any, Dict, List, Optional
 30  
 31  import asyncio
 32  import concurrent.futures
 33  
 34  from model_tools import handle_function_call
 35  from tools.terminal_tool import cleanup_vm
 36  from tools.browser_tool import cleanup_browser
 37  
 38  logger = logging.getLogger(__name__)
 39  
 40  # Thread pool for running sync tool calls that internally use asyncio.run()
 41  _tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
 42  
 43  
 44  def _run_tool_in_thread(tool_name: str, arguments: Dict[str, Any], task_id: str) -> str:
 45      """
 46      Run a tool call in a thread pool executor so backends that use asyncio.run()
 47      internally (modal, docker, daytona) get a clean event loop.
 48  
 49      If we're already in an async context, executes handle_function_call() in a
 50      disposable worker thread and blocks for the result.
 51      If not (e.g., called from sync code), runs directly.
 52      """
 53      try:
 54          loop = asyncio.get_running_loop()
 55          # We're in an async context -- need to run in thread
 56          with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
 57              future = pool.submit(
 58                  handle_function_call, tool_name, arguments, task_id
 59              )
 60              return future.result(timeout=300)
 61      except RuntimeError:
 62          # No running event loop -- safe to call directly
 63          return handle_function_call(tool_name, arguments, task_id)
 64  
 65  
 66  class ToolContext:
 67      """
 68      Open-ended access to all hermes-agent tools for a specific rollout.
 69  
 70      Passed to compute_reward() so verifiers can use any tool they need:
 71      terminal commands, file reads/writes, web searches, browser automation, etc.
 72      All calls share the rollout's task_id for session isolation.
 73      """
 74  
 75      def __init__(self, task_id: str):
 76          self.task_id = task_id
 77  
 78      # -------------------------------------------------------------------------
 79      # Terminal tools
 80      # -------------------------------------------------------------------------
 81  
 82      def terminal(self, command: str, timeout: int = 180) -> Dict[str, Any]:
 83          """
 84          Run a command in the rollout's terminal session.
 85  
 86          Args:
 87              command: Shell command to execute
 88              timeout: Command timeout in seconds
 89  
 90          Returns:
 91              Dict with 'exit_code' (int) and 'output' (str)
 92          """
 93          import os
 94          backend = os.getenv("TERMINAL_ENV", "local")
 95          logger.debug("ToolContext.terminal [%s backend] task=%s: %s", backend, self.task_id[:8], command[:100])
 96  
 97          # Run via thread helper so modal/docker/daytona backends' asyncio.run() doesn't deadlock
 98          result = _run_tool_in_thread(
 99              "terminal",
100              {"command": command, "timeout": timeout},
101              self.task_id,
102          )
103          try:
104              return json.loads(result)
105          except json.JSONDecodeError:
106              return {"exit_code": -1, "output": result}
107  
108      # -------------------------------------------------------------------------
109      # File tools
110      # -------------------------------------------------------------------------
111  
112      def read_file(self, path: str) -> Dict[str, Any]:
113          """
114          Read a file from the rollout's filesystem.
115  
116          Args:
117              path: File path to read
118  
119          Returns:
120              Dict with file content or error
121          """
122          result = handle_function_call(
123              "read_file", {"path": path}, task_id=self.task_id
124          )
125          try:
126              return json.loads(result)
127          except json.JSONDecodeError:
128              return {"error": result}
129  
130      def write_file(self, path: str, content: str) -> Dict[str, Any]:
131          """
132          Write a TEXT file in the rollout's filesystem.
133  
134          Uses a shell heredoc under the hood, so this is only safe for text content.
135          For binary files (images, compiled artifacts, etc.), use upload_file() instead.
136  
137          Args:
138              path: File path to write
139              content: Text content to write
140  
141          Returns:
142              Dict with success status or error
143          """
144          result = handle_function_call(
145              "write_file", {"path": path, "content": content}, task_id=self.task_id
146          )
147          try:
148              return json.loads(result)
149          except json.JSONDecodeError:
150              return {"error": result}
151  
152      def upload_file(self, local_path: str, remote_path: str) -> Dict[str, Any]:
153          """
154          Upload a local file to the rollout's sandbox (binary-safe).
155  
156          Unlike write_file() which passes content through a shell heredoc (text-only),
157          this method base64-encodes the file and decodes it inside the sandbox.
158          Safe for any file type: binaries, images, archives, etc.
159  
160          For large files (>1MB), the content is split into chunks to avoid
161          hitting shell command-length limits.
162  
163          Args:
164              local_path: Path to a local file on the host
165              remote_path: Destination path inside the sandbox
166  
167          Returns:
168              Dict with 'exit_code' and 'output'
169          """
170          import base64
171          from pathlib import Path as _Path
172  
173          local = _Path(local_path)
174          if not local.exists():
175              return {"exit_code": -1, "output": f"Local file not found: {local_path}"}
176  
177          raw = local.read_bytes()
178          b64 = base64.b64encode(raw).decode("ascii")
179  
180          # Ensure parent directory exists in the sandbox
181          parent = str(_Path(remote_path).parent)
182          if parent not in (".", "/"):
183              self.terminal(f"mkdir -p {parent}", timeout=10)
184  
185          # For small files, single command is fine
186          chunk_size = 60_000  # ~60KB per chunk (well within shell limits)
187          if len(b64) <= chunk_size:
188              result = self.terminal(
189                  f"printf '%s' '{b64}' | base64 -d > {remote_path}",
190                  timeout=30,
191              )
192          else:
193              # For larger files, write base64 in chunks then decode
194              tmp_b64 = "/tmp/_hermes_upload.b64"
195              self.terminal(f": > {tmp_b64}", timeout=5)  # truncate
196              for i in range(0, len(b64), chunk_size):
197                  chunk = b64[i : i + chunk_size]
198                  self.terminal(f"printf '%s' '{chunk}' >> {tmp_b64}", timeout=15)
199              result = self.terminal(
200                  f"base64 -d {tmp_b64} > {remote_path} && rm -f {tmp_b64}",
201                  timeout=30,
202              )
203  
204          return result
205  
206      def upload_dir(self, local_dir: str, remote_dir: str) -> List[Dict[str, Any]]:
207          """
208          Upload an entire local directory to the rollout's sandbox (binary-safe).
209  
210          Recursively uploads all files, preserving directory structure.
211  
212          Args:
213              local_dir: Path to a local directory on the host
214              remote_dir: Destination directory inside the sandbox
215  
216          Returns:
217              List of results, one per file uploaded
218          """
219          from pathlib import Path as _Path
220  
221          local = _Path(local_dir)
222          if not local.exists() or not local.is_dir():
223              return [{"exit_code": -1, "output": f"Local directory not found: {local_dir}"}]
224  
225          results = []
226          for file_path in sorted(local.rglob("*")):
227              if file_path.is_file():
228                  relative = file_path.relative_to(local)
229                  target = f"{remote_dir}/{relative}"
230                  results.append(self.upload_file(str(file_path), target))
231          return results
232  
233      def download_file(self, remote_path: str, local_path: str) -> Dict[str, Any]:
234          """
235          Download a file from the rollout's sandbox to the host (binary-safe).
236  
237          The inverse of upload_file(). Base64-encodes the file inside the sandbox,
238          reads the encoded data through the terminal, and decodes it locally.
239          Safe for any file type.
240  
241          Args:
242              remote_path: Path to the file inside the sandbox
243              local_path: Destination path on the host
244  
245          Returns:
246              Dict with 'success' (bool) and 'bytes' (int) or 'error' (str)
247          """
248          import base64
249          from pathlib import Path as _Path
250  
251          # Base64-encode the file inside the sandbox and capture output
252          result = self.terminal(
253              f"base64 {remote_path} 2>/dev/null",
254              timeout=30,
255          )
256  
257          if result.get("exit_code", -1) != 0:
258              return {
259                  "success": False,
260                  "error": f"Failed to read remote file: {result.get('output', '')}",
261              }
262  
263          b64_data = result.get("output", "").strip()
264          if not b64_data:
265              return {"success": False, "error": f"Remote file is empty or missing: {remote_path}"}
266  
267          try:
268              raw = base64.b64decode(b64_data)
269          except Exception as e:
270              return {"success": False, "error": f"Base64 decode failed: {e}"}
271  
272          # Write to local host filesystem
273          local = _Path(local_path)
274          local.parent.mkdir(parents=True, exist_ok=True)
275          local.write_bytes(raw)
276  
277          return {"success": True, "bytes": len(raw)}
278  
279      def download_dir(self, remote_dir: str, local_dir: str) -> List[Dict[str, Any]]:
280          """
281          Download a directory from the rollout's sandbox to the host (binary-safe).
282  
283          Lists all files in the remote directory, then downloads each one.
284          Preserves directory structure.
285  
286          Args:
287              remote_dir: Path to the directory inside the sandbox
288              local_dir: Destination directory on the host
289  
290          Returns:
291              List of results, one per file downloaded
292          """
293          from pathlib import Path as _Path
294  
295          # List files in the remote directory
296          ls_result = self.terminal(
297              f"find {remote_dir} -type f 2>/dev/null",
298              timeout=15,
299          )
300  
301          if ls_result.get("exit_code", -1) != 0:
302              return [{"success": False, "error": f"Failed to list remote dir: {remote_dir}"}]
303  
304          file_list = ls_result.get("output", "").strip()
305          if not file_list:
306              return [{"success": False, "error": f"Remote directory is empty or missing: {remote_dir}"}]
307  
308          results = []
309          for remote_file in file_list.splitlines():
310              remote_file = remote_file.strip()
311              if not remote_file:
312                  continue
313              # Compute the relative path to preserve directory structure
314              if remote_file.startswith(remote_dir):
315                  relative = remote_file[len(remote_dir):].lstrip("/")
316              else:
317                  relative = _Path(remote_file).name
318              local_file = str(_Path(local_dir) / relative)
319              results.append(self.download_file(remote_file, local_file))
320  
321          return results
322  
323      def search(self, query: str, path: str = ".") -> Dict[str, Any]:
324          """
325          Search for text in the rollout's filesystem.
326  
327          Args:
328              query: Search query
329              path: Directory to search in
330  
331          Returns:
332              Dict with search results
333          """
334          result = handle_function_call(
335              "search_files", {"pattern": query, "path": path}, task_id=self.task_id
336          )
337          try:
338              return json.loads(result)
339          except json.JSONDecodeError:
340              return {"error": result}
341  
342      # -------------------------------------------------------------------------
343      # Web tools
344      # -------------------------------------------------------------------------
345  
346      def web_search(self, query: str) -> Dict[str, Any]:
347          """
348          Search the web.
349  
350          Args:
351              query: Search query
352  
353          Returns:
354              Dict with search results
355          """
356          result = handle_function_call("web_search", {"query": query})
357          try:
358              return json.loads(result)
359          except json.JSONDecodeError:
360              return {"error": result}
361  
362      def web_extract(self, urls: List[str]) -> Dict[str, Any]:
363          """
364          Extract content from URLs.
365  
366          Args:
367              urls: List of URLs to extract content from
368  
369          Returns:
370              Dict with extracted content
371          """
372          result = handle_function_call("web_extract", {"urls": urls})
373          try:
374              return json.loads(result)
375          except json.JSONDecodeError:
376              return {"error": result}
377  
378      # -------------------------------------------------------------------------
379      # Browser tools
380      # -------------------------------------------------------------------------
381  
382      def browser_navigate(self, url: str) -> Dict[str, Any]:
383          """
384          Navigate the rollout's browser session to a URL.
385  
386          Args:
387              url: URL to navigate to
388  
389          Returns:
390              Dict with page snapshot or error
391          """
392          result = handle_function_call(
393              "browser_navigate", {"url": url}, task_id=self.task_id
394          )
395          try:
396              return json.loads(result)
397          except json.JSONDecodeError:
398              return {"error": result}
399  
400      def browser_snapshot(self) -> Dict[str, Any]:
401          """
402          Take a snapshot of the current browser page.
403  
404          Returns:
405              Dict with page content/accessibility snapshot
406          """
407          result = handle_function_call(
408              "browser_snapshot", {}, task_id=self.task_id
409          )
410          try:
411              return json.loads(result)
412          except json.JSONDecodeError:
413              return {"error": result}
414  
415      # -------------------------------------------------------------------------
416      # Generic tool access
417      # -------------------------------------------------------------------------
418  
419      def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
420          """
421          Call any hermes-agent tool by name.
422  
423          This is the generic escape hatch -- if a tool doesn't have a convenience
424          wrapper above, you can call it directly here.
425  
426          Args:
427              tool_name: Name of the tool (e.g., "vision_analyze", "skills_list")
428              arguments: Dict of arguments for the tool
429  
430          Returns:
431              Raw JSON string result from the tool
432          """
433          return _run_tool_in_thread(tool_name, arguments, self.task_id)
434  
435      # -------------------------------------------------------------------------
436      # Cleanup
437      # -------------------------------------------------------------------------
438  
439      def cleanup(self):
440          """
441          Release all resources (terminal VMs, browser sessions, background processes)
442          for this rollout.
443  
444          Called automatically by the base environment via try/finally after
445          compute_reward() completes. You generally don't need to call this yourself.
446          """
447          # Kill any background processes from this rollout (safety net)
448          try:
449              from tools.process_registry import process_registry
450              killed = process_registry.kill_all(task_id=self.task_id)
451              if killed:
452                  logger.debug("Process cleanup for task %s: killed %d process(es)", self.task_id, killed)
453          except Exception as e:
454              logger.debug("Process cleanup for task %s: %s", self.task_id, e)
455  
456          try:
457              cleanup_vm(self.task_id)
458          except Exception as e:
459              logger.debug("VM cleanup for task %s: %s", self.task_id, e)
460  
461          # Suppress browser_tool's noisy debug prints during cleanup.
462          # The cleanup still runs (safe), it just doesn't spam the console.
463          _prev_quiet = os.environ.get("HERMES_QUIET")
464          os.environ["HERMES_QUIET"] = "1"
465          try:
466              cleanup_browser(self.task_id)
467          except Exception as e:
468              logger.debug("Browser cleanup for task %s: %s", self.task_id, e)
469          finally:
470              if _prev_quiet is None:
471                  os.environ.pop("HERMES_QUIET", None)
472              else:
473                  os.environ["HERMES_QUIET"] = _prev_quiet