/ tests / run_agent / test_concurrent_interrupt.py
test_concurrent_interrupt.py
  1  """Tests for interrupt handling in concurrent tool execution."""
  2  
  3  import concurrent.futures
  4  import threading
  5  import time
  6  from unittest.mock import MagicMock, patch
  7  
  8  import pytest
  9  
 10  
 11  @pytest.fixture(autouse=True)
 12  def _isolate_hermes(tmp_path, monkeypatch):
 13      monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
 14      (tmp_path / ".hermes").mkdir(exist_ok=True)
 15  
 16  
 17  def _make_agent(monkeypatch):
 18      """Create a minimal AIAgent-like object with just the methods under test."""
 19      monkeypatch.setenv("OPENROUTER_API_KEY", "")
 20      monkeypatch.setenv("HERMES_INFERENCE_PROVIDER", "")
 21      # Avoid full AIAgent init — just import the class and build a stub
 22      import run_agent as _ra
 23  
 24      class _Stub:
 25          _interrupt_requested = False
 26          _interrupt_message = None
 27          # Bind to this thread's ident so interrupt() targets a real tid.
 28          _execution_thread_id = threading.current_thread().ident
 29          _interrupt_thread_signal_pending = False
 30          log_prefix = ""
 31          quiet_mode = True
 32          verbose_logging = False
 33          log_prefix_chars = 200
 34          _checkpoint_mgr = MagicMock(enabled=False)
 35          _subdirectory_hints = MagicMock()
 36          tool_progress_callback = None
 37          tool_start_callback = None
 38          tool_complete_callback = None
 39          _todo_store = MagicMock()
 40          _session_db = None
 41          valid_tool_names = set()
 42          _turns_since_memory = 0
 43          _iters_since_skill = 0
 44          _current_tool = None
 45          _last_activity = 0
 46          _print_fn = print
 47          # Worker-thread tracking state mirrored from AIAgent.__init__ so the
 48          # real interrupt() method can fan out to concurrent-tool workers.
 49          _active_children: list = []
 50  
 51          def __init__(self):
 52              # Instance-level (not class-level) so each test gets a fresh set.
 53              self._tool_worker_threads: set = set()
 54              self._tool_worker_threads_lock = threading.Lock()
 55              self._active_children_lock = threading.Lock()
 56  
 57          def _touch_activity(self, desc):
 58              self._last_activity = time.time()
 59  
 60          def _vprint(self, msg, force=False):
 61              pass
 62  
 63          def _safe_print(self, msg):
 64              pass
 65  
 66          def _should_emit_quiet_tool_messages(self):
 67              return False
 68  
 69          def _should_start_quiet_spinner(self):
 70              return False
 71  
 72          def _has_stream_consumers(self):
 73              return False
 74  
 75      stub = _Stub()
 76      # Bind the real methods under test
 77      stub._execute_tool_calls_concurrent = _ra.AIAgent._execute_tool_calls_concurrent.__get__(stub)
 78      stub.interrupt = _ra.AIAgent.interrupt.__get__(stub)
 79      stub.clear_interrupt = _ra.AIAgent.clear_interrupt.__get__(stub)
 80      # /steer injection (added in PR #12116) fires after every concurrent
 81      # tool batch. Stub it as a no-op — this test exercises interrupt
 82      # fanout, not steer injection.
 83      stub._apply_pending_steer_to_tool_results = lambda *a, **kw: None
 84      stub._invoke_tool = MagicMock(side_effect=lambda *a, **kw: '{"ok": true}')
 85      return stub
 86  
 87  
 88  class _FakeToolCall:
 89      def __init__(self, name, args="{}", call_id="tc_1"):
 90          self.function = MagicMock(name=name, arguments=args)
 91          self.function.name = name
 92          self.id = call_id
 93  
 94  
 95  class _FakeAssistantMsg:
 96      def __init__(self, tool_calls):
 97          self.tool_calls = tool_calls
 98  
 99  
100  def test_concurrent_interrupt_cancels_pending(monkeypatch):
101      """When _interrupt_requested is set during concurrent execution,
102      the wait loop should exit early and cancelled tools get interrupt messages."""
103      agent = _make_agent(monkeypatch)
104  
105      # Create a tool that blocks until interrupted
106      barrier = threading.Event()
107  
108      original_invoke = agent._invoke_tool
109  
110      def slow_tool(name, args, task_id, call_id=None):
111          if name == "slow_one":
112              # Block until the test sets the interrupt
113              barrier.wait(timeout=10)
114              return '{"slow": true}'
115          return '{"fast": true}'
116  
117      agent._invoke_tool = MagicMock(side_effect=slow_tool)
118  
119      tc1 = _FakeToolCall("fast_one", call_id="tc_fast")
120      tc2 = _FakeToolCall("slow_one", call_id="tc_slow")
121      msg = _FakeAssistantMsg([tc1, tc2])
122      messages = []
123  
124      def _set_interrupt_after_delay():
125          time.sleep(0.3)
126          agent._interrupt_requested = True
127          barrier.set()  # unblock the slow tool
128  
129      t = threading.Thread(target=_set_interrupt_after_delay)
130      t.start()
131  
132      agent._execute_tool_calls_concurrent(msg, messages, "test_task")
133      t.join()
134  
135      # Both tools should have results in messages
136      assert len(messages) == 2
137      # The interrupt was detected
138      assert agent._interrupt_requested is True
139  
140  
141  def test_concurrent_preflight_interrupt_skips_all(monkeypatch):
142      """When _interrupt_requested is already set before concurrent execution,
143      all tools are skipped with cancellation messages."""
144      agent = _make_agent(monkeypatch)
145      agent._interrupt_requested = True
146  
147      tc1 = _FakeToolCall("tool_a", call_id="tc_a")
148      tc2 = _FakeToolCall("tool_b", call_id="tc_b")
149      msg = _FakeAssistantMsg([tc1, tc2])
150      messages = []
151  
152      agent._execute_tool_calls_concurrent(msg, messages, "test_task")
153  
154      assert len(messages) == 2
155      assert "skipped due to user interrupt" in messages[0]["content"]
156      assert "skipped due to user interrupt" in messages[1]["content"]
157      # _invoke_tool should never have been called
158      agent._invoke_tool.assert_not_called()
159  
160  
161  def test_running_concurrent_worker_sees_is_interrupted(monkeypatch):
162      """Regression guard for the "interrupt-doesn't-reach-hung-tool" class of
163      bug Physikal reported in April 2026.
164  
165      Before this fix, `AIAgent.interrupt()` called `_set_interrupt(True,
166      _execution_thread_id)` — which only flagged the agent's *main* thread.
167      Tools running inside `_execute_tool_calls_concurrent` execute on
168      ThreadPoolExecutor worker threads whose tids are NOT the agent's, so
169      `is_interrupted()` (which checks the *current* thread's tid) returned
170      False inside those tools no matter how many times the gateway called
171      `.interrupt()`.  Hung ssh / long curl / big make-build tools would run
172      to their own timeout.
173  
174      This test runs a fake tool in the concurrent path that polls
175      `is_interrupted()` like a real terminal command does, then calls
176      `agent.interrupt()` from another thread, and asserts the poll sees True
177      within one second.
178      """
179      from tools.interrupt import is_interrupted
180  
181      agent = _make_agent(monkeypatch)
182  
183      # Counter plus observation hooks so we can prove the worker saw the flip.
184      observed = {"saw_true": False, "poll_count": 0, "worker_tid": None}
185      worker_started = threading.Event()
186  
187      def polling_tool(name, args, task_id, call_id=None, messages=None):
188          observed["worker_tid"] = threading.current_thread().ident
189          worker_started.set()
190          deadline = time.monotonic() + 5.0
191          while time.monotonic() < deadline:
192              observed["poll_count"] += 1
193              if is_interrupted():
194                  observed["saw_true"] = True
195                  return '{"interrupted": true}'
196              time.sleep(0.05)
197          return '{"timed_out": true}'
198  
199      agent._invoke_tool = MagicMock(side_effect=polling_tool)
200  
201      tc1 = _FakeToolCall("hung_fake_tool_1", call_id="tc1")
202      tc2 = _FakeToolCall("hung_fake_tool_2", call_id="tc2")
203      msg = _FakeAssistantMsg([tc1, tc2])
204      messages = []
205  
206      def _interrupt_after_start():
207          # Wait until at least one worker is running so its tid is tracked.
208          worker_started.wait(timeout=2.0)
209          time.sleep(0.2)  # let the other worker enter too
210          agent.interrupt("stop requested by test")
211  
212      t = threading.Thread(target=_interrupt_after_start)
213      t.start()
214      start = time.monotonic()
215      agent._execute_tool_calls_concurrent(msg, messages, "test_task")
216      elapsed = time.monotonic() - start
217      t.join(timeout=2.0)
218  
219      # The worker must have actually polled is_interrupted — otherwise the
220      # test isn't exercising what it claims to.
221      assert observed["poll_count"] > 0, (
222          "polling_tool never ran — test scaffold issue"
223      )
224      # The worker must see the interrupt within ~1 s of agent.interrupt()
225      # being called.  Before the fix this loop ran until its 5 s own-timeout.
226      assert observed["saw_true"], (
227          f"is_interrupted() never returned True inside the concurrent worker "
228          f"after agent.interrupt() — interrupt-propagation hole regressed. "
229          f"worker_tid={observed['worker_tid']!r} poll_count={observed['poll_count']}"
230      )
231      assert elapsed < 3.0, (
232          f"concurrent execution took {elapsed:.2f}s after interrupt — the fan-out "
233          f"to worker tids didn't shortcut the tool's poll loop as expected"
234      )
235      # Also verify cleanup: no stale worker tids should remain after all
236      # tools finished.
237      assert agent._tool_worker_threads == set(), (
238          f"worker tids leaked after run: {agent._tool_worker_threads}"
239      )
240  
241  
242  def test_clear_interrupt_clears_worker_tids(monkeypatch):
243      """After clear_interrupt(), stale worker-tid bits must be cleared so the
244      next turn's tools — which may be scheduled onto recycled tids — don't
245      see a false interrupt."""
246      from tools.interrupt import is_interrupted, set_interrupt
247  
248      agent = _make_agent(monkeypatch)
249      # Simulate a worker having registered but not yet exited cleanly (e.g. a
250      # hypothetical bug in the tear-down).  Put a fake tid in the set and
251      # flag it interrupted.
252      fake_tid = threading.current_thread().ident  # use real tid so is_interrupted can see it
253      with agent._tool_worker_threads_lock:
254          agent._tool_worker_threads.add(fake_tid)
255      set_interrupt(True, fake_tid)
256      assert is_interrupted() is True  # sanity
257  
258      agent.clear_interrupt()
259  
260      assert is_interrupted() is False, (
261          "clear_interrupt() did not clear the interrupt bit for a tracked "
262          "worker tid — stale interrupt can leak into the next turn"
263      )
264