test_concurrent_interrupt.py
1 """Tests for interrupt handling in concurrent tool execution.""" 2 3 import concurrent.futures 4 import threading 5 import time 6 from unittest.mock import MagicMock, patch 7 8 import pytest 9 10 11 @pytest.fixture(autouse=True) 12 def _isolate_hermes(tmp_path, monkeypatch): 13 monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes")) 14 (tmp_path / ".hermes").mkdir(exist_ok=True) 15 16 17 def _make_agent(monkeypatch): 18 """Create a minimal AIAgent-like object with just the methods under test.""" 19 monkeypatch.setenv("OPENROUTER_API_KEY", "") 20 monkeypatch.setenv("HERMES_INFERENCE_PROVIDER", "") 21 # Avoid full AIAgent init — just import the class and build a stub 22 import run_agent as _ra 23 24 class _Stub: 25 _interrupt_requested = False 26 _interrupt_message = None 27 # Bind to this thread's ident so interrupt() targets a real tid. 28 _execution_thread_id = threading.current_thread().ident 29 _interrupt_thread_signal_pending = False 30 log_prefix = "" 31 quiet_mode = True 32 verbose_logging = False 33 log_prefix_chars = 200 34 _checkpoint_mgr = MagicMock(enabled=False) 35 _subdirectory_hints = MagicMock() 36 tool_progress_callback = None 37 tool_start_callback = None 38 tool_complete_callback = None 39 _todo_store = MagicMock() 40 _session_db = None 41 valid_tool_names = set() 42 _turns_since_memory = 0 43 _iters_since_skill = 0 44 _current_tool = None 45 _last_activity = 0 46 _print_fn = print 47 # Worker-thread tracking state mirrored from AIAgent.__init__ so the 48 # real interrupt() method can fan out to concurrent-tool workers. 49 _active_children: list = [] 50 51 def __init__(self): 52 # Instance-level (not class-level) so each test gets a fresh set. 53 self._tool_worker_threads: set = set() 54 self._tool_worker_threads_lock = threading.Lock() 55 self._active_children_lock = threading.Lock() 56 57 def _touch_activity(self, desc): 58 self._last_activity = time.time() 59 60 def _vprint(self, msg, force=False): 61 pass 62 63 def _safe_print(self, msg): 64 pass 65 66 def _should_emit_quiet_tool_messages(self): 67 return False 68 69 def _should_start_quiet_spinner(self): 70 return False 71 72 def _has_stream_consumers(self): 73 return False 74 75 stub = _Stub() 76 # Bind the real methods under test 77 stub._execute_tool_calls_concurrent = _ra.AIAgent._execute_tool_calls_concurrent.__get__(stub) 78 stub.interrupt = _ra.AIAgent.interrupt.__get__(stub) 79 stub.clear_interrupt = _ra.AIAgent.clear_interrupt.__get__(stub) 80 # /steer injection (added in PR #12116) fires after every concurrent 81 # tool batch. Stub it as a no-op — this test exercises interrupt 82 # fanout, not steer injection. 83 stub._apply_pending_steer_to_tool_results = lambda *a, **kw: None 84 stub._invoke_tool = MagicMock(side_effect=lambda *a, **kw: '{"ok": true}') 85 return stub 86 87 88 class _FakeToolCall: 89 def __init__(self, name, args="{}", call_id="tc_1"): 90 self.function = MagicMock(name=name, arguments=args) 91 self.function.name = name 92 self.id = call_id 93 94 95 class _FakeAssistantMsg: 96 def __init__(self, tool_calls): 97 self.tool_calls = tool_calls 98 99 100 def test_concurrent_interrupt_cancels_pending(monkeypatch): 101 """When _interrupt_requested is set during concurrent execution, 102 the wait loop should exit early and cancelled tools get interrupt messages.""" 103 agent = _make_agent(monkeypatch) 104 105 # Create a tool that blocks until interrupted 106 barrier = threading.Event() 107 108 original_invoke = agent._invoke_tool 109 110 def slow_tool(name, args, task_id, call_id=None): 111 if name == "slow_one": 112 # Block until the test sets the interrupt 113 barrier.wait(timeout=10) 114 return '{"slow": true}' 115 return '{"fast": true}' 116 117 agent._invoke_tool = MagicMock(side_effect=slow_tool) 118 119 tc1 = _FakeToolCall("fast_one", call_id="tc_fast") 120 tc2 = _FakeToolCall("slow_one", call_id="tc_slow") 121 msg = _FakeAssistantMsg([tc1, tc2]) 122 messages = [] 123 124 def _set_interrupt_after_delay(): 125 time.sleep(0.3) 126 agent._interrupt_requested = True 127 barrier.set() # unblock the slow tool 128 129 t = threading.Thread(target=_set_interrupt_after_delay) 130 t.start() 131 132 agent._execute_tool_calls_concurrent(msg, messages, "test_task") 133 t.join() 134 135 # Both tools should have results in messages 136 assert len(messages) == 2 137 # The interrupt was detected 138 assert agent._interrupt_requested is True 139 140 141 def test_concurrent_preflight_interrupt_skips_all(monkeypatch): 142 """When _interrupt_requested is already set before concurrent execution, 143 all tools are skipped with cancellation messages.""" 144 agent = _make_agent(monkeypatch) 145 agent._interrupt_requested = True 146 147 tc1 = _FakeToolCall("tool_a", call_id="tc_a") 148 tc2 = _FakeToolCall("tool_b", call_id="tc_b") 149 msg = _FakeAssistantMsg([tc1, tc2]) 150 messages = [] 151 152 agent._execute_tool_calls_concurrent(msg, messages, "test_task") 153 154 assert len(messages) == 2 155 assert "skipped due to user interrupt" in messages[0]["content"] 156 assert "skipped due to user interrupt" in messages[1]["content"] 157 # _invoke_tool should never have been called 158 agent._invoke_tool.assert_not_called() 159 160 161 def test_running_concurrent_worker_sees_is_interrupted(monkeypatch): 162 """Regression guard for the "interrupt-doesn't-reach-hung-tool" class of 163 bug Physikal reported in April 2026. 164 165 Before this fix, `AIAgent.interrupt()` called `_set_interrupt(True, 166 _execution_thread_id)` — which only flagged the agent's *main* thread. 167 Tools running inside `_execute_tool_calls_concurrent` execute on 168 ThreadPoolExecutor worker threads whose tids are NOT the agent's, so 169 `is_interrupted()` (which checks the *current* thread's tid) returned 170 False inside those tools no matter how many times the gateway called 171 `.interrupt()`. Hung ssh / long curl / big make-build tools would run 172 to their own timeout. 173 174 This test runs a fake tool in the concurrent path that polls 175 `is_interrupted()` like a real terminal command does, then calls 176 `agent.interrupt()` from another thread, and asserts the poll sees True 177 within one second. 178 """ 179 from tools.interrupt import is_interrupted 180 181 agent = _make_agent(monkeypatch) 182 183 # Counter plus observation hooks so we can prove the worker saw the flip. 184 observed = {"saw_true": False, "poll_count": 0, "worker_tid": None} 185 worker_started = threading.Event() 186 187 def polling_tool(name, args, task_id, call_id=None, messages=None): 188 observed["worker_tid"] = threading.current_thread().ident 189 worker_started.set() 190 deadline = time.monotonic() + 5.0 191 while time.monotonic() < deadline: 192 observed["poll_count"] += 1 193 if is_interrupted(): 194 observed["saw_true"] = True 195 return '{"interrupted": true}' 196 time.sleep(0.05) 197 return '{"timed_out": true}' 198 199 agent._invoke_tool = MagicMock(side_effect=polling_tool) 200 201 tc1 = _FakeToolCall("hung_fake_tool_1", call_id="tc1") 202 tc2 = _FakeToolCall("hung_fake_tool_2", call_id="tc2") 203 msg = _FakeAssistantMsg([tc1, tc2]) 204 messages = [] 205 206 def _interrupt_after_start(): 207 # Wait until at least one worker is running so its tid is tracked. 208 worker_started.wait(timeout=2.0) 209 time.sleep(0.2) # let the other worker enter too 210 agent.interrupt("stop requested by test") 211 212 t = threading.Thread(target=_interrupt_after_start) 213 t.start() 214 start = time.monotonic() 215 agent._execute_tool_calls_concurrent(msg, messages, "test_task") 216 elapsed = time.monotonic() - start 217 t.join(timeout=2.0) 218 219 # The worker must have actually polled is_interrupted — otherwise the 220 # test isn't exercising what it claims to. 221 assert observed["poll_count"] > 0, ( 222 "polling_tool never ran — test scaffold issue" 223 ) 224 # The worker must see the interrupt within ~1 s of agent.interrupt() 225 # being called. Before the fix this loop ran until its 5 s own-timeout. 226 assert observed["saw_true"], ( 227 f"is_interrupted() never returned True inside the concurrent worker " 228 f"after agent.interrupt() — interrupt-propagation hole regressed. " 229 f"worker_tid={observed['worker_tid']!r} poll_count={observed['poll_count']}" 230 ) 231 assert elapsed < 3.0, ( 232 f"concurrent execution took {elapsed:.2f}s after interrupt — the fan-out " 233 f"to worker tids didn't shortcut the tool's poll loop as expected" 234 ) 235 # Also verify cleanup: no stale worker tids should remain after all 236 # tools finished. 237 assert agent._tool_worker_threads == set(), ( 238 f"worker tids leaked after run: {agent._tool_worker_threads}" 239 ) 240 241 242 def test_clear_interrupt_clears_worker_tids(monkeypatch): 243 """After clear_interrupt(), stale worker-tid bits must be cleared so the 244 next turn's tools — which may be scheduled onto recycled tids — don't 245 see a false interrupt.""" 246 from tools.interrupt import is_interrupted, set_interrupt 247 248 agent = _make_agent(monkeypatch) 249 # Simulate a worker having registered but not yet exited cleanly (e.g. a 250 # hypothetical bug in the tear-down). Put a fake tid in the set and 251 # flag it interrupted. 252 fake_tid = threading.current_thread().ident # use real tid so is_interrupted can see it 253 with agent._tool_worker_threads_lock: 254 agent._tool_worker_threads.add(fake_tid) 255 set_interrupt(True, fake_tid) 256 assert is_interrupted() is True # sanity 257 258 agent.clear_interrupt() 259 260 assert is_interrupted() is False, ( 261 "clear_interrupt() did not clear the interrupt bit for a tracked " 262 "worker tid — stale interrupt can leak into the next turn" 263 ) 264