test_managed_modal_environment.py
1 import json 2 import sys 3 import tempfile 4 import threading 5 import types 6 from importlib.util import module_from_spec, spec_from_file_location 7 from pathlib import Path 8 9 import pytest 10 11 12 TOOLS_DIR = Path(__file__).resolve().parents[2] / "tools" 13 14 15 def _load_tool_module(module_name: str, filename: str): 16 spec = spec_from_file_location(module_name, TOOLS_DIR / filename) 17 assert spec and spec.loader 18 module = module_from_spec(spec) 19 sys.modules[module_name] = module 20 spec.loader.exec_module(module) 21 return module 22 23 24 def _reset_modules(prefixes: tuple[str, ...]): 25 for name in list(sys.modules): 26 if name.startswith(prefixes): 27 sys.modules.pop(name, None) 28 29 30 @pytest.fixture(autouse=True) 31 def _restore_tool_and_agent_modules(): 32 """Save and restore sys.modules entries so fakes don't leak to other tests.""" 33 original_modules = { 34 name: module 35 for name, module in sys.modules.items() 36 if name in ("tools", "agent", "hermes_cli") 37 or name.startswith("tools.") 38 or name.startswith("agent.") 39 or name.startswith("hermes_cli.") 40 } 41 try: 42 yield 43 finally: 44 _reset_modules(("tools", "agent", "hermes_cli")) 45 sys.modules.update(original_modules) 46 47 48 def _install_fake_tools_package(*, credential_mounts=None): 49 _reset_modules(("tools", "agent", "hermes_cli")) 50 51 hermes_cli = types.ModuleType("hermes_cli") 52 hermes_cli.__path__ = [] # type: ignore[attr-defined] 53 sys.modules["hermes_cli"] = hermes_cli 54 sys.modules["hermes_cli.config"] = types.SimpleNamespace( 55 get_hermes_home=lambda: Path(tempfile.gettempdir()) / "hermes-home", 56 ) 57 58 tools_package = types.ModuleType("tools") 59 tools_package.__path__ = [str(TOOLS_DIR)] # type: ignore[attr-defined] 60 sys.modules["tools"] = tools_package 61 62 env_package = types.ModuleType("tools.environments") 63 env_package.__path__ = [str(TOOLS_DIR / "environments")] # type: ignore[attr-defined] 64 sys.modules["tools.environments"] = env_package 65 66 interrupt_event = threading.Event() 67 sys.modules["tools.interrupt"] = types.SimpleNamespace( 68 set_interrupt=lambda value=True: interrupt_event.set() if value else interrupt_event.clear(), 69 is_interrupted=lambda: interrupt_event.is_set(), 70 _interrupt_event=interrupt_event, 71 ) 72 73 class _DummyBaseEnvironment: 74 def __init__(self, cwd: str, timeout: int, env=None): 75 self.cwd = cwd 76 self.timeout = timeout 77 self.env = env or {} 78 79 def _prepare_command(self, command: str): 80 return command, None 81 82 sys.modules["tools.environments.base"] = types.SimpleNamespace(BaseEnvironment=_DummyBaseEnvironment) 83 sys.modules["tools.managed_tool_gateway"] = types.SimpleNamespace( 84 resolve_managed_tool_gateway=lambda vendor: types.SimpleNamespace( 85 vendor=vendor, 86 gateway_origin="https://modal-gateway.example.com", 87 nous_user_token="user-token", 88 managed_mode=True, 89 ) 90 ) 91 sys.modules["tools.credential_files"] = types.SimpleNamespace( 92 get_credential_file_mounts=lambda: list(credential_mounts or []), 93 ) 94 95 return interrupt_event 96 97 98 class _FakeResponse: 99 def __init__(self, status_code: int, payload=None, text: str = ""): 100 self.status_code = status_code 101 self._payload = payload 102 self.text = text 103 104 def json(self): 105 if isinstance(self._payload, Exception): 106 raise self._payload 107 return self._payload 108 109 110 def test_managed_modal_execute_polls_until_completed(monkeypatch): 111 _install_fake_tools_package() 112 managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py") 113 modal_common = sys.modules["tools.environments.modal_utils"] 114 115 calls = [] 116 poll_count = {"value": 0} 117 118 def fake_request(method, url, headers=None, json=None, timeout=None): 119 calls.append((method, url, json, timeout)) 120 if method == "POST" and url.endswith("/v1/sandboxes"): 121 return _FakeResponse(200, {"id": "sandbox-1"}) 122 if method == "POST" and url.endswith("/execs"): 123 return _FakeResponse(202, {"execId": json["execId"], "status": "running"}) 124 if method == "GET" and "/execs/" in url: 125 poll_count["value"] += 1 126 if poll_count["value"] == 1: 127 return _FakeResponse(200, {"execId": url.rsplit("/", 1)[-1], "status": "running"}) 128 return _FakeResponse(200, { 129 "execId": url.rsplit("/", 1)[-1], 130 "status": "completed", 131 "output": "hello", 132 "returncode": 0, 133 }) 134 if method == "POST" and url.endswith("/terminate"): 135 return _FakeResponse(200, {"status": "terminated"}) 136 raise AssertionError(f"Unexpected request: {method} {url}") 137 138 monkeypatch.setattr(managed_modal.requests, "request", fake_request) 139 monkeypatch.setattr(modal_common.time, "sleep", lambda _: None) 140 141 env = managed_modal.ManagedModalEnvironment(image="python:3.11") 142 result = env.execute("echo hello") 143 env.cleanup() 144 145 assert result == {"output": "hello", "returncode": 0} 146 assert any(call[0] == "POST" and call[1].endswith("/execs") for call in calls) 147 148 149 def test_managed_modal_create_sends_a_stable_idempotency_key(monkeypatch): 150 _install_fake_tools_package() 151 managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py") 152 153 create_headers = [] 154 155 def fake_request(method, url, headers=None, json=None, timeout=None): 156 if method == "POST" and url.endswith("/v1/sandboxes"): 157 create_headers.append(headers or {}) 158 return _FakeResponse(200, {"id": "sandbox-1"}) 159 if method == "POST" and url.endswith("/terminate"): 160 return _FakeResponse(200, {"status": "terminated"}) 161 raise AssertionError(f"Unexpected request: {method} {url}") 162 163 monkeypatch.setattr(managed_modal.requests, "request", fake_request) 164 165 env = managed_modal.ManagedModalEnvironment(image="python:3.11") 166 env.cleanup() 167 168 assert len(create_headers) == 1 169 assert isinstance(create_headers[0].get("x-idempotency-key"), str) 170 assert create_headers[0]["x-idempotency-key"] 171 172 173 def test_managed_modal_execute_cancels_on_interrupt(monkeypatch): 174 interrupt_event = _install_fake_tools_package() 175 managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py") 176 modal_common = sys.modules["tools.environments.modal_utils"] 177 178 calls = [] 179 180 def fake_request(method, url, headers=None, json=None, timeout=None): 181 calls.append((method, url, json, timeout)) 182 if method == "POST" and url.endswith("/v1/sandboxes"): 183 return _FakeResponse(200, {"id": "sandbox-1"}) 184 if method == "POST" and url.endswith("/execs"): 185 return _FakeResponse(202, {"execId": json["execId"], "status": "running"}) 186 if method == "GET" and "/execs/" in url: 187 return _FakeResponse(200, {"execId": url.rsplit("/", 1)[-1], "status": "running"}) 188 if method == "POST" and url.endswith("/cancel"): 189 return _FakeResponse(202, {"status": "cancelling"}) 190 if method == "POST" and url.endswith("/terminate"): 191 return _FakeResponse(200, {"status": "terminated"}) 192 raise AssertionError(f"Unexpected request: {method} {url}") 193 194 def fake_sleep(_seconds): 195 interrupt_event.set() 196 197 monkeypatch.setattr(managed_modal.requests, "request", fake_request) 198 monkeypatch.setattr(modal_common.time, "sleep", fake_sleep) 199 200 env = managed_modal.ManagedModalEnvironment(image="python:3.11") 201 result = env.execute("sleep 30") 202 env.cleanup() 203 204 assert result == { 205 "output": "[Command interrupted - Modal sandbox exec cancelled]", 206 "returncode": 130, 207 } 208 assert any(call[0] == "POST" and call[1].endswith("/cancel") for call in calls) 209 poll_calls = [call for call in calls if call[0] == "GET" and "/execs/" in call[1]] 210 cancel_calls = [call for call in calls if call[0] == "POST" and call[1].endswith("/cancel")] 211 assert poll_calls[0][3] == (1.0, 5.0) 212 assert cancel_calls[0][3] == (1.0, 5.0) 213 214 215 def test_managed_modal_execute_returns_descriptive_error_on_missing_exec(monkeypatch): 216 _install_fake_tools_package() 217 managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py") 218 modal_common = sys.modules["tools.environments.modal_utils"] 219 220 def fake_request(method, url, headers=None, json=None, timeout=None): 221 if method == "POST" and url.endswith("/v1/sandboxes"): 222 return _FakeResponse(200, {"id": "sandbox-1"}) 223 if method == "POST" and url.endswith("/execs"): 224 return _FakeResponse(202, {"execId": json["execId"], "status": "running"}) 225 if method == "GET" and "/execs/" in url: 226 return _FakeResponse(404, {"error": "not found"}, text="not found") 227 if method == "POST" and url.endswith("/terminate"): 228 return _FakeResponse(200, {"status": "terminated"}) 229 raise AssertionError(f"Unexpected request: {method} {url}") 230 231 monkeypatch.setattr(managed_modal.requests, "request", fake_request) 232 monkeypatch.setattr(modal_common.time, "sleep", lambda _: None) 233 234 env = managed_modal.ManagedModalEnvironment(image="python:3.11") 235 result = env.execute("echo hello") 236 env.cleanup() 237 238 assert result["returncode"] == 1 239 assert "not found" in result["output"].lower() 240 241 242 def test_managed_modal_create_and_cleanup_preserve_gateway_persistence_fields(monkeypatch): 243 _install_fake_tools_package() 244 managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py") 245 246 create_payloads = [] 247 terminate_payloads = [] 248 249 def fake_request(method, url, headers=None, json=None, timeout=None): 250 if method == "POST" and url.endswith("/v1/sandboxes"): 251 create_payloads.append(json) 252 return _FakeResponse(200, {"id": "sandbox-1"}) 253 if method == "POST" and url.endswith("/terminate"): 254 terminate_payloads.append(json) 255 return _FakeResponse(200, {"status": "terminated"}) 256 raise AssertionError(f"Unexpected request: {method} {url}") 257 258 monkeypatch.setattr(managed_modal.requests, "request", fake_request) 259 260 env = managed_modal.ManagedModalEnvironment( 261 image="python:3.11", 262 task_id="task-managed-persist", 263 persistent_filesystem=False, 264 ) 265 env.cleanup() 266 267 assert create_payloads == [{ 268 "image": "python:3.11", 269 "cwd": "/root", 270 "cpu": 1.0, 271 "memoryMiB": 5120.0, 272 "timeoutMs": 3_600_000, 273 "idleTimeoutMs": 300_000, 274 "persistentFilesystem": False, 275 "logicalKey": "task-managed-persist", 276 }] 277 assert terminate_payloads == [{"snapshotBeforeTerminate": False}] 278 279 280 def test_managed_modal_rejects_host_credential_passthrough(): 281 _install_fake_tools_package( 282 credential_mounts=[{ 283 "host_path": "/tmp/token.json", 284 "container_path": "/root/.hermes/token.json", 285 }] 286 ) 287 managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py") 288 289 with pytest.raises(ValueError, match="credential-file passthrough"): 290 managed_modal.ManagedModalEnvironment(image="python:3.11") 291 292 293 def test_managed_modal_execute_times_out_and_cancels(monkeypatch): 294 _install_fake_tools_package() 295 managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py") 296 modal_common = sys.modules["tools.environments.modal_utils"] 297 298 calls = [] 299 monotonic_values = iter([0.0, 0.0, 0.0, 12.5, 12.5]) 300 301 def fake_request(method, url, headers=None, json=None, timeout=None): 302 calls.append((method, url, json, timeout)) 303 if method == "POST" and url.endswith("/v1/sandboxes"): 304 return _FakeResponse(200, {"id": "sandbox-1"}) 305 if method == "POST" and url.endswith("/execs"): 306 return _FakeResponse(202, {"execId": json["execId"], "status": "running"}) 307 if method == "GET" and "/execs/" in url: 308 return _FakeResponse(200, {"execId": url.rsplit("/", 1)[-1], "status": "running"}) 309 if method == "POST" and url.endswith("/cancel"): 310 return _FakeResponse(202, {"status": "cancelling"}) 311 if method == "POST" and url.endswith("/terminate"): 312 return _FakeResponse(200, {"status": "terminated"}) 313 raise AssertionError(f"Unexpected request: {method} {url}") 314 315 monkeypatch.setattr(managed_modal.requests, "request", fake_request) 316 monkeypatch.setattr(modal_common.time, "monotonic", lambda: next(monotonic_values)) 317 monkeypatch.setattr(modal_common.time, "sleep", lambda _: None) 318 319 env = managed_modal.ManagedModalEnvironment(image="python:3.11") 320 result = env.execute("sleep 30", timeout=2) 321 env.cleanup() 322 323 assert result == { 324 "output": "Managed Modal exec timed out after 2s", 325 "returncode": 124, 326 } 327 assert any(call[0] == "POST" and call[1].endswith("/cancel") for call in calls)