test_code_execution.py
1 #!/usr/bin/env python3 2 """ 3 4 Tests for the code execution sandbox (programmatic tool calling). 5 6 These tests monkeypatch handle_function_call so they don't require API keys 7 or a running terminal backend. They verify the core sandbox mechanics: 8 UDS socket lifecycle, hermes_tools generation, timeout enforcement, 9 output capping, tool call counting, and error propagation. 10 11 Run with: python -m pytest tests/test_code_execution.py -v 12 or: python tests/test_code_execution.py 13 """ 14 15 import pytest 16 # pytestmark removed — tests run fine (61 pass, ~99s) 17 18 import json 19 import os 20 21 os.environ["TERMINAL_ENV"] = "local" 22 23 24 @pytest.fixture(autouse=True) 25 def _force_local_terminal(monkeypatch): 26 """Re-set TERMINAL_ENV=local before every test. 27 28 The module-level assignment above covers import time, but under xdist 29 another worker can overwrite os.environ between tests. monkeypatch 30 ensures each test starts (and ends) with the correct value. 31 """ 32 monkeypatch.setenv("TERMINAL_ENV", "local") 33 import sys 34 import time 35 import threading 36 import unittest 37 from unittest.mock import patch, MagicMock 38 39 from tools.code_execution_tool import ( 40 SANDBOX_ALLOWED_TOOLS, 41 execute_code, 42 generate_hermes_tools_module, 43 check_sandbox_requirements, 44 build_execute_code_schema, 45 EXECUTE_CODE_SCHEMA, 46 _TOOL_DOC_LINES, 47 _execute_remote, 48 ) 49 50 51 def _mock_handle_function_call(function_name, function_args, task_id=None, user_task=None): 52 """Mock dispatcher that returns canned responses for each tool.""" 53 if function_name == "terminal": 54 cmd = function_args.get("command", "") 55 return json.dumps({"output": f"mock output for: {cmd}", "exit_code": 0}) 56 if function_name == "web_search": 57 return json.dumps({"results": [{"url": "https://example.com", "title": "Example", "description": "A test result"}]}) 58 if function_name == "read_file": 59 return json.dumps({"content": "line 1\nline 2\nline 3\n", "total_lines": 3}) 60 if function_name == "write_file": 61 return json.dumps({"status": "ok", "path": function_args.get("path", "")}) 62 if function_name == "search_files": 63 return json.dumps({"matches": [{"file": "test.py", "line": 1, "text": "match"}]}) 64 if function_name == "patch": 65 return json.dumps({"status": "ok", "replacements": 1}) 66 if function_name == "web_extract": 67 return json.dumps("# Extracted content\nSome text from the page.") 68 return json.dumps({"error": f"Unknown tool in mock: {function_name}"}) 69 70 71 class TestSandboxRequirements(unittest.TestCase): 72 def test_available_on_posix(self): 73 if sys.platform != "win32": 74 self.assertTrue(check_sandbox_requirements()) 75 76 def test_schema_is_valid(self): 77 self.assertEqual(EXECUTE_CODE_SCHEMA["name"], "execute_code") 78 self.assertIn("code", EXECUTE_CODE_SCHEMA["parameters"]["properties"]) 79 self.assertIn("code", EXECUTE_CODE_SCHEMA["parameters"]["required"]) 80 81 82 class TestHermesToolsGeneration(unittest.TestCase): 83 def test_generates_all_allowed_tools(self): 84 src = generate_hermes_tools_module(list(SANDBOX_ALLOWED_TOOLS)) 85 for tool in SANDBOX_ALLOWED_TOOLS: 86 self.assertIn(f"def {tool}(", src) 87 88 def test_generates_subset(self): 89 src = generate_hermes_tools_module(["terminal", "web_search"]) 90 self.assertIn("def terminal(", src) 91 self.assertIn("def web_search(", src) 92 self.assertNotIn("def read_file(", src) 93 94 def test_empty_list_generates_nothing(self): 95 src = generate_hermes_tools_module([]) 96 self.assertNotIn("def terminal(", src) 97 self.assertIn("def _call(", src) # infrastructure still present 98 99 def test_non_allowed_tools_ignored(self): 100 src = generate_hermes_tools_module(["vision_analyze", "terminal"]) 101 self.assertIn("def terminal(", src) 102 self.assertNotIn("def vision_analyze(", src) 103 104 def test_rpc_infrastructure_present(self): 105 src = generate_hermes_tools_module(["terminal"]) 106 self.assertIn("HERMES_RPC_SOCKET", src) 107 self.assertIn("AF_UNIX", src) 108 self.assertIn("def _connect(", src) 109 self.assertIn("def _call(", src) 110 111 def test_convenience_helpers_present(self): 112 """Verify json_parse, shell_quote, and retry helpers are generated.""" 113 src = generate_hermes_tools_module(["terminal"]) 114 self.assertIn("def json_parse(", src) 115 self.assertIn("def shell_quote(", src) 116 self.assertIn("def retry(", src) 117 self.assertIn("import json, os, socket, shlex, threading, time", src) 118 119 def test_file_transport_uses_tempfile_fallback_for_rpc_dir(self): 120 src = generate_hermes_tools_module(["terminal"], transport="file") 121 self.assertIn("import json, os, shlex, tempfile, threading, time", src) 122 self.assertIn("os.path.join(tempfile.gettempdir(), \"hermes_rpc\")", src) 123 self.assertNotIn('os.environ.get("HERMES_RPC_DIR", "/tmp/hermes_rpc")', src) 124 125 def test_uds_transport_serializes_concurrent_calls(self): 126 """Regression: UDS _call() must hold a lock across send+recv so that 127 concurrent tool calls from multiple threads don't interleave on the 128 shared socket and receive each other's responses.""" 129 src = generate_hermes_tools_module(["terminal"], transport="uds") 130 self.assertIn("_call_lock = threading.Lock()", src) 131 self.assertIn("with _call_lock:", src) 132 133 def test_file_transport_serializes_seq_allocation(self): 134 """Regression: file transport _call() must allocate `_seq` under a 135 lock, otherwise concurrent threads can pick the same seq and clobber 136 each other's request files.""" 137 src = generate_hermes_tools_module(["terminal"], transport="file") 138 self.assertIn("_seq_lock = threading.Lock()", src) 139 self.assertIn("with _seq_lock:", src) 140 141 142 class TestExecuteCodeRemoteTempDir(unittest.TestCase): 143 def test_execute_remote_uses_backend_temp_dir_for_sandbox(self): 144 class FakeEnv: 145 def __init__(self): 146 self.commands = [] 147 148 def get_temp_dir(self): 149 return "/data/data/com.termux/files/usr/tmp" 150 151 def execute(self, command, cwd=None, timeout=None): 152 self.commands.append((command, cwd, timeout)) 153 if "command -v python3" in command: 154 return {"output": "OK\n"} 155 if "python3 script.py" in command: 156 return {"output": "hello\n", "returncode": 0} 157 return {"output": ""} 158 159 env = FakeEnv() 160 fake_thread = MagicMock() 161 162 with patch("tools.code_execution_tool._load_config", return_value={"timeout": 30, "max_tool_calls": 5}), \ 163 patch("tools.code_execution_tool._get_or_create_env", return_value=(env, "ssh")), \ 164 patch("tools.code_execution_tool._ship_file_to_remote"), \ 165 patch("tools.code_execution_tool.threading.Thread", return_value=fake_thread): 166 result = json.loads(_execute_remote("print('hello')", "task-1", ["terminal"])) 167 168 self.assertEqual(result["status"], "success") 169 mkdir_cmd = env.commands[1][0] 170 run_cmd = next(cmd for cmd, _, _ in env.commands if "python3 script.py" in cmd) 171 cleanup_cmd = env.commands[-1][0] 172 self.assertIn("mkdir -p /data/data/com.termux/files/usr/tmp/hermes_exec_", mkdir_cmd) 173 self.assertIn("HERMES_RPC_DIR=/data/data/com.termux/files/usr/tmp/hermes_exec_", run_cmd) 174 self.assertIn("rm -rf /data/data/com.termux/files/usr/tmp/hermes_exec_", cleanup_cmd) 175 self.assertNotIn("mkdir -p /tmp/hermes_exec_", mkdir_cmd) 176 177 178 @unittest.skipIf(sys.platform == "win32", "UDS not available on Windows") 179 class TestExecuteCode(unittest.TestCase): 180 """Integration tests using the mock dispatcher.""" 181 182 def _run(self, code, enabled_tools=None): 183 """Helper: run code with mocked handle_function_call.""" 184 with patch("tools.code_execution_tool._rpc_server_loop") as mock_rpc: 185 # Use real execution but mock the tool dispatcher 186 pass 187 # Actually run with full integration, mocking at the model_tools level 188 with patch("model_tools.handle_function_call", side_effect=_mock_handle_function_call): 189 result = execute_code( 190 code=code, 191 task_id="test-task", 192 enabled_tools=enabled_tools or list(SANDBOX_ALLOWED_TOOLS), 193 ) 194 return json.loads(result) 195 196 def test_basic_print(self): 197 """Script that just prints -- no tool calls.""" 198 result = self._run('print("hello world")') 199 self.assertEqual(result["status"], "success") 200 self.assertIn("hello world", result["output"]) 201 self.assertEqual(result["tool_calls_made"], 0) 202 203 def test_repo_root_modules_are_importable(self): 204 """Sandboxed scripts can import modules that live at the repo root.""" 205 result = self._run('import hermes_constants; print(hermes_constants.__file__)') 206 self.assertEqual(result["status"], "success") 207 self.assertIn("hermes_constants.py", result["output"]) 208 209 def test_single_tool_call(self): 210 """Script calls terminal and prints the result.""" 211 code = """ 212 from hermes_tools import terminal 213 result = terminal("echo hello") 214 print(result.get("output", "")) 215 """ 216 result = self._run(code) 217 self.assertEqual(result["status"], "success") 218 self.assertIn("mock output for: echo hello", result["output"]) 219 self.assertEqual(result["tool_calls_made"], 1) 220 221 def test_multi_tool_chain(self): 222 """Script calls multiple tools sequentially.""" 223 code = """ 224 from hermes_tools import terminal, read_file 225 r1 = terminal("ls") 226 r2 = read_file("test.py") 227 print(f"terminal: {r1['output'][:20]}") 228 print(f"file lines: {r2['total_lines']}") 229 """ 230 result = self._run(code) 231 self.assertEqual(result["status"], "success") 232 self.assertEqual(result["tool_calls_made"], 2) 233 234 def test_syntax_error(self): 235 """Script with a syntax error returns error status.""" 236 result = self._run("def broken(") 237 self.assertEqual(result["status"], "error") 238 self.assertIn("SyntaxError", result.get("error", "") + result.get("output", "")) 239 240 def test_runtime_exception(self): 241 """Script with a runtime error returns error status.""" 242 result = self._run("raise ValueError('test error')") 243 self.assertEqual(result["status"], "error") 244 245 def test_concurrent_tool_calls_match_responses(self): 246 """Regression for the UDS RPC race: multiple threads inside the 247 sandbox calling terminal() concurrently must each receive their own 248 response, not another thread's. 249 250 Before the fix, `_sock` and the recv-loop were shared without a 251 lock, so responses (written FIFO by the single-threaded server) 252 got delivered to whichever client thread happened to win the 253 recv() race. That surfaced as each thread seeing another thread's 254 output. 255 256 The mock dispatcher sleeps briefly to guarantee the requests 257 overlap on the socket. 258 """ 259 code = ''' 260 import threading 261 from concurrent.futures import ThreadPoolExecutor 262 from hermes_tools import terminal 263 264 N = 10 265 266 def call(i): 267 r = terminal(f"echo TAG-{i}") 268 return i, r.get("output", "") 269 270 with ThreadPoolExecutor(max_workers=N) as ex: 271 results = list(ex.map(call, range(N))) 272 273 mismatches = [(i, out) for i, out in results if f"TAG-{i}" not in out] 274 if mismatches: 275 print(f"MISMATCH {len(mismatches)}/{N}: {mismatches[:3]}") 276 else: 277 print(f"OK {N}/{N}") 278 ''' 279 280 def slow_mock(function_name, function_args, task_id=None, user_task=None): 281 import time as _t 282 if function_name == "terminal": 283 _t.sleep(0.05) # ensure requests overlap on the socket 284 cmd = function_args.get("command", "") 285 # Echo semantics: strip leading "echo " and return the rest 286 out = cmd[5:] if cmd.startswith("echo ") else f"mock: {cmd}" 287 return json.dumps({"output": out, "exit_code": 0}) 288 return _mock_handle_function_call( 289 function_name, function_args, task_id=task_id, user_task=user_task 290 ) 291 292 with patch("model_tools.handle_function_call", side_effect=slow_mock): 293 raw = execute_code( 294 code=code, 295 task_id="test-concurrent", 296 enabled_tools=list(SANDBOX_ALLOWED_TOOLS), 297 ) 298 result = json.loads(raw) 299 self.assertEqual(result["status"], "success", msg=result) 300 self.assertIn("OK 10/10", result["output"], 301 msg=f"Concurrent tool calls mismatched: {result['output']!r}") 302 303 def test_excluded_tool_returns_error(self): 304 """Script calling a tool not in the allow-list gets an error from RPC.""" 305 code = """ 306 from hermes_tools import terminal 307 result = terminal("echo hi") 308 print(result) 309 """ 310 # Only enable web_search -- terminal should be excluded 311 result = self._run(code, enabled_tools=["web_search"]) 312 # terminal won't be in hermes_tools.py, so import fails 313 self.assertEqual(result["status"], "error") 314 315 def test_empty_code(self): 316 """Empty code string returns an error.""" 317 result = json.loads(execute_code("", task_id="test")) 318 self.assertIn("error", result) 319 320 def test_output_captured(self): 321 """Multiple print statements are captured in order.""" 322 code = """ 323 for i in range(5): 324 print(f"line {i}") 325 """ 326 result = self._run(code) 327 self.assertEqual(result["status"], "success") 328 for i in range(5): 329 self.assertIn(f"line {i}", result["output"]) 330 331 def test_stderr_on_error(self): 332 """Traceback from stderr is included in the response.""" 333 code = """ 334 import sys 335 print("before error") 336 raise RuntimeError("deliberate crash") 337 """ 338 result = self._run(code) 339 self.assertEqual(result["status"], "error") 340 self.assertIn("before error", result["output"]) 341 self.assertIn("RuntimeError", result.get("error", "") + result.get("output", "")) 342 343 def test_timeout_enforcement(self): 344 """Script that sleeps too long is killed.""" 345 code = "import time; time.sleep(999)" 346 with patch("model_tools.handle_function_call", side_effect=_mock_handle_function_call): 347 # Override config to use a very short timeout 348 with patch("tools.code_execution_tool._load_config", return_value={"timeout": 2, "max_tool_calls": 50}): 349 result = json.loads(execute_code( 350 code=code, 351 task_id="test-task", 352 enabled_tools=list(SANDBOX_ALLOWED_TOOLS), 353 )) 354 self.assertEqual(result["status"], "timeout") 355 self.assertIn("timed out", result.get("error", "")) 356 # The timeout message must also appear in output so the LLM always 357 # surfaces it to the user (#10807). 358 self.assertIn("timed out", result.get("output", "")) 359 self.assertIn("\u23f0", result.get("output", "")) 360 361 def test_web_search_tool(self): 362 """Script calls web_search and processes results.""" 363 code = """ 364 from hermes_tools import web_search 365 results = web_search("test query") 366 print(f"Found {len(results.get('results', []))} results") 367 """ 368 result = self._run(code) 369 self.assertEqual(result["status"], "success") 370 self.assertIn("Found 1 results", result["output"]) 371 372 def test_json_parse_helper(self): 373 """json_parse handles control characters that json.loads(strict=True) rejects.""" 374 code = r""" 375 from hermes_tools import json_parse 376 # This JSON has a literal tab character which strict mode rejects 377 text = '{"body": "line1\tline2\nline3"}' 378 result = json_parse(text) 379 print(result["body"]) 380 """ 381 result = self._run(code) 382 self.assertEqual(result["status"], "success") 383 self.assertIn("line1", result["output"]) 384 385 def test_shell_quote_helper(self): 386 """shell_quote properly escapes dangerous characters.""" 387 code = """ 388 from hermes_tools import shell_quote 389 # String with backticks, quotes, and special chars 390 dangerous = '`rm -rf /` && $(whoami) "hello"' 391 escaped = shell_quote(dangerous) 392 print(escaped) 393 # Verify it's wrapped in single quotes with proper escaping 394 assert "rm -rf" in escaped 395 assert escaped.startswith("'") 396 """ 397 result = self._run(code) 398 self.assertEqual(result["status"], "success") 399 400 def test_retry_helper_success(self): 401 """retry returns on first success.""" 402 code = """ 403 from hermes_tools import retry 404 counter = [0] 405 def flaky(): 406 counter[0] += 1 407 return f"ok on attempt {counter[0]}" 408 result = retry(flaky) 409 print(result) 410 """ 411 result = self._run(code) 412 self.assertEqual(result["status"], "success") 413 self.assertIn("ok on attempt 1", result["output"]) 414 415 def test_retry_helper_eventual_success(self): 416 """retry retries on failure and succeeds eventually.""" 417 code = """ 418 from hermes_tools import retry 419 counter = [0] 420 def flaky(): 421 counter[0] += 1 422 if counter[0] < 3: 423 raise ConnectionError(f"fail {counter[0]}") 424 return "success" 425 result = retry(flaky, max_attempts=3, delay=0.01) 426 print(result) 427 """ 428 result = self._run(code) 429 self.assertEqual(result["status"], "success") 430 self.assertIn("success", result["output"]) 431 432 def test_retry_helper_all_fail(self): 433 """retry raises the last error when all attempts fail.""" 434 code = """ 435 from hermes_tools import retry 436 def always_fail(): 437 raise ValueError("nope") 438 try: 439 retry(always_fail, max_attempts=2, delay=0.01) 440 print("should not reach here") 441 except ValueError as e: 442 print(f"caught: {e}") 443 """ 444 result = self._run(code) 445 self.assertEqual(result["status"], "success") 446 self.assertIn("caught: nope", result["output"]) 447 448 449 class TestStubSchemaDrift(unittest.TestCase): 450 """Verify that _TOOL_STUBS in code_execution_tool.py stay in sync with 451 the real tool schemas registered in tools/registry.py. 452 453 If a tool gains a new parameter but the sandbox stub isn't updated, 454 the LLM will try to use the parameter (it sees it in the system prompt) 455 and get a TypeError. This test catches that drift. 456 """ 457 458 # Parameters that are internal (injected by the handler, not user-facing) 459 _INTERNAL_PARAMS = {"task_id", "user_task"} 460 # Parameters intentionally blocked in the sandbox 461 _BLOCKED_TERMINAL_PARAMS = {"background", "pty", "notify_on_complete", "watch_patterns"} 462 463 def test_stubs_cover_all_schema_params(self): 464 """Every user-facing parameter in the real schema must appear in the 465 corresponding _TOOL_STUBS entry.""" 466 import re 467 from tools.code_execution_tool import _TOOL_STUBS 468 469 # Import the registry and trigger tool registration 470 from tools.registry import registry 471 import tools.file_tools # noqa: F401 - registers read_file, write_file, patch, search_files 472 import tools.web_tools # noqa: F401 - registers web_search, web_extract 473 474 for tool_name, (func_name, sig, doc, args_expr) in _TOOL_STUBS.items(): 475 entry = registry._tools.get(tool_name) 476 if not entry: 477 # Tool might not be registered yet (e.g., terminal uses a 478 # different registration path). Skip gracefully. 479 continue 480 481 schema_props = entry.schema.get("parameters", {}).get("properties", {}) 482 schema_params = set(schema_props.keys()) - self._INTERNAL_PARAMS 483 if tool_name == "terminal": 484 schema_params -= self._BLOCKED_TERMINAL_PARAMS 485 486 # Extract parameter names from the stub signature string 487 # Match word before colon: "pattern: str, target: str = ..." 488 stub_params = set(re.findall(r'(\w+)\s*:', sig)) 489 490 missing = schema_params - stub_params 491 self.assertEqual( 492 missing, set(), 493 f"Stub for '{tool_name}' is missing parameters that exist in " 494 f"the real schema: {missing}. Update _TOOL_STUBS in " 495 f"code_execution_tool.py to include them." 496 ) 497 498 def test_stubs_pass_all_params_to_rpc(self): 499 """The args_dict_expr in each stub must include every parameter from 500 the signature, so that all params are actually sent over RPC.""" 501 import re 502 from tools.code_execution_tool import _TOOL_STUBS 503 504 for tool_name, (func_name, sig, doc, args_expr) in _TOOL_STUBS.items(): 505 stub_params = set(re.findall(r'(\w+)\s*:', sig)) 506 # Check that each param name appears in the args dict expression 507 for param in stub_params: 508 self.assertIn( 509 f'"{param}"', 510 args_expr, 511 f"Stub for '{tool_name}' has parameter '{param}' in its " 512 f"signature but doesn't pass it in the args dict: {args_expr}" 513 ) 514 515 def test_search_files_target_uses_current_values(self): 516 """search_files stub should use 'content'/'files', not old 'grep'/'find'.""" 517 from tools.code_execution_tool import _TOOL_STUBS 518 _, sig, doc, _ = _TOOL_STUBS["search_files"] 519 self.assertIn('"content"', sig, 520 "search_files stub should default target to 'content', not 'grep'") 521 self.assertNotIn('"grep"', sig, 522 "search_files stub still uses obsolete 'grep' target value") 523 self.assertNotIn('"find"', doc, 524 "search_files stub docstring still uses obsolete 'find' target value") 525 526 def test_generated_module_accepts_all_params(self): 527 """The generated hermes_tools.py module should accept all current params 528 without TypeError when called with keyword arguments.""" 529 src = generate_hermes_tools_module(list(SANDBOX_ALLOWED_TOOLS)) 530 531 # Compile the generated module to check for syntax errors 532 compile(src, "hermes_tools.py", "exec") 533 534 # Verify specific parameter signatures are in the source 535 # search_files must accept context, offset, output_mode 536 self.assertIn("context", src) 537 self.assertIn("offset", src) 538 self.assertIn("output_mode", src) 539 540 # patch must accept mode and patch params 541 self.assertIn("mode", src) 542 543 544 # --------------------------------------------------------------------------- 545 # build_execute_code_schema 546 # --------------------------------------------------------------------------- 547 548 class TestBuildExecuteCodeSchema(unittest.TestCase): 549 """Tests for build_execute_code_schema — the dynamic schema generator.""" 550 551 def test_default_includes_all_tools(self): 552 schema = build_execute_code_schema() 553 desc = schema["description"] 554 for name, _ in _TOOL_DOC_LINES: 555 self.assertIn(name, desc, f"Default schema should mention '{name}'") 556 557 def test_schema_structure(self): 558 schema = build_execute_code_schema() 559 self.assertEqual(schema["name"], "execute_code") 560 self.assertIn("parameters", schema) 561 self.assertIn("code", schema["parameters"]["properties"]) 562 self.assertEqual(schema["parameters"]["required"], ["code"]) 563 564 def test_subset_only_lists_enabled_tools(self): 565 enabled = {"terminal", "read_file"} 566 schema = build_execute_code_schema(enabled) 567 desc = schema["description"] 568 self.assertIn("terminal(", desc) 569 self.assertIn("read_file(", desc) 570 self.assertNotIn("web_search(", desc) 571 self.assertNotIn("web_extract(", desc) 572 self.assertNotIn("write_file(", desc) 573 574 def test_single_tool(self): 575 schema = build_execute_code_schema({"terminal"}) 576 desc = schema["description"] 577 self.assertIn("terminal(", desc) 578 self.assertNotIn("web_search(", desc) 579 580 def test_import_examples_prefer_web_search_and_terminal(self): 581 enabled = {"web_search", "terminal", "read_file"} 582 schema = build_execute_code_schema(enabled) 583 code_desc = schema["parameters"]["properties"]["code"]["description"] 584 self.assertIn("web_search", code_desc) 585 self.assertIn("terminal", code_desc) 586 587 def test_import_examples_fallback_when_no_preferred(self): 588 """When neither web_search nor terminal are enabled, falls back to 589 sorted first two tools.""" 590 enabled = {"read_file", "write_file", "patch"} 591 schema = build_execute_code_schema(enabled) 592 code_desc = schema["parameters"]["properties"]["code"]["description"] 593 # Should use sorted first 2: patch, read_file 594 self.assertIn("patch", code_desc) 595 self.assertIn("read_file", code_desc) 596 597 def test_empty_set_produces_valid_description(self): 598 """build_execute_code_schema(set()) must not produce 'import , ...' 599 in the code property description.""" 600 schema = build_execute_code_schema(set()) 601 code_desc = schema["parameters"]["properties"]["code"]["description"] 602 self.assertNotIn("import , ...", code_desc, 603 "Empty enabled set produces broken import syntax in description") 604 605 def test_real_scenario_all_sandbox_tools_disabled(self): 606 """Reproduce the exact code path from model_tools.py:231-234. 607 608 Scenario: user runs `hermes tools code_execution` (only code_execution 609 toolset enabled). tools_to_include = {"execute_code"}. 610 611 model_tools.py does: 612 sandbox_enabled = SANDBOX_ALLOWED_TOOLS & tools_to_include 613 dynamic_schema = build_execute_code_schema(sandbox_enabled) 614 615 SANDBOX_ALLOWED_TOOLS = {web_search, web_extract, read_file, write_file, 616 search_files, patch, terminal} 617 tools_to_include = {"execute_code"} 618 intersection = empty set 619 """ 620 # Simulate model_tools.py:233 621 tools_to_include = {"execute_code"} 622 sandbox_enabled = SANDBOX_ALLOWED_TOOLS & tools_to_include 623 624 self.assertEqual(sandbox_enabled, set(), 625 "Intersection should be empty when only execute_code is enabled") 626 627 schema = build_execute_code_schema(sandbox_enabled) 628 code_desc = schema["parameters"]["properties"]["code"]["description"] 629 self.assertNotIn("import , ...", code_desc, 630 "Bug: broken import syntax sent to the model") 631 632 def test_real_scenario_only_vision_enabled(self): 633 """Another real path: user runs `hermes tools code_execution,vision`. 634 635 tools_to_include = {"execute_code", "vision_analyze"} 636 SANDBOX_ALLOWED_TOOLS has neither, so intersection is empty. 637 """ 638 tools_to_include = {"execute_code", "vision_analyze"} 639 sandbox_enabled = SANDBOX_ALLOWED_TOOLS & tools_to_include 640 641 self.assertEqual(sandbox_enabled, set()) 642 643 schema = build_execute_code_schema(sandbox_enabled) 644 code_desc = schema["parameters"]["properties"]["code"]["description"] 645 self.assertNotIn("import , ...", code_desc) 646 647 def test_description_mentions_limits(self): 648 schema = build_execute_code_schema() 649 desc = schema["description"] 650 self.assertIn("5-minute timeout", desc) 651 self.assertIn("50KB", desc) 652 self.assertIn("50 tool calls", desc) 653 654 def test_description_mentions_helpers(self): 655 schema = build_execute_code_schema() 656 desc = schema["description"] 657 self.assertIn("json_parse", desc) 658 self.assertIn("shell_quote", desc) 659 self.assertIn("retry", desc) 660 661 def test_none_defaults_to_all_tools(self): 662 schema_none = build_execute_code_schema(None) 663 schema_all = build_execute_code_schema(SANDBOX_ALLOWED_TOOLS) 664 self.assertEqual(schema_none["description"], schema_all["description"]) 665 666 667 # --------------------------------------------------------------------------- 668 # Environment variable filtering (security critical) 669 # --------------------------------------------------------------------------- 670 671 @unittest.skipIf(sys.platform == "win32", "UDS not available on Windows") 672 class TestEnvVarFiltering(unittest.TestCase): 673 """Verify that execute_code filters environment variables correctly. 674 675 The child process should NOT receive API keys, tokens, or secrets. 676 It should receive safe vars like PATH, HOME, LANG, etc. 677 """ 678 679 def _get_child_env(self, extra_env=None): 680 """Run a script that dumps its environment and return the env dict.""" 681 code = ( 682 "import os, json\n" 683 "print(json.dumps(dict(os.environ)))\n" 684 ) 685 env_backup = os.environ.copy() 686 try: 687 if extra_env: 688 os.environ.update(extra_env) 689 with patch("model_tools.handle_function_call", return_value='{}'), \ 690 patch("tools.code_execution_tool._load_config", 691 return_value={"timeout": 10, "max_tool_calls": 50}): 692 raw = execute_code(code, task_id="test-env", 693 enabled_tools=list(SANDBOX_ALLOWED_TOOLS)) 694 finally: 695 os.environ.clear() 696 os.environ.update(env_backup) 697 698 result = json.loads(raw) 699 self.assertEqual(result["status"], "success", result.get("error", "")) 700 return json.loads(result["output"].strip()) 701 702 def test_api_keys_excluded(self): 703 child_env = self._get_child_env({ 704 "OPENAI_API_KEY": "sk-secret123", 705 "ANTHROPIC_API_KEY": "sk-ant-secret", 706 "FIRECRAWL_API_KEY": "fc-secret", 707 }) 708 self.assertNotIn("OPENAI_API_KEY", child_env) 709 self.assertNotIn("ANTHROPIC_API_KEY", child_env) 710 self.assertNotIn("FIRECRAWL_API_KEY", child_env) 711 712 def test_tokens_excluded(self): 713 child_env = self._get_child_env({ 714 "GITHUB_TOKEN": "ghp_secret", 715 "MODAL_TOKEN_ID": "tok-123", 716 "MODAL_TOKEN_SECRET": "tok-sec", 717 }) 718 self.assertNotIn("GITHUB_TOKEN", child_env) 719 self.assertNotIn("MODAL_TOKEN_ID", child_env) 720 self.assertNotIn("MODAL_TOKEN_SECRET", child_env) 721 722 def test_password_vars_excluded(self): 723 child_env = self._get_child_env({ 724 "DB_PASSWORD": "hunter2", 725 "MY_PASSWD": "secret", 726 "AUTH_CREDENTIAL": "cred", 727 }) 728 self.assertNotIn("DB_PASSWORD", child_env) 729 self.assertNotIn("MY_PASSWD", child_env) 730 self.assertNotIn("AUTH_CREDENTIAL", child_env) 731 732 def test_path_included(self): 733 child_env = self._get_child_env() 734 self.assertIn("PATH", child_env) 735 736 def test_home_included(self): 737 child_env = self._get_child_env() 738 self.assertIn("HOME", child_env) 739 740 def test_hermes_rpc_socket_injected(self): 741 child_env = self._get_child_env() 742 self.assertIn("HERMES_RPC_SOCKET", child_env) 743 744 def test_pythondontwritebytecode_set(self): 745 child_env = self._get_child_env() 746 self.assertEqual(child_env.get("PYTHONDONTWRITEBYTECODE"), "1") 747 748 def test_timezone_injected_when_set(self): 749 env_backup = os.environ.copy() 750 try: 751 os.environ["HERMES_TIMEZONE"] = "America/New_York" 752 child_env = self._get_child_env() 753 self.assertEqual(child_env.get("TZ"), "America/New_York") 754 finally: 755 os.environ.clear() 756 os.environ.update(env_backup) 757 758 def test_timezone_not_set_when_empty(self): 759 env_backup = os.environ.copy() 760 try: 761 os.environ.pop("HERMES_TIMEZONE", None) 762 child_env = self._get_child_env() 763 if "TZ" in child_env: 764 self.assertNotEqual(child_env["TZ"], "") 765 finally: 766 os.environ.clear() 767 os.environ.update(env_backup) 768 769 770 # --------------------------------------------------------------------------- 771 # execute_code edge cases 772 # --------------------------------------------------------------------------- 773 774 class TestExecuteCodeEdgeCases(unittest.TestCase): 775 776 def test_windows_returns_error(self): 777 """On Windows (or when SANDBOX_AVAILABLE is False), returns error JSON.""" 778 with patch("tools.code_execution_tool.SANDBOX_AVAILABLE", False): 779 result = json.loads(execute_code("print('hi')", task_id="test")) 780 self.assertIn("error", result) 781 self.assertIn("Windows", result["error"]) 782 783 def test_whitespace_only_code(self): 784 result = json.loads(execute_code(" \n\t ", task_id="test")) 785 self.assertIn("error", result) 786 self.assertIn("No code", result["error"]) 787 788 @unittest.skipIf(sys.platform == "win32", "UDS not available on Windows") 789 def test_none_enabled_tools_uses_all(self): 790 """When enabled_tools is None, all sandbox tools should be available.""" 791 code = ( 792 "from hermes_tools import terminal, web_search, read_file\n" 793 "print('all imports ok')\n" 794 ) 795 with patch("model_tools.handle_function_call", 796 return_value=json.dumps({"ok": True})): 797 result = json.loads(execute_code(code, task_id="test-none", 798 enabled_tools=None)) 799 self.assertEqual(result["status"], "success") 800 self.assertIn("all imports ok", result["output"]) 801 802 @unittest.skipIf(sys.platform == "win32", "UDS not available on Windows") 803 def test_empty_enabled_tools_uses_all(self): 804 """When enabled_tools is [] (empty), all sandbox tools should be available.""" 805 code = ( 806 "from hermes_tools import terminal, web_search\n" 807 "print('imports ok')\n" 808 ) 809 with patch("model_tools.handle_function_call", 810 return_value=json.dumps({"ok": True})): 811 result = json.loads(execute_code(code, task_id="test-empty", 812 enabled_tools=[])) 813 self.assertEqual(result["status"], "success") 814 self.assertIn("imports ok", result["output"]) 815 816 @unittest.skipIf(sys.platform == "win32", "UDS not available on Windows") 817 def test_nonoverlapping_tools_fallback(self): 818 """When enabled_tools has no overlap with SANDBOX_ALLOWED_TOOLS, 819 should fall back to all allowed tools.""" 820 code = ( 821 "from hermes_tools import terminal\n" 822 "print('fallback ok')\n" 823 ) 824 with patch("model_tools.handle_function_call", 825 return_value=json.dumps({"ok": True})): 826 result = json.loads(execute_code( 827 code, task_id="test-nonoverlap", 828 enabled_tools=["vision_analyze", "browser_snapshot"], 829 )) 830 self.assertEqual(result["status"], "success") 831 self.assertIn("fallback ok", result["output"]) 832 833 834 # --------------------------------------------------------------------------- 835 # _load_config 836 # --------------------------------------------------------------------------- 837 838 class TestLoadConfig(unittest.TestCase): 839 def test_returns_empty_dict_when_cli_config_unavailable(self): 840 from tools.code_execution_tool import _load_config 841 with patch.dict("sys.modules", {"cli": None}): 842 result = _load_config() 843 self.assertIsInstance(result, dict) 844 845 def test_returns_code_execution_section(self): 846 from tools.code_execution_tool import _load_config 847 with patch("hermes_cli.config.read_raw_config", 848 return_value={"code_execution": {"timeout": 120, "max_tool_calls": 10}}): 849 result = _load_config() 850 self.assertEqual(result, {"timeout": 120, "max_tool_calls": 10}) 851 852 def test_does_not_import_interactive_cli(self): 853 from tools.code_execution_tool import _load_config 854 mock_cli = MagicMock() 855 mock_cli.CLI_CONFIG = {"code_execution": {"timeout": 999}} 856 with patch.dict("sys.modules", {"cli": mock_cli}), \ 857 patch("hermes_cli.config.read_raw_config", return_value={}): 858 result = _load_config() 859 self.assertEqual(result, {}) 860 861 862 # --------------------------------------------------------------------------- 863 # Interrupt event 864 # --------------------------------------------------------------------------- 865 866 @unittest.skipIf(sys.platform == "win32", "UDS not available on Windows") 867 class TestInterruptHandling(unittest.TestCase): 868 def test_interrupt_event_stops_execution(self): 869 """When interrupt is set for the execution thread, execute_code should stop.""" 870 code = "import time; time.sleep(60); print('should not reach')" 871 from tools.interrupt import set_interrupt 872 873 # Capture the main thread ID so we can target the interrupt correctly. 874 # execute_code runs in the current thread; set_interrupt needs its ID. 875 main_tid = threading.current_thread().ident 876 877 def set_interrupt_after_delay(): 878 import time as _t 879 _t.sleep(1) 880 set_interrupt(True, main_tid) 881 882 t = threading.Thread(target=set_interrupt_after_delay, daemon=True) 883 t.start() 884 885 try: 886 with patch("model_tools.handle_function_call", 887 return_value=json.dumps({"ok": True})), \ 888 patch("tools.code_execution_tool._load_config", 889 return_value={"timeout": 30, "max_tool_calls": 50}): 890 result = json.loads(execute_code( 891 code, task_id="test-interrupt", 892 enabled_tools=list(SANDBOX_ALLOWED_TOOLS), 893 )) 894 self.assertEqual(result["status"], "interrupted") 895 self.assertIn("interrupted", result["output"]) 896 finally: 897 set_interrupt(False, main_tid) 898 t.join(timeout=3) 899 900 901 class TestHeadTailTruncation(unittest.TestCase): 902 """Tests for head+tail truncation of large stdout in execute_code.""" 903 904 def _run(self, code): 905 with patch("model_tools.handle_function_call", side_effect=_mock_handle_function_call): 906 result = execute_code( 907 code=code, 908 task_id="test-task", 909 enabled_tools=list(SANDBOX_ALLOWED_TOOLS), 910 ) 911 return json.loads(result) 912 913 def test_short_output_not_truncated(self): 914 """Output under MAX_STDOUT_BYTES should not be truncated.""" 915 result = self._run('print("small output")') 916 self.assertEqual(result["status"], "success") 917 self.assertIn("small output", result["output"]) 918 self.assertNotIn("TRUNCATED", result["output"]) 919 920 def test_large_output_preserves_head_and_tail(self): 921 """Output exceeding MAX_STDOUT_BYTES keeps both head and tail.""" 922 code = ''' 923 # Print HEAD marker, then filler, then TAIL marker 924 print("HEAD_MARKER_START") 925 for i in range(15000): 926 print(f"filler_line_{i:06d}_padding_to_fill_buffer") 927 print("TAIL_MARKER_END") 928 ''' 929 result = self._run(code) 930 self.assertEqual(result["status"], "success") 931 output = result["output"] 932 # Head should be preserved 933 self.assertIn("HEAD_MARKER_START", output) 934 # Tail should be preserved (this is the key improvement) 935 self.assertIn("TAIL_MARKER_END", output) 936 # Truncation notice should be present 937 self.assertIn("TRUNCATED", output) 938 939 def test_truncation_notice_format(self): 940 """Truncation notice includes character counts.""" 941 code = ''' 942 for i in range(15000): 943 print(f"padding_line_{i:06d}_xxxxxxxxxxxxxxxxxxxxxxxxxx") 944 ''' 945 result = self._run(code) 946 output = result["output"] 947 if "TRUNCATED" in output: 948 self.assertIn("chars omitted", output) 949 self.assertIn("total", output) 950 951 952 if __name__ == "__main__": 953 unittest.main()