/ tests / tools / test_code_execution_modes.py
test_code_execution_modes.py
  1  #!/usr/bin/env python3
  2  """Tests for execute_code's strict / project execution modes.
  3  
  4  The mode switch controls two things:
  5    - working directory: staging tmpdir (strict) vs session CWD (project)
  6    - interpreter:       sys.executable (strict) vs active venv's python (project)
  7  
  8  Security-critical invariants — env scrubbing, tool whitelist, resource caps —
  9  must apply identically in both modes. These tests guard all three layers.
 10  
 11  Mode is sourced exclusively from ``code_execution.mode`` in config.yaml —
 12  there is no env-var override. Tests patch ``_load_config`` directly.
 13  """
 14  
 15  import json
 16  import os
 17  import sys
 18  import unittest
 19  from contextlib import contextmanager
 20  from unittest.mock import patch
 21  
 22  import pytest
 23  
 24  os.environ["TERMINAL_ENV"] = "local"
 25  
 26  
 27  @pytest.fixture(autouse=True)
 28  def _force_local_terminal(monkeypatch):
 29      """Mirror test_code_execution.py — guarantee local backend under xdist."""
 30      monkeypatch.setenv("TERMINAL_ENV", "local")
 31  
 32  
 33  from tools.code_execution_tool import (
 34      SANDBOX_ALLOWED_TOOLS,
 35      DEFAULT_EXECUTION_MODE,
 36      EXECUTION_MODES,
 37      _get_execution_mode,
 38      _is_usable_python,
 39      _resolve_child_cwd,
 40      _resolve_child_python,
 41      build_execute_code_schema,
 42      execute_code,
 43  )
 44  
 45  
 46  @contextmanager
 47  def _mock_mode(mode):
 48      """Context manager that pins code_execution.mode to the given value."""
 49      with patch("tools.code_execution_tool._load_config",
 50                 return_value={"mode": mode}):
 51          yield
 52  
 53  
 54  def _mock_handle_function_call(function_name, function_args, task_id=None, user_task=None):
 55      """Minimal mock dispatcher reused across tests."""
 56      if function_name == "terminal":
 57          return json.dumps({"output": "mock", "exit_code": 0})
 58      if function_name == "read_file":
 59          return json.dumps({"content": "line1\n", "total_lines": 1})
 60      return json.dumps({"error": f"Unknown tool: {function_name}"})
 61  
 62  
 63  # ---------------------------------------------------------------------------
 64  # Mode resolution
 65  # ---------------------------------------------------------------------------
 66  
 67  class TestGetExecutionMode(unittest.TestCase):
 68      """_get_execution_mode reads config.yaml only (no env var surface)."""
 69  
 70      def test_default_is_project(self):
 71          self.assertEqual(DEFAULT_EXECUTION_MODE, "project")
 72  
 73      def test_config_project(self):
 74          with patch("tools.code_execution_tool._load_config",
 75                     return_value={"mode": "project"}):
 76              self.assertEqual(_get_execution_mode(), "project")
 77  
 78      def test_config_strict(self):
 79          with patch("tools.code_execution_tool._load_config",
 80                     return_value={"mode": "strict"}):
 81              self.assertEqual(_get_execution_mode(), "strict")
 82  
 83      def test_config_case_insensitive(self):
 84          with patch("tools.code_execution_tool._load_config",
 85                     return_value={"mode": "STRICT"}):
 86              self.assertEqual(_get_execution_mode(), "strict")
 87  
 88      def test_config_strips_whitespace(self):
 89          with patch("tools.code_execution_tool._load_config",
 90                     return_value={"mode": "  project  "}):
 91              self.assertEqual(_get_execution_mode(), "project")
 92  
 93      def test_empty_config_falls_back_to_default(self):
 94          with patch("tools.code_execution_tool._load_config", return_value={}):
 95              self.assertEqual(_get_execution_mode(), DEFAULT_EXECUTION_MODE)
 96  
 97      def test_bogus_config_falls_back_to_default(self):
 98          with patch("tools.code_execution_tool._load_config",
 99                     return_value={"mode": "banana"}):
100              self.assertEqual(_get_execution_mode(), DEFAULT_EXECUTION_MODE)
101  
102      def test_none_config_falls_back_to_default(self):
103          with patch("tools.code_execution_tool._load_config",
104                     return_value={"mode": None}):
105              # str(None).lower() = "none" → not in EXECUTION_MODES → default
106              self.assertEqual(_get_execution_mode(), DEFAULT_EXECUTION_MODE)
107  
108      def test_execution_modes_tuple(self):
109          """Canonical set of modes — tests + config layer rely on this shape."""
110          self.assertEqual(set(EXECUTION_MODES), {"project", "strict"})
111  
112  
113  # ---------------------------------------------------------------------------
114  # Interpreter resolver
115  # ---------------------------------------------------------------------------
116  
117  class TestResolveChildPython(unittest.TestCase):
118      """_resolve_child_python — picks the right interpreter per mode."""
119  
120      def test_strict_always_sys_executable(self):
121          """Strict mode never leaves sys.executable, even if venv is set."""
122          with patch.dict(os.environ, {"VIRTUAL_ENV": "/some/venv"}):
123              self.assertEqual(_resolve_child_python("strict"), sys.executable)
124  
125      def test_project_with_no_venv_falls_back(self):
126          """Project mode without VIRTUAL_ENV or CONDA_PREFIX → sys.executable."""
127          env = {k: v for k, v in os.environ.items()
128                 if k not in ("VIRTUAL_ENV", "CONDA_PREFIX")}
129          with patch.dict(os.environ, env, clear=True):
130              self.assertEqual(_resolve_child_python("project"), sys.executable)
131  
132      def test_project_with_virtualenv_picks_venv_python(self):
133          """Project mode + VIRTUAL_ENV pointing at a real venv → that python."""
134          import tempfile, pathlib
135          with tempfile.TemporaryDirectory() as td:
136              fake_venv = pathlib.Path(td)
137              (fake_venv / "bin").mkdir()
138              # Symlink to real python so the version check actually passes
139              (fake_venv / "bin" / "python").symlink_to(sys.executable)
140              with patch.dict(os.environ, {"VIRTUAL_ENV": str(fake_venv)}):
141                  # Clear cache — _is_usable_python memoizes on path
142                  _is_usable_python.cache_clear()
143                  result = _resolve_child_python("project")
144                  self.assertEqual(result, str(fake_venv / "bin" / "python"))
145  
146      def test_project_with_broken_venv_falls_back(self):
147          """VIRTUAL_ENV set but bin/python missing → sys.executable."""
148          import tempfile
149          with tempfile.TemporaryDirectory() as td:
150              # No bin/python inside — broken venv
151              with patch.dict(os.environ, {"VIRTUAL_ENV": td}):
152                  _is_usable_python.cache_clear()
153                  self.assertEqual(_resolve_child_python("project"), sys.executable)
154  
155      def test_project_prefers_virtualenv_over_conda(self):
156          """If both VIRTUAL_ENV and CONDA_PREFIX are set, VIRTUAL_ENV wins."""
157          import tempfile, pathlib
158          with tempfile.TemporaryDirectory() as ve_td, tempfile.TemporaryDirectory() as conda_td:
159              ve = pathlib.Path(ve_td)
160              (ve / "bin").mkdir()
161              (ve / "bin" / "python").symlink_to(sys.executable)
162  
163              conda = pathlib.Path(conda_td)
164              (conda / "bin").mkdir()
165              (conda / "bin" / "python").symlink_to(sys.executable)
166  
167              with patch.dict(os.environ, {"VIRTUAL_ENV": str(ve), "CONDA_PREFIX": str(conda)}):
168                  _is_usable_python.cache_clear()
169                  result = _resolve_child_python("project")
170                  self.assertEqual(result, str(ve / "bin" / "python"))
171  
172      def test_is_usable_python_rejects_nonexistent(self):
173          _is_usable_python.cache_clear()
174          self.assertFalse(_is_usable_python("/does/not/exist/python"))
175  
176      def test_is_usable_python_accepts_real_python(self):
177          _is_usable_python.cache_clear()
178          self.assertTrue(_is_usable_python(sys.executable))
179  
180  
181  # ---------------------------------------------------------------------------
182  # CWD resolver
183  # ---------------------------------------------------------------------------
184  
185  class TestResolveChildCwd(unittest.TestCase):
186  
187      def test_strict_uses_staging_dir(self):
188          self.assertEqual(_resolve_child_cwd("strict", "/tmp/staging"), "/tmp/staging")
189  
190      def test_project_without_terminal_cwd_uses_getcwd(self):
191          env = {k: v for k, v in os.environ.items() if k != "TERMINAL_CWD"}
192          with patch.dict(os.environ, env, clear=True):
193              self.assertEqual(_resolve_child_cwd("project", "/tmp/staging"), os.getcwd())
194  
195      def test_project_uses_terminal_cwd_when_set(self):
196          import tempfile
197          with tempfile.TemporaryDirectory() as td:
198              with patch.dict(os.environ, {"TERMINAL_CWD": td}):
199                  self.assertEqual(_resolve_child_cwd("project", "/tmp/staging"), td)
200  
201      def test_project_bogus_terminal_cwd_falls_back_to_getcwd(self):
202          with patch.dict(os.environ, {"TERMINAL_CWD": "/does/not/exist/anywhere"}):
203              self.assertEqual(_resolve_child_cwd("project", "/tmp/staging"), os.getcwd())
204  
205      def test_project_expands_tilde(self):
206          import pathlib
207          home = str(pathlib.Path.home())
208          with patch.dict(os.environ, {"TERMINAL_CWD": "~"}):
209              self.assertEqual(_resolve_child_cwd("project", "/tmp/staging"), home)
210  
211  
212  # ---------------------------------------------------------------------------
213  # Schema description
214  # ---------------------------------------------------------------------------
215  
216  class TestModeAwareSchema(unittest.TestCase):
217  
218      def test_strict_description_mentions_temp_dir(self):
219          desc = build_execute_code_schema(mode="strict")["description"]
220          self.assertIn("temp dir", desc)
221  
222      def test_project_description_mentions_session_and_venv(self):
223          desc = build_execute_code_schema(mode="project")["description"]
224          self.assertIn("session", desc)
225          self.assertIn("venv", desc)
226  
227      def test_neither_description_uses_sandbox_language(self):
228          """REGRESSION GUARD for commit 39b83f34.
229  
230          Agents on local backends falsely believed they were sandboxed and
231          refused networking tasks. Do not reintroduce any 'sandbox' /
232          'isolated' / 'cloud' language in the tool description.
233          """
234          for mode in EXECUTION_MODES:
235              desc = build_execute_code_schema(mode=mode)["description"].lower()
236              for forbidden in ("sandbox", "isolated", "cloud"):
237                  self.assertNotIn(forbidden, desc,
238                                   f"mode={mode}: '{forbidden}' leaked into description")
239  
240      def test_descriptions_are_similar_length(self):
241          """Both modes should have roughly the same-size description."""
242          strict = len(build_execute_code_schema(mode="strict")["description"])
243          project = len(build_execute_code_schema(mode="project")["description"])
244          self.assertLess(abs(strict - project), 200)
245  
246      def test_default_mode_reads_config(self):
247          """build_execute_code_schema() with mode=None reads config.yaml."""
248          with _mock_mode("strict"):
249              desc = build_execute_code_schema()["description"]
250              self.assertIn("temp dir", desc)
251          with _mock_mode("project"):
252              desc = build_execute_code_schema()["description"]
253              self.assertIn("session", desc)
254  
255  
256  # ---------------------------------------------------------------------------
257  # Integration: what actually happens when execute_code runs per mode
258  # ---------------------------------------------------------------------------
259  
260  @pytest.mark.skipif(sys.platform == "win32", reason="execute_code is POSIX-only")
261  class TestExecuteCodeModeIntegration(unittest.TestCase):
262      """End-to-end: verify the subprocess actually runs where we expect."""
263  
264      def _run(self, code, mode, enabled_tools=None, extra_env=None):
265          env_overrides = extra_env or {}
266          with _mock_mode(mode):
267              with patch.dict(os.environ, env_overrides):
268                  with patch("model_tools.handle_function_call",
269                             side_effect=_mock_handle_function_call):
270                      raw = execute_code(
271                          code=code,
272                          task_id=f"test-{mode}",
273                          enabled_tools=enabled_tools or list(SANDBOX_ALLOWED_TOOLS),
274                      )
275          return json.loads(raw)
276  
277      def test_strict_mode_runs_in_tmpdir(self):
278          """Strict mode: script's os.getcwd() is the staging tmpdir."""
279          result = self._run("import os; print(os.getcwd())", mode="strict")
280          self.assertEqual(result["status"], "success")
281          self.assertIn("hermes_sandbox_", result["output"])
282  
283      def test_project_mode_runs_in_session_cwd(self):
284          """Project mode: script's os.getcwd() is the session's working dir."""
285          import tempfile
286          with tempfile.TemporaryDirectory() as td:
287              result = self._run(
288                  "import os; print(os.getcwd())",
289                  mode="project",
290                  extra_env={"TERMINAL_CWD": td},
291              )
292              self.assertEqual(result["status"], "success")
293              # Resolve symlinks (macOS /tmp → /private/tmp) on both sides
294              self.assertEqual(
295                  os.path.realpath(result["output"].strip()),
296                  os.path.realpath(td),
297              )
298  
299      def test_project_mode_interpreter_is_venv_python(self):
300          """Project mode: sys.executable inside the child is the venv's python
301          when VIRTUAL_ENV is set to a real venv."""
302          # The hermes-agent venv is always active during tests, so this also
303          # happens to equal sys.executable of the parent. What we're asserting
304          # is: resolver picked a venv-bin/python path, not that it differs
305          # from sys.executable.
306          result = self._run("import sys; print(sys.executable)", mode="project")
307          self.assertEqual(result["status"], "success")
308          # Either VIRTUAL_ENV-bin/python or sys.executable fallback, both OK.
309          output = result["output"].strip()
310          ve = os.environ.get("VIRTUAL_ENV", "").strip()
311          if ve:
312              self.assertTrue(
313                  output.startswith(ve) or output == sys.executable,
314                  f"project-mode python should be under VIRTUAL_ENV={ve} or sys.executable={sys.executable}, got {output}",
315              )
316  
317      def test_project_mode_can_still_import_hermes_tools(self):
318          """Regression: hermes_tools still importable from non-tmpdir CWD.
319  
320          This is the PYTHONPATH fix — without it, switching to session CWD
321          breaks `from hermes_tools import terminal`.
322          """
323          import tempfile
324          with tempfile.TemporaryDirectory() as td:
325              code = (
326                  "from hermes_tools import terminal\n"
327                  "r = terminal('echo x')\n"
328                  "print(r.get('output', 'MISSING'))\n"
329              )
330              result = self._run(code, mode="project", extra_env={"TERMINAL_CWD": td})
331              self.assertEqual(result["status"], "success")
332              self.assertIn("mock", result["output"])
333  
334      def test_strict_mode_can_still_import_hermes_tools(self):
335          """Regression: strict mode's tmpdir CWD still works for imports."""
336          code = (
337              "from hermes_tools import terminal\n"
338              "r = terminal('echo x')\n"
339              "print(r.get('output', 'MISSING'))\n"
340          )
341          result = self._run(code, mode="strict")
342          self.assertEqual(result["status"], "success")
343          self.assertIn("mock", result["output"])
344  
345  
346  # ---------------------------------------------------------------------------
347  # SECURITY-CRITICAL regression guards
348  #
349  # These MUST pass in both strict and project mode. The whole tiered-mode
350  # proposition rests on the claim that switching from strict to project only
351  # changes CWD + interpreter, not the security posture.
352  # ---------------------------------------------------------------------------
353  
354  @pytest.mark.skipif(sys.platform == "win32", reason="execute_code is POSIX-only")
355  class TestSecurityInvariantsAcrossModes(unittest.TestCase):
356  
357      def _run(self, code, mode):
358          with _mock_mode(mode):
359              with patch("model_tools.handle_function_call",
360                         side_effect=_mock_handle_function_call):
361                  raw = execute_code(
362                      code=code,
363                      task_id=f"test-sec-{mode}",
364                      enabled_tools=list(SANDBOX_ALLOWED_TOOLS),
365                  )
366          return json.loads(raw)
367  
368      def test_api_keys_scrubbed_in_strict_mode(self):
369          code = (
370              "import os\n"
371              "print('KEY=' + os.environ.get('OPENAI_API_KEY', 'MISSING'))\n"
372              "print('TOK=' + os.environ.get('ANTHROPIC_API_KEY', 'MISSING'))\n"
373          )
374          with patch.dict(os.environ, {
375              "OPENAI_API_KEY": "sk-should-not-leak",
376              "ANTHROPIC_API_KEY": "ant-should-not-leak",
377          }):
378              result = self._run(code, mode="strict")
379          self.assertEqual(result["status"], "success")
380          self.assertIn("KEY=MISSING", result["output"])
381          self.assertIn("TOK=MISSING", result["output"])
382          self.assertNotIn("sk-should-not-leak", result["output"])
383          self.assertNotIn("ant-should-not-leak", result["output"])
384  
385      def test_api_keys_scrubbed_in_project_mode(self):
386          """CRITICAL: the project-mode default does NOT leak user credentials."""
387          code = (
388              "import os\n"
389              "print('KEY=' + os.environ.get('OPENAI_API_KEY', 'MISSING'))\n"
390              "print('TOK=' + os.environ.get('ANTHROPIC_API_KEY', 'MISSING'))\n"
391              "print('SEC=' + os.environ.get('GITHUB_TOKEN', 'MISSING'))\n"
392          )
393          with patch.dict(os.environ, {
394              "OPENAI_API_KEY": "sk-should-not-leak",
395              "ANTHROPIC_API_KEY": "ant-should-not-leak",
396              "GITHUB_TOKEN": "ghp-should-not-leak",
397          }):
398              result = self._run(code, mode="project")
399          self.assertEqual(result["status"], "success")
400          for needle in ("KEY=MISSING", "TOK=MISSING", "SEC=MISSING"):
401              self.assertIn(needle, result["output"])
402          for leaked in ("sk-should-not-leak", "ant-should-not-leak", "ghp-should-not-leak"):
403              self.assertNotIn(leaked, result["output"])
404  
405      def test_secret_substrings_scrubbed_in_project_mode(self):
406          """SECRET/PASSWORD/CREDENTIAL/PASSWD/AUTH filters still apply."""
407          code = (
408              "import os\n"
409              "for k in ('MY_SECRET', 'DB_PASSWORD', 'VAULT_CREDENTIAL', "
410              "'LDAP_PASSWD', 'AUTH_TOKEN'):\n"
411              "    print(f'{k}=' + os.environ.get(k, 'MISSING'))\n"
412          )
413          with patch.dict(os.environ, {
414              "MY_SECRET": "secret-should-not-leak",
415              "DB_PASSWORD": "password-should-not-leak",
416              "VAULT_CREDENTIAL": "cred-should-not-leak",
417              "LDAP_PASSWD": "passwd-should-not-leak",
418              "AUTH_TOKEN": "auth-should-not-leak",
419          }):
420              result = self._run(code, mode="project")
421          self.assertEqual(result["status"], "success")
422          for leaked in ("secret-should-not-leak", "password-should-not-leak",
423                         "cred-should-not-leak", "passwd-should-not-leak",
424                         "auth-should-not-leak"):
425              self.assertNotIn(leaked, result["output"])
426  
427      def test_tool_whitelist_enforced_in_strict_mode(self):
428          """A script cannot RPC-call tools outside SANDBOX_ALLOWED_TOOLS."""
429          # execute_code is NOT in SANDBOX_ALLOWED_TOOLS (no recursion)
430          self.assertNotIn("execute_code", SANDBOX_ALLOWED_TOOLS)
431          code = (
432              "import hermes_tools as ht\n"
433              "print('execute_code_available:', hasattr(ht, 'execute_code'))\n"
434              "print('delegate_task_available:', hasattr(ht, 'delegate_task'))\n"
435          )
436          result = self._run(code, mode="strict")
437          self.assertEqual(result["status"], "success")
438          self.assertIn("execute_code_available: False", result["output"])
439          self.assertIn("delegate_task_available: False", result["output"])
440  
441      def test_tool_whitelist_enforced_in_project_mode(self):
442          """CRITICAL: project mode does NOT widen the tool whitelist."""
443          code = (
444              "import hermes_tools as ht\n"
445              "print('execute_code_available:', hasattr(ht, 'execute_code'))\n"
446              "print('delegate_task_available:', hasattr(ht, 'delegate_task'))\n"
447          )
448          result = self._run(code, mode="project")
449          self.assertEqual(result["status"], "success")
450          self.assertIn("execute_code_available: False", result["output"])
451          self.assertIn("delegate_task_available: False", result["output"])
452  
453  
454  if __name__ == "__main__":
455      unittest.main()