/ tests / tools / test_code_execution.py
test_code_execution.py
  1  #!/usr/bin/env python3
  2  """
  3  
  4  Tests for the code execution sandbox (programmatic tool calling).
  5  
  6  These tests monkeypatch handle_function_call so they don't require API keys
  7  or a running terminal backend. They verify the core sandbox mechanics:
  8  UDS socket lifecycle, hermes_tools generation, timeout enforcement,
  9  output capping, tool call counting, and error propagation.
 10  
 11  Run with:  python -m pytest tests/test_code_execution.py -v
 12     or:     python tests/test_code_execution.py
 13  """
 14  
 15  import pytest
 16  # pytestmark removed — tests run fine (61 pass, ~99s)
 17  
 18  import json
 19  import os
 20  
 21  os.environ["TERMINAL_ENV"] = "local"
 22  
 23  
 24  @pytest.fixture(autouse=True)
 25  def _force_local_terminal(monkeypatch):
 26      """Re-set TERMINAL_ENV=local before every test.
 27  
 28      The module-level assignment above covers import time, but under xdist
 29      another worker can overwrite os.environ between tests.  monkeypatch
 30      ensures each test starts (and ends) with the correct value.
 31      """
 32      monkeypatch.setenv("TERMINAL_ENV", "local")
 33  import sys
 34  import time
 35  import threading
 36  import unittest
 37  from unittest.mock import patch, MagicMock
 38  
 39  from tools.code_execution_tool import (
 40      SANDBOX_ALLOWED_TOOLS,
 41      execute_code,
 42      generate_hermes_tools_module,
 43      check_sandbox_requirements,
 44      build_execute_code_schema,
 45      EXECUTE_CODE_SCHEMA,
 46      _TOOL_DOC_LINES,
 47      _execute_remote,
 48  )
 49  
 50  
 51  def _mock_handle_function_call(function_name, function_args, task_id=None, user_task=None):
 52      """Mock dispatcher that returns canned responses for each tool."""
 53      if function_name == "terminal":
 54          cmd = function_args.get("command", "")
 55          return json.dumps({"output": f"mock output for: {cmd}", "exit_code": 0})
 56      if function_name == "web_search":
 57          return json.dumps({"results": [{"url": "https://example.com", "title": "Example", "description": "A test result"}]})
 58      if function_name == "read_file":
 59          return json.dumps({"content": "line 1\nline 2\nline 3\n", "total_lines": 3})
 60      if function_name == "write_file":
 61          return json.dumps({"status": "ok", "path": function_args.get("path", "")})
 62      if function_name == "search_files":
 63          return json.dumps({"matches": [{"file": "test.py", "line": 1, "text": "match"}]})
 64      if function_name == "patch":
 65          return json.dumps({"status": "ok", "replacements": 1})
 66      if function_name == "web_extract":
 67          return json.dumps("# Extracted content\nSome text from the page.")
 68      return json.dumps({"error": f"Unknown tool in mock: {function_name}"})
 69  
 70  
 71  class TestSandboxRequirements(unittest.TestCase):
 72      def test_available_on_posix(self):
 73          if sys.platform != "win32":
 74              self.assertTrue(check_sandbox_requirements())
 75  
 76      def test_schema_is_valid(self):
 77          self.assertEqual(EXECUTE_CODE_SCHEMA["name"], "execute_code")
 78          self.assertIn("code", EXECUTE_CODE_SCHEMA["parameters"]["properties"])
 79          self.assertIn("code", EXECUTE_CODE_SCHEMA["parameters"]["required"])
 80  
 81  
 82  class TestHermesToolsGeneration(unittest.TestCase):
 83      def test_generates_all_allowed_tools(self):
 84          src = generate_hermes_tools_module(list(SANDBOX_ALLOWED_TOOLS))
 85          for tool in SANDBOX_ALLOWED_TOOLS:
 86              self.assertIn(f"def {tool}(", src)
 87  
 88      def test_generates_subset(self):
 89          src = generate_hermes_tools_module(["terminal", "web_search"])
 90          self.assertIn("def terminal(", src)
 91          self.assertIn("def web_search(", src)
 92          self.assertNotIn("def read_file(", src)
 93  
 94      def test_empty_list_generates_nothing(self):
 95          src = generate_hermes_tools_module([])
 96          self.assertNotIn("def terminal(", src)
 97          self.assertIn("def _call(", src)  # infrastructure still present
 98  
 99      def test_non_allowed_tools_ignored(self):
100          src = generate_hermes_tools_module(["vision_analyze", "terminal"])
101          self.assertIn("def terminal(", src)
102          self.assertNotIn("def vision_analyze(", src)
103  
104      def test_rpc_infrastructure_present(self):
105          src = generate_hermes_tools_module(["terminal"])
106          self.assertIn("HERMES_RPC_SOCKET", src)
107          self.assertIn("AF_UNIX", src)
108          self.assertIn("def _connect(", src)
109          self.assertIn("def _call(", src)
110  
111      def test_convenience_helpers_present(self):
112          """Verify json_parse, shell_quote, and retry helpers are generated."""
113          src = generate_hermes_tools_module(["terminal"])
114          self.assertIn("def json_parse(", src)
115          self.assertIn("def shell_quote(", src)
116          self.assertIn("def retry(", src)
117          self.assertIn("import json, os, socket, shlex, threading, time", src)
118  
119      def test_file_transport_uses_tempfile_fallback_for_rpc_dir(self):
120          src = generate_hermes_tools_module(["terminal"], transport="file")
121          self.assertIn("import json, os, shlex, tempfile, threading, time", src)
122          self.assertIn("os.path.join(tempfile.gettempdir(), \"hermes_rpc\")", src)
123          self.assertNotIn('os.environ.get("HERMES_RPC_DIR", "/tmp/hermes_rpc")', src)
124  
125      def test_uds_transport_serializes_concurrent_calls(self):
126          """Regression: UDS _call() must hold a lock across send+recv so that
127          concurrent tool calls from multiple threads don't interleave on the
128          shared socket and receive each other's responses."""
129          src = generate_hermes_tools_module(["terminal"], transport="uds")
130          self.assertIn("_call_lock = threading.Lock()", src)
131          self.assertIn("with _call_lock:", src)
132  
133      def test_file_transport_serializes_seq_allocation(self):
134          """Regression: file transport _call() must allocate `_seq` under a
135          lock, otherwise concurrent threads can pick the same seq and clobber
136          each other's request files."""
137          src = generate_hermes_tools_module(["terminal"], transport="file")
138          self.assertIn("_seq_lock = threading.Lock()", src)
139          self.assertIn("with _seq_lock:", src)
140  
141  
142  class TestExecuteCodeRemoteTempDir(unittest.TestCase):
143      def test_execute_remote_uses_backend_temp_dir_for_sandbox(self):
144          class FakeEnv:
145              def __init__(self):
146                  self.commands = []
147  
148              def get_temp_dir(self):
149                  return "/data/data/com.termux/files/usr/tmp"
150  
151              def execute(self, command, cwd=None, timeout=None):
152                  self.commands.append((command, cwd, timeout))
153                  if "command -v python3" in command:
154                      return {"output": "OK\n"}
155                  if "python3 script.py" in command:
156                      return {"output": "hello\n", "returncode": 0}
157                  return {"output": ""}
158  
159          env = FakeEnv()
160          fake_thread = MagicMock()
161  
162          with patch("tools.code_execution_tool._load_config", return_value={"timeout": 30, "max_tool_calls": 5}), \
163               patch("tools.code_execution_tool._get_or_create_env", return_value=(env, "ssh")), \
164               patch("tools.code_execution_tool._ship_file_to_remote"), \
165               patch("tools.code_execution_tool.threading.Thread", return_value=fake_thread):
166              result = json.loads(_execute_remote("print('hello')", "task-1", ["terminal"]))
167  
168          self.assertEqual(result["status"], "success")
169          mkdir_cmd = env.commands[1][0]
170          run_cmd = next(cmd for cmd, _, _ in env.commands if "python3 script.py" in cmd)
171          cleanup_cmd = env.commands[-1][0]
172          self.assertIn("mkdir -p /data/data/com.termux/files/usr/tmp/hermes_exec_", mkdir_cmd)
173          self.assertIn("HERMES_RPC_DIR=/data/data/com.termux/files/usr/tmp/hermes_exec_", run_cmd)
174          self.assertIn("rm -rf /data/data/com.termux/files/usr/tmp/hermes_exec_", cleanup_cmd)
175          self.assertNotIn("mkdir -p /tmp/hermes_exec_", mkdir_cmd)
176  
177  
178  @unittest.skipIf(sys.platform == "win32", "UDS not available on Windows")
179  class TestExecuteCode(unittest.TestCase):
180      """Integration tests using the mock dispatcher."""
181  
182      def _run(self, code, enabled_tools=None):
183          """Helper: run code with mocked handle_function_call."""
184          with patch("tools.code_execution_tool._rpc_server_loop") as mock_rpc:
185              # Use real execution but mock the tool dispatcher
186              pass
187          # Actually run with full integration, mocking at the model_tools level
188          with patch("model_tools.handle_function_call", side_effect=_mock_handle_function_call):
189              result = execute_code(
190                  code=code,
191                  task_id="test-task",
192                  enabled_tools=enabled_tools or list(SANDBOX_ALLOWED_TOOLS),
193              )
194          return json.loads(result)
195  
196      def test_basic_print(self):
197          """Script that just prints -- no tool calls."""
198          result = self._run('print("hello world")')
199          self.assertEqual(result["status"], "success")
200          self.assertIn("hello world", result["output"])
201          self.assertEqual(result["tool_calls_made"], 0)
202  
203      def test_repo_root_modules_are_importable(self):
204          """Sandboxed scripts can import modules that live at the repo root."""
205          result = self._run('import hermes_constants; print(hermes_constants.__file__)')
206          self.assertEqual(result["status"], "success")
207          self.assertIn("hermes_constants.py", result["output"])
208  
209      def test_single_tool_call(self):
210          """Script calls terminal and prints the result."""
211          code = """
212  from hermes_tools import terminal
213  result = terminal("echo hello")
214  print(result.get("output", ""))
215  """
216          result = self._run(code)
217          self.assertEqual(result["status"], "success")
218          self.assertIn("mock output for: echo hello", result["output"])
219          self.assertEqual(result["tool_calls_made"], 1)
220  
221      def test_multi_tool_chain(self):
222          """Script calls multiple tools sequentially."""
223          code = """
224  from hermes_tools import terminal, read_file
225  r1 = terminal("ls")
226  r2 = read_file("test.py")
227  print(f"terminal: {r1['output'][:20]}")
228  print(f"file lines: {r2['total_lines']}")
229  """
230          result = self._run(code)
231          self.assertEqual(result["status"], "success")
232          self.assertEqual(result["tool_calls_made"], 2)
233  
234      def test_syntax_error(self):
235          """Script with a syntax error returns error status."""
236          result = self._run("def broken(")
237          self.assertEqual(result["status"], "error")
238          self.assertIn("SyntaxError", result.get("error", "") + result.get("output", ""))
239  
240      def test_runtime_exception(self):
241          """Script with a runtime error returns error status."""
242          result = self._run("raise ValueError('test error')")
243          self.assertEqual(result["status"], "error")
244  
245      def test_concurrent_tool_calls_match_responses(self):
246          """Regression for the UDS RPC race: multiple threads inside the
247          sandbox calling terminal() concurrently must each receive their own
248          response, not another thread's.
249  
250          Before the fix, `_sock` and the recv-loop were shared without a
251          lock, so responses (written FIFO by the single-threaded server)
252          got delivered to whichever client thread happened to win the
253          recv() race. That surfaced as each thread seeing another thread's
254          output.
255  
256          The mock dispatcher sleeps briefly to guarantee the requests
257          overlap on the socket.
258          """
259          code = '''
260  import threading
261  from concurrent.futures import ThreadPoolExecutor
262  from hermes_tools import terminal
263  
264  N = 10
265  
266  def call(i):
267      r = terminal(f"echo TAG-{i}")
268      return i, r.get("output", "")
269  
270  with ThreadPoolExecutor(max_workers=N) as ex:
271      results = list(ex.map(call, range(N)))
272  
273  mismatches = [(i, out) for i, out in results if f"TAG-{i}" not in out]
274  if mismatches:
275      print(f"MISMATCH {len(mismatches)}/{N}: {mismatches[:3]}")
276  else:
277      print(f"OK {N}/{N}")
278  '''
279  
280          def slow_mock(function_name, function_args, task_id=None, user_task=None):
281              import time as _t
282              if function_name == "terminal":
283                  _t.sleep(0.05)  # ensure requests overlap on the socket
284                  cmd = function_args.get("command", "")
285                  # Echo semantics: strip leading "echo " and return the rest
286                  out = cmd[5:] if cmd.startswith("echo ") else f"mock: {cmd}"
287                  return json.dumps({"output": out, "exit_code": 0})
288              return _mock_handle_function_call(
289                  function_name, function_args, task_id=task_id, user_task=user_task
290              )
291  
292          with patch("model_tools.handle_function_call", side_effect=slow_mock):
293              raw = execute_code(
294                  code=code,
295                  task_id="test-concurrent",
296                  enabled_tools=list(SANDBOX_ALLOWED_TOOLS),
297              )
298          result = json.loads(raw)
299          self.assertEqual(result["status"], "success", msg=result)
300          self.assertIn("OK 10/10", result["output"],
301                        msg=f"Concurrent tool calls mismatched: {result['output']!r}")
302  
303      def test_excluded_tool_returns_error(self):
304          """Script calling a tool not in the allow-list gets an error from RPC."""
305          code = """
306  from hermes_tools import terminal
307  result = terminal("echo hi")
308  print(result)
309  """
310          # Only enable web_search -- terminal should be excluded
311          result = self._run(code, enabled_tools=["web_search"])
312          # terminal won't be in hermes_tools.py, so import fails
313          self.assertEqual(result["status"], "error")
314  
315      def test_empty_code(self):
316          """Empty code string returns an error."""
317          result = json.loads(execute_code("", task_id="test"))
318          self.assertIn("error", result)
319  
320      def test_output_captured(self):
321          """Multiple print statements are captured in order."""
322          code = """
323  for i in range(5):
324      print(f"line {i}")
325  """
326          result = self._run(code)
327          self.assertEqual(result["status"], "success")
328          for i in range(5):
329              self.assertIn(f"line {i}", result["output"])
330  
331      def test_stderr_on_error(self):
332          """Traceback from stderr is included in the response."""
333          code = """
334  import sys
335  print("before error")
336  raise RuntimeError("deliberate crash")
337  """
338          result = self._run(code)
339          self.assertEqual(result["status"], "error")
340          self.assertIn("before error", result["output"])
341          self.assertIn("RuntimeError", result.get("error", "") + result.get("output", ""))
342  
343      def test_timeout_enforcement(self):
344          """Script that sleeps too long is killed."""
345          code = "import time; time.sleep(999)"
346          with patch("model_tools.handle_function_call", side_effect=_mock_handle_function_call):
347              # Override config to use a very short timeout
348              with patch("tools.code_execution_tool._load_config", return_value={"timeout": 2, "max_tool_calls": 50}):
349                  result = json.loads(execute_code(
350                      code=code,
351                      task_id="test-task",
352                      enabled_tools=list(SANDBOX_ALLOWED_TOOLS),
353                  ))
354          self.assertEqual(result["status"], "timeout")
355          self.assertIn("timed out", result.get("error", ""))
356          # The timeout message must also appear in output so the LLM always
357          # surfaces it to the user (#10807).
358          self.assertIn("timed out", result.get("output", ""))
359          self.assertIn("\u23f0", result.get("output", ""))
360  
361      def test_web_search_tool(self):
362          """Script calls web_search and processes results."""
363          code = """
364  from hermes_tools import web_search
365  results = web_search("test query")
366  print(f"Found {len(results.get('results', []))} results")
367  """
368          result = self._run(code)
369          self.assertEqual(result["status"], "success")
370          self.assertIn("Found 1 results", result["output"])
371  
372      def test_json_parse_helper(self):
373          """json_parse handles control characters that json.loads(strict=True) rejects."""
374          code = r"""
375  from hermes_tools import json_parse
376  # This JSON has a literal tab character which strict mode rejects
377  text = '{"body": "line1\tline2\nline3"}'
378  result = json_parse(text)
379  print(result["body"])
380  """
381          result = self._run(code)
382          self.assertEqual(result["status"], "success")
383          self.assertIn("line1", result["output"])
384  
385      def test_shell_quote_helper(self):
386          """shell_quote properly escapes dangerous characters."""
387          code = """
388  from hermes_tools import shell_quote
389  # String with backticks, quotes, and special chars
390  dangerous = '`rm -rf /` && $(whoami) "hello"'
391  escaped = shell_quote(dangerous)
392  print(escaped)
393  # Verify it's wrapped in single quotes with proper escaping
394  assert "rm -rf" in escaped
395  assert escaped.startswith("'")
396  """
397          result = self._run(code)
398          self.assertEqual(result["status"], "success")
399  
400      def test_retry_helper_success(self):
401          """retry returns on first success."""
402          code = """
403  from hermes_tools import retry
404  counter = [0]
405  def flaky():
406      counter[0] += 1
407      return f"ok on attempt {counter[0]}"
408  result = retry(flaky)
409  print(result)
410  """
411          result = self._run(code)
412          self.assertEqual(result["status"], "success")
413          self.assertIn("ok on attempt 1", result["output"])
414  
415      def test_retry_helper_eventual_success(self):
416          """retry retries on failure and succeeds eventually."""
417          code = """
418  from hermes_tools import retry
419  counter = [0]
420  def flaky():
421      counter[0] += 1
422      if counter[0] < 3:
423          raise ConnectionError(f"fail {counter[0]}")
424      return "success"
425  result = retry(flaky, max_attempts=3, delay=0.01)
426  print(result)
427  """
428          result = self._run(code)
429          self.assertEqual(result["status"], "success")
430          self.assertIn("success", result["output"])
431  
432      def test_retry_helper_all_fail(self):
433          """retry raises the last error when all attempts fail."""
434          code = """
435  from hermes_tools import retry
436  def always_fail():
437      raise ValueError("nope")
438  try:
439      retry(always_fail, max_attempts=2, delay=0.01)
440      print("should not reach here")
441  except ValueError as e:
442      print(f"caught: {e}")
443  """
444          result = self._run(code)
445          self.assertEqual(result["status"], "success")
446          self.assertIn("caught: nope", result["output"])
447  
448  
449  class TestStubSchemaDrift(unittest.TestCase):
450      """Verify that _TOOL_STUBS in code_execution_tool.py stay in sync with
451      the real tool schemas registered in tools/registry.py.
452  
453      If a tool gains a new parameter but the sandbox stub isn't updated,
454      the LLM will try to use the parameter (it sees it in the system prompt)
455      and get a TypeError.  This test catches that drift.
456      """
457  
458      # Parameters that are internal (injected by the handler, not user-facing)
459      _INTERNAL_PARAMS = {"task_id", "user_task"}
460      # Parameters intentionally blocked in the sandbox
461      _BLOCKED_TERMINAL_PARAMS = {"background", "pty", "notify_on_complete", "watch_patterns"}
462  
463      def test_stubs_cover_all_schema_params(self):
464          """Every user-facing parameter in the real schema must appear in the
465          corresponding _TOOL_STUBS entry."""
466          import re
467          from tools.code_execution_tool import _TOOL_STUBS
468  
469          # Import the registry and trigger tool registration
470          from tools.registry import registry
471          import tools.file_tools  # noqa: F401 - registers read_file, write_file, patch, search_files
472          import tools.web_tools  # noqa: F401 - registers web_search, web_extract
473  
474          for tool_name, (func_name, sig, doc, args_expr) in _TOOL_STUBS.items():
475              entry = registry._tools.get(tool_name)
476              if not entry:
477                  # Tool might not be registered yet (e.g., terminal uses a
478                  # different registration path).  Skip gracefully.
479                  continue
480  
481              schema_props = entry.schema.get("parameters", {}).get("properties", {})
482              schema_params = set(schema_props.keys()) - self._INTERNAL_PARAMS
483              if tool_name == "terminal":
484                  schema_params -= self._BLOCKED_TERMINAL_PARAMS
485  
486              # Extract parameter names from the stub signature string
487              # Match word before colon: "pattern: str, target: str = ..."
488              stub_params = set(re.findall(r'(\w+)\s*:', sig))
489  
490              missing = schema_params - stub_params
491              self.assertEqual(
492                  missing, set(),
493                  f"Stub for '{tool_name}' is missing parameters that exist in "
494                  f"the real schema: {missing}. Update _TOOL_STUBS in "
495                  f"code_execution_tool.py to include them."
496              )
497  
498      def test_stubs_pass_all_params_to_rpc(self):
499          """The args_dict_expr in each stub must include every parameter from
500          the signature, so that all params are actually sent over RPC."""
501          import re
502          from tools.code_execution_tool import _TOOL_STUBS
503  
504          for tool_name, (func_name, sig, doc, args_expr) in _TOOL_STUBS.items():
505              stub_params = set(re.findall(r'(\w+)\s*:', sig))
506              # Check that each param name appears in the args dict expression
507              for param in stub_params:
508                  self.assertIn(
509                      f'"{param}"',
510                      args_expr,
511                      f"Stub for '{tool_name}' has parameter '{param}' in its "
512                      f"signature but doesn't pass it in the args dict: {args_expr}"
513                  )
514  
515      def test_search_files_target_uses_current_values(self):
516          """search_files stub should use 'content'/'files', not old 'grep'/'find'."""
517          from tools.code_execution_tool import _TOOL_STUBS
518          _, sig, doc, _ = _TOOL_STUBS["search_files"]
519          self.assertIn('"content"', sig,
520                        "search_files stub should default target to 'content', not 'grep'")
521          self.assertNotIn('"grep"', sig,
522                           "search_files stub still uses obsolete 'grep' target value")
523          self.assertNotIn('"find"', doc,
524                           "search_files stub docstring still uses obsolete 'find' target value")
525  
526      def test_generated_module_accepts_all_params(self):
527          """The generated hermes_tools.py module should accept all current params
528          without TypeError when called with keyword arguments."""
529          src = generate_hermes_tools_module(list(SANDBOX_ALLOWED_TOOLS))
530  
531          # Compile the generated module to check for syntax errors
532          compile(src, "hermes_tools.py", "exec")
533  
534          # Verify specific parameter signatures are in the source
535          # search_files must accept context, offset, output_mode
536          self.assertIn("context", src)
537          self.assertIn("offset", src)
538          self.assertIn("output_mode", src)
539  
540          # patch must accept mode and patch params
541          self.assertIn("mode", src)
542  
543  
544  # ---------------------------------------------------------------------------
545  # build_execute_code_schema
546  # ---------------------------------------------------------------------------
547  
548  class TestBuildExecuteCodeSchema(unittest.TestCase):
549      """Tests for build_execute_code_schema — the dynamic schema generator."""
550  
551      def test_default_includes_all_tools(self):
552          schema = build_execute_code_schema()
553          desc = schema["description"]
554          for name, _ in _TOOL_DOC_LINES:
555              self.assertIn(name, desc, f"Default schema should mention '{name}'")
556  
557      def test_schema_structure(self):
558          schema = build_execute_code_schema()
559          self.assertEqual(schema["name"], "execute_code")
560          self.assertIn("parameters", schema)
561          self.assertIn("code", schema["parameters"]["properties"])
562          self.assertEqual(schema["parameters"]["required"], ["code"])
563  
564      def test_subset_only_lists_enabled_tools(self):
565          enabled = {"terminal", "read_file"}
566          schema = build_execute_code_schema(enabled)
567          desc = schema["description"]
568          self.assertIn("terminal(", desc)
569          self.assertIn("read_file(", desc)
570          self.assertNotIn("web_search(", desc)
571          self.assertNotIn("web_extract(", desc)
572          self.assertNotIn("write_file(", desc)
573  
574      def test_single_tool(self):
575          schema = build_execute_code_schema({"terminal"})
576          desc = schema["description"]
577          self.assertIn("terminal(", desc)
578          self.assertNotIn("web_search(", desc)
579  
580      def test_import_examples_prefer_web_search_and_terminal(self):
581          enabled = {"web_search", "terminal", "read_file"}
582          schema = build_execute_code_schema(enabled)
583          code_desc = schema["parameters"]["properties"]["code"]["description"]
584          self.assertIn("web_search", code_desc)
585          self.assertIn("terminal", code_desc)
586  
587      def test_import_examples_fallback_when_no_preferred(self):
588          """When neither web_search nor terminal are enabled, falls back to
589          sorted first two tools."""
590          enabled = {"read_file", "write_file", "patch"}
591          schema = build_execute_code_schema(enabled)
592          code_desc = schema["parameters"]["properties"]["code"]["description"]
593          # Should use sorted first 2: patch, read_file
594          self.assertIn("patch", code_desc)
595          self.assertIn("read_file", code_desc)
596  
597      def test_empty_set_produces_valid_description(self):
598          """build_execute_code_schema(set()) must not produce 'import , ...'
599          in the code property description."""
600          schema = build_execute_code_schema(set())
601          code_desc = schema["parameters"]["properties"]["code"]["description"]
602          self.assertNotIn("import , ...", code_desc,
603                           "Empty enabled set produces broken import syntax in description")
604  
605      def test_real_scenario_all_sandbox_tools_disabled(self):
606          """Reproduce the exact code path from model_tools.py:231-234.
607  
608          Scenario: user runs `hermes tools code_execution` (only code_execution
609          toolset enabled). tools_to_include = {"execute_code"}.
610  
611          model_tools.py does:
612              sandbox_enabled = SANDBOX_ALLOWED_TOOLS & tools_to_include
613              dynamic_schema = build_execute_code_schema(sandbox_enabled)
614  
615          SANDBOX_ALLOWED_TOOLS = {web_search, web_extract, read_file, write_file,
616                                    search_files, patch, terminal}
617          tools_to_include  = {"execute_code"}
618          intersection      = empty set
619          """
620          # Simulate model_tools.py:233
621          tools_to_include = {"execute_code"}
622          sandbox_enabled = SANDBOX_ALLOWED_TOOLS & tools_to_include
623  
624          self.assertEqual(sandbox_enabled, set(),
625                           "Intersection should be empty when only execute_code is enabled")
626  
627          schema = build_execute_code_schema(sandbox_enabled)
628          code_desc = schema["parameters"]["properties"]["code"]["description"]
629          self.assertNotIn("import , ...", code_desc,
630                           "Bug: broken import syntax sent to the model")
631  
632      def test_real_scenario_only_vision_enabled(self):
633          """Another real path: user runs `hermes tools code_execution,vision`.
634  
635          tools_to_include = {"execute_code", "vision_analyze"}
636          SANDBOX_ALLOWED_TOOLS has neither, so intersection is empty.
637          """
638          tools_to_include = {"execute_code", "vision_analyze"}
639          sandbox_enabled = SANDBOX_ALLOWED_TOOLS & tools_to_include
640  
641          self.assertEqual(sandbox_enabled, set())
642  
643          schema = build_execute_code_schema(sandbox_enabled)
644          code_desc = schema["parameters"]["properties"]["code"]["description"]
645          self.assertNotIn("import , ...", code_desc)
646  
647      def test_description_mentions_limits(self):
648          schema = build_execute_code_schema()
649          desc = schema["description"]
650          self.assertIn("5-minute timeout", desc)
651          self.assertIn("50KB", desc)
652          self.assertIn("50 tool calls", desc)
653  
654      def test_description_mentions_helpers(self):
655          schema = build_execute_code_schema()
656          desc = schema["description"]
657          self.assertIn("json_parse", desc)
658          self.assertIn("shell_quote", desc)
659          self.assertIn("retry", desc)
660  
661      def test_none_defaults_to_all_tools(self):
662          schema_none = build_execute_code_schema(None)
663          schema_all = build_execute_code_schema(SANDBOX_ALLOWED_TOOLS)
664          self.assertEqual(schema_none["description"], schema_all["description"])
665  
666  
667  # ---------------------------------------------------------------------------
668  # Environment variable filtering (security critical)
669  # ---------------------------------------------------------------------------
670  
671  @unittest.skipIf(sys.platform == "win32", "UDS not available on Windows")
672  class TestEnvVarFiltering(unittest.TestCase):
673      """Verify that execute_code filters environment variables correctly.
674  
675      The child process should NOT receive API keys, tokens, or secrets.
676      It should receive safe vars like PATH, HOME, LANG, etc.
677      """
678  
679      def _get_child_env(self, extra_env=None):
680          """Run a script that dumps its environment and return the env dict."""
681          code = (
682              "import os, json\n"
683              "print(json.dumps(dict(os.environ)))\n"
684          )
685          env_backup = os.environ.copy()
686          try:
687              if extra_env:
688                  os.environ.update(extra_env)
689              with patch("model_tools.handle_function_call", return_value='{}'), \
690                   patch("tools.code_execution_tool._load_config",
691                         return_value={"timeout": 10, "max_tool_calls": 50}):
692                  raw = execute_code(code, task_id="test-env",
693                                     enabled_tools=list(SANDBOX_ALLOWED_TOOLS))
694          finally:
695              os.environ.clear()
696              os.environ.update(env_backup)
697  
698          result = json.loads(raw)
699          self.assertEqual(result["status"], "success", result.get("error", ""))
700          return json.loads(result["output"].strip())
701  
702      def test_api_keys_excluded(self):
703          child_env = self._get_child_env({
704              "OPENAI_API_KEY": "sk-secret123",
705              "ANTHROPIC_API_KEY": "sk-ant-secret",
706              "FIRECRAWL_API_KEY": "fc-secret",
707          })
708          self.assertNotIn("OPENAI_API_KEY", child_env)
709          self.assertNotIn("ANTHROPIC_API_KEY", child_env)
710          self.assertNotIn("FIRECRAWL_API_KEY", child_env)
711  
712      def test_tokens_excluded(self):
713          child_env = self._get_child_env({
714              "GITHUB_TOKEN": "ghp_secret",
715              "MODAL_TOKEN_ID": "tok-123",
716              "MODAL_TOKEN_SECRET": "tok-sec",
717          })
718          self.assertNotIn("GITHUB_TOKEN", child_env)
719          self.assertNotIn("MODAL_TOKEN_ID", child_env)
720          self.assertNotIn("MODAL_TOKEN_SECRET", child_env)
721  
722      def test_password_vars_excluded(self):
723          child_env = self._get_child_env({
724              "DB_PASSWORD": "hunter2",
725              "MY_PASSWD": "secret",
726              "AUTH_CREDENTIAL": "cred",
727          })
728          self.assertNotIn("DB_PASSWORD", child_env)
729          self.assertNotIn("MY_PASSWD", child_env)
730          self.assertNotIn("AUTH_CREDENTIAL", child_env)
731  
732      def test_path_included(self):
733          child_env = self._get_child_env()
734          self.assertIn("PATH", child_env)
735  
736      def test_home_included(self):
737          child_env = self._get_child_env()
738          self.assertIn("HOME", child_env)
739  
740      def test_hermes_rpc_socket_injected(self):
741          child_env = self._get_child_env()
742          self.assertIn("HERMES_RPC_SOCKET", child_env)
743  
744      def test_pythondontwritebytecode_set(self):
745          child_env = self._get_child_env()
746          self.assertEqual(child_env.get("PYTHONDONTWRITEBYTECODE"), "1")
747  
748      def test_timezone_injected_when_set(self):
749          env_backup = os.environ.copy()
750          try:
751              os.environ["HERMES_TIMEZONE"] = "America/New_York"
752              child_env = self._get_child_env()
753              self.assertEqual(child_env.get("TZ"), "America/New_York")
754          finally:
755              os.environ.clear()
756              os.environ.update(env_backup)
757  
758      def test_timezone_not_set_when_empty(self):
759          env_backup = os.environ.copy()
760          try:
761              os.environ.pop("HERMES_TIMEZONE", None)
762              child_env = self._get_child_env()
763              if "TZ" in child_env:
764                  self.assertNotEqual(child_env["TZ"], "")
765          finally:
766              os.environ.clear()
767              os.environ.update(env_backup)
768  
769  
770  # ---------------------------------------------------------------------------
771  # execute_code edge cases
772  # ---------------------------------------------------------------------------
773  
774  class TestExecuteCodeEdgeCases(unittest.TestCase):
775  
776      def test_windows_returns_error(self):
777          """On Windows (or when SANDBOX_AVAILABLE is False), returns error JSON."""
778          with patch("tools.code_execution_tool.SANDBOX_AVAILABLE", False):
779              result = json.loads(execute_code("print('hi')", task_id="test"))
780              self.assertIn("error", result)
781              self.assertIn("Windows", result["error"])
782  
783      def test_whitespace_only_code(self):
784          result = json.loads(execute_code("   \n\t  ", task_id="test"))
785          self.assertIn("error", result)
786          self.assertIn("No code", result["error"])
787  
788      @unittest.skipIf(sys.platform == "win32", "UDS not available on Windows")
789      def test_none_enabled_tools_uses_all(self):
790          """When enabled_tools is None, all sandbox tools should be available."""
791          code = (
792              "from hermes_tools import terminal, web_search, read_file\n"
793              "print('all imports ok')\n"
794          )
795          with patch("model_tools.handle_function_call",
796                      return_value=json.dumps({"ok": True})):
797              result = json.loads(execute_code(code, task_id="test-none",
798                                               enabled_tools=None))
799          self.assertEqual(result["status"], "success")
800          self.assertIn("all imports ok", result["output"])
801  
802      @unittest.skipIf(sys.platform == "win32", "UDS not available on Windows")
803      def test_empty_enabled_tools_uses_all(self):
804          """When enabled_tools is [] (empty), all sandbox tools should be available."""
805          code = (
806              "from hermes_tools import terminal, web_search\n"
807              "print('imports ok')\n"
808          )
809          with patch("model_tools.handle_function_call",
810                      return_value=json.dumps({"ok": True})):
811              result = json.loads(execute_code(code, task_id="test-empty",
812                                               enabled_tools=[]))
813          self.assertEqual(result["status"], "success")
814          self.assertIn("imports ok", result["output"])
815  
816      @unittest.skipIf(sys.platform == "win32", "UDS not available on Windows")
817      def test_nonoverlapping_tools_fallback(self):
818          """When enabled_tools has no overlap with SANDBOX_ALLOWED_TOOLS,
819          should fall back to all allowed tools."""
820          code = (
821              "from hermes_tools import terminal\n"
822              "print('fallback ok')\n"
823          )
824          with patch("model_tools.handle_function_call",
825                      return_value=json.dumps({"ok": True})):
826              result = json.loads(execute_code(
827                  code, task_id="test-nonoverlap",
828                  enabled_tools=["vision_analyze", "browser_snapshot"],
829              ))
830          self.assertEqual(result["status"], "success")
831          self.assertIn("fallback ok", result["output"])
832  
833  
834  # ---------------------------------------------------------------------------
835  # _load_config
836  # ---------------------------------------------------------------------------
837  
838  class TestLoadConfig(unittest.TestCase):
839      def test_returns_empty_dict_when_cli_config_unavailable(self):
840          from tools.code_execution_tool import _load_config
841          with patch.dict("sys.modules", {"cli": None}):
842              result = _load_config()
843              self.assertIsInstance(result, dict)
844  
845      def test_returns_code_execution_section(self):
846          from tools.code_execution_tool import _load_config
847          with patch("hermes_cli.config.read_raw_config",
848                     return_value={"code_execution": {"timeout": 120, "max_tool_calls": 10}}):
849              result = _load_config()
850          self.assertEqual(result, {"timeout": 120, "max_tool_calls": 10})
851  
852      def test_does_not_import_interactive_cli(self):
853          from tools.code_execution_tool import _load_config
854          mock_cli = MagicMock()
855          mock_cli.CLI_CONFIG = {"code_execution": {"timeout": 999}}
856          with patch.dict("sys.modules", {"cli": mock_cli}), \
857               patch("hermes_cli.config.read_raw_config", return_value={}):
858              result = _load_config()
859          self.assertEqual(result, {})
860  
861  
862  # ---------------------------------------------------------------------------
863  # Interrupt event
864  # ---------------------------------------------------------------------------
865  
866  @unittest.skipIf(sys.platform == "win32", "UDS not available on Windows")
867  class TestInterruptHandling(unittest.TestCase):
868      def test_interrupt_event_stops_execution(self):
869          """When interrupt is set for the execution thread, execute_code should stop."""
870          code = "import time; time.sleep(60); print('should not reach')"
871          from tools.interrupt import set_interrupt
872  
873          # Capture the main thread ID so we can target the interrupt correctly.
874          # execute_code runs in the current thread; set_interrupt needs its ID.
875          main_tid = threading.current_thread().ident
876  
877          def set_interrupt_after_delay():
878              import time as _t
879              _t.sleep(1)
880              set_interrupt(True, main_tid)
881  
882          t = threading.Thread(target=set_interrupt_after_delay, daemon=True)
883          t.start()
884  
885          try:
886              with patch("model_tools.handle_function_call",
887                          return_value=json.dumps({"ok": True})), \
888                   patch("tools.code_execution_tool._load_config",
889                         return_value={"timeout": 30, "max_tool_calls": 50}):
890                  result = json.loads(execute_code(
891                      code, task_id="test-interrupt",
892                      enabled_tools=list(SANDBOX_ALLOWED_TOOLS),
893                  ))
894              self.assertEqual(result["status"], "interrupted")
895              self.assertIn("interrupted", result["output"])
896          finally:
897              set_interrupt(False, main_tid)
898              t.join(timeout=3)
899  
900  
901  class TestHeadTailTruncation(unittest.TestCase):
902      """Tests for head+tail truncation of large stdout in execute_code."""
903  
904      def _run(self, code):
905          with patch("model_tools.handle_function_call", side_effect=_mock_handle_function_call):
906              result = execute_code(
907                  code=code,
908                  task_id="test-task",
909                  enabled_tools=list(SANDBOX_ALLOWED_TOOLS),
910              )
911          return json.loads(result)
912  
913      def test_short_output_not_truncated(self):
914          """Output under MAX_STDOUT_BYTES should not be truncated."""
915          result = self._run('print("small output")')
916          self.assertEqual(result["status"], "success")
917          self.assertIn("small output", result["output"])
918          self.assertNotIn("TRUNCATED", result["output"])
919  
920      def test_large_output_preserves_head_and_tail(self):
921          """Output exceeding MAX_STDOUT_BYTES keeps both head and tail."""
922          code = '''
923  # Print HEAD marker, then filler, then TAIL marker
924  print("HEAD_MARKER_START")
925  for i in range(15000):
926      print(f"filler_line_{i:06d}_padding_to_fill_buffer")
927  print("TAIL_MARKER_END")
928  '''
929          result = self._run(code)
930          self.assertEqual(result["status"], "success")
931          output = result["output"]
932          # Head should be preserved
933          self.assertIn("HEAD_MARKER_START", output)
934          # Tail should be preserved (this is the key improvement)
935          self.assertIn("TAIL_MARKER_END", output)
936          # Truncation notice should be present
937          self.assertIn("TRUNCATED", output)
938  
939      def test_truncation_notice_format(self):
940          """Truncation notice includes character counts."""
941          code = '''
942  for i in range(15000):
943      print(f"padding_line_{i:06d}_xxxxxxxxxxxxxxxxxxxxxxxxxx")
944  '''
945          result = self._run(code)
946          output = result["output"]
947          if "TRUNCATED" in output:
948              self.assertIn("chars omitted", output)
949              self.assertIn("total", output)
950  
951  
952  if __name__ == "__main__":
953      unittest.main()