/ tests / tools / test_managed_modal_environment.py
test_managed_modal_environment.py
  1  import json
  2  import sys
  3  import tempfile
  4  import threading
  5  import types
  6  from importlib.util import module_from_spec, spec_from_file_location
  7  from pathlib import Path
  8  
  9  import pytest
 10  
 11  
 12  TOOLS_DIR = Path(__file__).resolve().parents[2] / "tools"
 13  
 14  
 15  def _load_tool_module(module_name: str, filename: str):
 16      spec = spec_from_file_location(module_name, TOOLS_DIR / filename)
 17      assert spec and spec.loader
 18      module = module_from_spec(spec)
 19      sys.modules[module_name] = module
 20      spec.loader.exec_module(module)
 21      return module
 22  
 23  
 24  def _reset_modules(prefixes: tuple[str, ...]):
 25      for name in list(sys.modules):
 26          if name.startswith(prefixes):
 27              sys.modules.pop(name, None)
 28  
 29  
 30  @pytest.fixture(autouse=True)
 31  def _restore_tool_and_agent_modules():
 32      """Save and restore sys.modules entries so fakes don't leak to other tests."""
 33      original_modules = {
 34          name: module
 35          for name, module in sys.modules.items()
 36          if name in ("tools", "agent", "hermes_cli")
 37          or name.startswith("tools.")
 38          or name.startswith("agent.")
 39          or name.startswith("hermes_cli.")
 40      }
 41      try:
 42          yield
 43      finally:
 44          _reset_modules(("tools", "agent", "hermes_cli"))
 45          sys.modules.update(original_modules)
 46  
 47  
 48  def _install_fake_tools_package(*, credential_mounts=None):
 49      _reset_modules(("tools", "agent", "hermes_cli"))
 50  
 51      hermes_cli = types.ModuleType("hermes_cli")
 52      hermes_cli.__path__ = []  # type: ignore[attr-defined]
 53      sys.modules["hermes_cli"] = hermes_cli
 54      sys.modules["hermes_cli.config"] = types.SimpleNamespace(
 55          get_hermes_home=lambda: Path(tempfile.gettempdir()) / "hermes-home",
 56      )
 57  
 58      tools_package = types.ModuleType("tools")
 59      tools_package.__path__ = [str(TOOLS_DIR)]  # type: ignore[attr-defined]
 60      sys.modules["tools"] = tools_package
 61  
 62      env_package = types.ModuleType("tools.environments")
 63      env_package.__path__ = [str(TOOLS_DIR / "environments")]  # type: ignore[attr-defined]
 64      sys.modules["tools.environments"] = env_package
 65  
 66      interrupt_event = threading.Event()
 67      sys.modules["tools.interrupt"] = types.SimpleNamespace(
 68          set_interrupt=lambda value=True: interrupt_event.set() if value else interrupt_event.clear(),
 69          is_interrupted=lambda: interrupt_event.is_set(),
 70          _interrupt_event=interrupt_event,
 71      )
 72  
 73      class _DummyBaseEnvironment:
 74          def __init__(self, cwd: str, timeout: int, env=None):
 75              self.cwd = cwd
 76              self.timeout = timeout
 77              self.env = env or {}
 78  
 79          def _prepare_command(self, command: str):
 80              return command, None
 81  
 82      sys.modules["tools.environments.base"] = types.SimpleNamespace(BaseEnvironment=_DummyBaseEnvironment)
 83      sys.modules["tools.managed_tool_gateway"] = types.SimpleNamespace(
 84          resolve_managed_tool_gateway=lambda vendor: types.SimpleNamespace(
 85              vendor=vendor,
 86              gateway_origin="https://modal-gateway.example.com",
 87              nous_user_token="user-token",
 88              managed_mode=True,
 89          )
 90      )
 91      sys.modules["tools.credential_files"] = types.SimpleNamespace(
 92          get_credential_file_mounts=lambda: list(credential_mounts or []),
 93      )
 94  
 95      return interrupt_event
 96  
 97  
 98  class _FakeResponse:
 99      def __init__(self, status_code: int, payload=None, text: str = ""):
100          self.status_code = status_code
101          self._payload = payload
102          self.text = text
103  
104      def json(self):
105          if isinstance(self._payload, Exception):
106              raise self._payload
107          return self._payload
108  
109  
110  def test_managed_modal_execute_polls_until_completed(monkeypatch):
111      _install_fake_tools_package()
112      managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
113      modal_common = sys.modules["tools.environments.modal_utils"]
114  
115      calls = []
116      poll_count = {"value": 0}
117  
118      def fake_request(method, url, headers=None, json=None, timeout=None):
119          calls.append((method, url, json, timeout))
120          if method == "POST" and url.endswith("/v1/sandboxes"):
121              return _FakeResponse(200, {"id": "sandbox-1"})
122          if method == "POST" and url.endswith("/execs"):
123              return _FakeResponse(202, {"execId": json["execId"], "status": "running"})
124          if method == "GET" and "/execs/" in url:
125              poll_count["value"] += 1
126              if poll_count["value"] == 1:
127                  return _FakeResponse(200, {"execId": url.rsplit("/", 1)[-1], "status": "running"})
128              return _FakeResponse(200, {
129                  "execId": url.rsplit("/", 1)[-1],
130                  "status": "completed",
131                  "output": "hello",
132                  "returncode": 0,
133              })
134          if method == "POST" and url.endswith("/terminate"):
135              return _FakeResponse(200, {"status": "terminated"})
136          raise AssertionError(f"Unexpected request: {method} {url}")
137  
138      monkeypatch.setattr(managed_modal.requests, "request", fake_request)
139      monkeypatch.setattr(modal_common.time, "sleep", lambda _: None)
140  
141      env = managed_modal.ManagedModalEnvironment(image="python:3.11")
142      result = env.execute("echo hello")
143      env.cleanup()
144  
145      assert result == {"output": "hello", "returncode": 0}
146      assert any(call[0] == "POST" and call[1].endswith("/execs") for call in calls)
147  
148  
149  def test_managed_modal_create_sends_a_stable_idempotency_key(monkeypatch):
150      _install_fake_tools_package()
151      managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
152  
153      create_headers = []
154  
155      def fake_request(method, url, headers=None, json=None, timeout=None):
156          if method == "POST" and url.endswith("/v1/sandboxes"):
157              create_headers.append(headers or {})
158              return _FakeResponse(200, {"id": "sandbox-1"})
159          if method == "POST" and url.endswith("/terminate"):
160              return _FakeResponse(200, {"status": "terminated"})
161          raise AssertionError(f"Unexpected request: {method} {url}")
162  
163      monkeypatch.setattr(managed_modal.requests, "request", fake_request)
164  
165      env = managed_modal.ManagedModalEnvironment(image="python:3.11")
166      env.cleanup()
167  
168      assert len(create_headers) == 1
169      assert isinstance(create_headers[0].get("x-idempotency-key"), str)
170      assert create_headers[0]["x-idempotency-key"]
171  
172  
173  def test_managed_modal_execute_cancels_on_interrupt(monkeypatch):
174      interrupt_event = _install_fake_tools_package()
175      managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
176      modal_common = sys.modules["tools.environments.modal_utils"]
177  
178      calls = []
179  
180      def fake_request(method, url, headers=None, json=None, timeout=None):
181          calls.append((method, url, json, timeout))
182          if method == "POST" and url.endswith("/v1/sandboxes"):
183              return _FakeResponse(200, {"id": "sandbox-1"})
184          if method == "POST" and url.endswith("/execs"):
185              return _FakeResponse(202, {"execId": json["execId"], "status": "running"})
186          if method == "GET" and "/execs/" in url:
187              return _FakeResponse(200, {"execId": url.rsplit("/", 1)[-1], "status": "running"})
188          if method == "POST" and url.endswith("/cancel"):
189              return _FakeResponse(202, {"status": "cancelling"})
190          if method == "POST" and url.endswith("/terminate"):
191              return _FakeResponse(200, {"status": "terminated"})
192          raise AssertionError(f"Unexpected request: {method} {url}")
193  
194      def fake_sleep(_seconds):
195          interrupt_event.set()
196  
197      monkeypatch.setattr(managed_modal.requests, "request", fake_request)
198      monkeypatch.setattr(modal_common.time, "sleep", fake_sleep)
199  
200      env = managed_modal.ManagedModalEnvironment(image="python:3.11")
201      result = env.execute("sleep 30")
202      env.cleanup()
203  
204      assert result == {
205          "output": "[Command interrupted - Modal sandbox exec cancelled]",
206          "returncode": 130,
207      }
208      assert any(call[0] == "POST" and call[1].endswith("/cancel") for call in calls)
209      poll_calls = [call for call in calls if call[0] == "GET" and "/execs/" in call[1]]
210      cancel_calls = [call for call in calls if call[0] == "POST" and call[1].endswith("/cancel")]
211      assert poll_calls[0][3] == (1.0, 5.0)
212      assert cancel_calls[0][3] == (1.0, 5.0)
213  
214  
215  def test_managed_modal_execute_returns_descriptive_error_on_missing_exec(monkeypatch):
216      _install_fake_tools_package()
217      managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
218      modal_common = sys.modules["tools.environments.modal_utils"]
219  
220      def fake_request(method, url, headers=None, json=None, timeout=None):
221          if method == "POST" and url.endswith("/v1/sandboxes"):
222              return _FakeResponse(200, {"id": "sandbox-1"})
223          if method == "POST" and url.endswith("/execs"):
224              return _FakeResponse(202, {"execId": json["execId"], "status": "running"})
225          if method == "GET" and "/execs/" in url:
226              return _FakeResponse(404, {"error": "not found"}, text="not found")
227          if method == "POST" and url.endswith("/terminate"):
228              return _FakeResponse(200, {"status": "terminated"})
229          raise AssertionError(f"Unexpected request: {method} {url}")
230  
231      monkeypatch.setattr(managed_modal.requests, "request", fake_request)
232      monkeypatch.setattr(modal_common.time, "sleep", lambda _: None)
233  
234      env = managed_modal.ManagedModalEnvironment(image="python:3.11")
235      result = env.execute("echo hello")
236      env.cleanup()
237  
238      assert result["returncode"] == 1
239      assert "not found" in result["output"].lower()
240  
241  
242  def test_managed_modal_create_and_cleanup_preserve_gateway_persistence_fields(monkeypatch):
243      _install_fake_tools_package()
244      managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
245  
246      create_payloads = []
247      terminate_payloads = []
248  
249      def fake_request(method, url, headers=None, json=None, timeout=None):
250          if method == "POST" and url.endswith("/v1/sandboxes"):
251              create_payloads.append(json)
252              return _FakeResponse(200, {"id": "sandbox-1"})
253          if method == "POST" and url.endswith("/terminate"):
254              terminate_payloads.append(json)
255              return _FakeResponse(200, {"status": "terminated"})
256          raise AssertionError(f"Unexpected request: {method} {url}")
257  
258      monkeypatch.setattr(managed_modal.requests, "request", fake_request)
259  
260      env = managed_modal.ManagedModalEnvironment(
261          image="python:3.11",
262          task_id="task-managed-persist",
263          persistent_filesystem=False,
264      )
265      env.cleanup()
266  
267      assert create_payloads == [{
268          "image": "python:3.11",
269          "cwd": "/root",
270          "cpu": 1.0,
271          "memoryMiB": 5120.0,
272          "timeoutMs": 3_600_000,
273          "idleTimeoutMs": 300_000,
274          "persistentFilesystem": False,
275          "logicalKey": "task-managed-persist",
276      }]
277      assert terminate_payloads == [{"snapshotBeforeTerminate": False}]
278  
279  
280  def test_managed_modal_rejects_host_credential_passthrough():
281      _install_fake_tools_package(
282          credential_mounts=[{
283              "host_path": "/tmp/token.json",
284              "container_path": "/root/.hermes/token.json",
285          }]
286      )
287      managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
288  
289      with pytest.raises(ValueError, match="credential-file passthrough"):
290          managed_modal.ManagedModalEnvironment(image="python:3.11")
291  
292  
293  def test_managed_modal_execute_times_out_and_cancels(monkeypatch):
294      _install_fake_tools_package()
295      managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
296      modal_common = sys.modules["tools.environments.modal_utils"]
297  
298      calls = []
299      monotonic_values = iter([0.0, 0.0, 0.0, 12.5, 12.5])
300  
301      def fake_request(method, url, headers=None, json=None, timeout=None):
302          calls.append((method, url, json, timeout))
303          if method == "POST" and url.endswith("/v1/sandboxes"):
304              return _FakeResponse(200, {"id": "sandbox-1"})
305          if method == "POST" and url.endswith("/execs"):
306              return _FakeResponse(202, {"execId": json["execId"], "status": "running"})
307          if method == "GET" and "/execs/" in url:
308              return _FakeResponse(200, {"execId": url.rsplit("/", 1)[-1], "status": "running"})
309          if method == "POST" and url.endswith("/cancel"):
310              return _FakeResponse(202, {"status": "cancelling"})
311          if method == "POST" and url.endswith("/terminate"):
312              return _FakeResponse(200, {"status": "terminated"})
313          raise AssertionError(f"Unexpected request: {method} {url}")
314  
315      monkeypatch.setattr(managed_modal.requests, "request", fake_request)
316      monkeypatch.setattr(modal_common.time, "monotonic", lambda: next(monotonic_values))
317      monkeypatch.setattr(modal_common.time, "sleep", lambda _: None)
318  
319      env = managed_modal.ManagedModalEnvironment(image="python:3.11")
320      result = env.execute("sleep 30", timeout=2)
321      env.cleanup()
322  
323      assert result == {
324          "output": "Managed Modal exec timed out after 2s",
325          "returncode": 124,
326      }
327      assert any(call[0] == "POST" and call[1].endswith("/cancel") for call in calls)