test_run_progress_interrupt.py
1 """Tests for interrupt-aware tool-progress suppression in gateway. 2 3 When a user sends `stop` while the agent is executing a batch of parallel 4 tool calls, the gateway's progress_callback should stop queuing 🔍 bubbles 5 and the drain loop should drop any already-queued events. Without this 6 guard, the stop acknowledgement appears first but is followed by a trail 7 of tool-progress bubbles for calls that were already parsed from the LLM 8 response — making the interrupt feel ignored. 9 """ 10 11 import asyncio 12 import importlib 13 import sys 14 import time 15 import types 16 from types import SimpleNamespace 17 18 import pytest 19 20 from gateway.config import Platform, PlatformConfig 21 from gateway.platforms.base import BasePlatformAdapter, SendResult 22 from gateway.session import SessionSource 23 24 25 class ProgressCaptureAdapter(BasePlatformAdapter): 26 def __init__(self, platform=Platform.TELEGRAM): 27 super().__init__(PlatformConfig(enabled=True, token="***"), platform) 28 self.sent = [] 29 self.edits = [] 30 self.typing = [] 31 32 async def connect(self) -> bool: 33 return True 34 35 async def disconnect(self) -> None: 36 return None 37 38 async def send(self, chat_id, content, reply_to=None, metadata=None) -> SendResult: 39 self.sent.append({"chat_id": chat_id, "content": content}) 40 return SendResult(success=True, message_id="progress-1") 41 42 async def edit_message(self, chat_id, message_id, content) -> SendResult: 43 self.edits.append({"message_id": message_id, "content": content}) 44 return SendResult(success=True, message_id=message_id) 45 46 async def send_typing(self, chat_id, metadata=None) -> None: 47 self.typing.append(chat_id) 48 49 async def stop_typing(self, chat_id) -> None: 50 return None 51 52 async def get_chat_info(self, chat_id: str): 53 return {"id": chat_id} 54 55 56 class PreInterruptAgent: 57 """Fires tool-progress events BEFORE the interrupt lands. 58 59 These should render normally. Baseline for comparison with the 60 interrupted case — proves the harness renders events when no 61 interrupt is active. 62 """ 63 64 def __init__(self, **kwargs): 65 self.tool_progress_callback = kwargs.get("tool_progress_callback") 66 self.tools = [] 67 self._interrupt_requested = False 68 69 @property 70 def is_interrupted(self) -> bool: 71 return self._interrupt_requested 72 73 def run_conversation(self, message, conversation_history=None, task_id=None): 74 self.tool_progress_callback("tool.started", "web_search", "first search", {}) 75 time.sleep(0.35) # let the drain loop process 76 return {"final_response": "done", "messages": [], "api_calls": 1} 77 78 79 class InterruptedAgent: 80 """Fires tool.started events AFTER interrupt — all should be suppressed. 81 82 Mirrors the failure mode in the bug report: LLM returned N parallel 83 web_search calls, interrupt flag flipped, remaining events still 84 rendered as bubbles. With the fix, none of these should appear. 85 """ 86 87 def __init__(self, **kwargs): 88 self.tool_progress_callback = kwargs.get("tool_progress_callback") 89 self.tools = [] 90 # Start already interrupted — simulates stop having already landed 91 # by the time the agent batch starts firing tool.started events. 92 self._interrupt_requested = True 93 94 @property 95 def is_interrupted(self) -> bool: 96 return self._interrupt_requested 97 98 def run_conversation(self, message, conversation_history=None, task_id=None): 99 # Parallel tool batch — in production these come from one LLM 100 # response with 5 tool_calls. All are post-interrupt. 101 self.tool_progress_callback("tool.started", "web_search", "cognee hermes", {}) 102 self.tool_progress_callback("tool.started", "web_search", "McBee deer hunting", {}) 103 self.tool_progress_callback("tool.started", "web_search", "kuzu graph db", {}) 104 self.tool_progress_callback("tool.started", "web_search", "moonshot kimi api", {}) 105 self.tool_progress_callback("tool.started", "web_search", "platform.moonshot.cn", {}) 106 time.sleep(0.35) # let the drain loop attempt to process the queue 107 return {"final_response": "interrupted", "messages": [], "api_calls": 1} 108 109 110 def _make_runner(adapter): 111 gateway_run = importlib.import_module("gateway.run") 112 GatewayRunner = gateway_run.GatewayRunner 113 114 runner = object.__new__(GatewayRunner) 115 runner.adapters = {adapter.platform: adapter} 116 runner._voice_mode = {} 117 runner._prefill_messages = [] 118 runner._ephemeral_system_prompt = "" 119 runner._reasoning_config = None 120 runner._provider_routing = {} 121 runner._fallback_model = None 122 runner._session_db = None 123 runner._running_agents = {} 124 runner._session_run_generation = {} 125 runner.hooks = SimpleNamespace(loaded_hooks=False) 126 runner.config = SimpleNamespace( 127 thread_sessions_per_user=False, 128 group_sessions_per_user=False, 129 stt_enabled=False, 130 ) 131 return runner 132 133 134 async def _run_once(monkeypatch, tmp_path, agent_cls, session_id): 135 monkeypatch.setenv("HERMES_TOOL_PROGRESS_MODE", "all") 136 137 fake_dotenv = types.ModuleType("dotenv") 138 fake_dotenv.load_dotenv = lambda *args, **kwargs: None 139 monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv) 140 141 fake_run_agent = types.ModuleType("run_agent") 142 fake_run_agent.AIAgent = agent_cls 143 monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent) 144 145 adapter = ProgressCaptureAdapter() 146 runner = _make_runner(adapter) 147 gateway_run = importlib.import_module("gateway.run") 148 monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path) 149 monkeypatch.setattr( 150 gateway_run, 151 "_resolve_runtime_agent_kwargs", 152 lambda: {"api_key": "fake"}, 153 ) 154 source = SessionSource( 155 platform=Platform.TELEGRAM, 156 chat_id="-1001", 157 chat_type="group", 158 thread_id="17585", 159 ) 160 result = await runner._run_agent( 161 message="hi", 162 context_prompt="", 163 history=[], 164 source=source, 165 session_id=session_id, 166 session_key="agent:main:telegram:group:-1001:17585", 167 ) 168 return adapter, result 169 170 171 @pytest.mark.asyncio 172 async def test_baseline_non_interrupted_agent_renders_progress(monkeypatch, tmp_path): 173 """Sanity check: when is_interrupted is False, tool-progress renders normally.""" 174 adapter, result = await _run_once(monkeypatch, tmp_path, PreInterruptAgent, "sess-baseline") 175 assert result["final_response"] == "done" 176 rendered = " ".join(c["content"] for c in adapter.sent) + " " + " ".join( 177 c["content"] for c in adapter.edits 178 ) 179 assert "first search" in rendered, ( 180 "baseline agent should render its tool-progress event — " 181 "if this fails the test harness is broken, not the fix" 182 ) 183 184 185 @pytest.mark.asyncio 186 async def test_progress_suppressed_when_agent_is_interrupted(monkeypatch, tmp_path): 187 """Post-interrupt tool.started events must not render as bubbles. 188 189 This is Bug B from the screenshot: user sends `stop`, agent acks with 190 ⚡ Interrupting, but 5 more 🔍 web_search bubbles still render because 191 their tool.started events were already parsed from the LLM response. 192 With the fix, progress_callback and the drain loop both check 193 is_interrupted and skip these events. 194 """ 195 adapter, result = await _run_once( 196 monkeypatch, tmp_path, InterruptedAgent, "sess-interrupted" 197 ) 198 assert result["final_response"] == "interrupted" 199 200 rendered = " ".join(c["content"] for c in adapter.sent) + " " + " ".join( 201 c["content"] for c in adapter.edits 202 ) 203 204 # None of the post-interrupt queries should appear. 205 for leaked_query in ( 206 "cognee hermes", 207 "McBee deer hunting", 208 "kuzu graph db", 209 "moonshot kimi api", 210 "platform.moonshot.cn", 211 ): 212 assert leaked_query not in rendered, ( 213 f"event '{leaked_query}' leaked into the UI after interrupt — " 214 f"progress_callback / drain loop is not checking is_interrupted" 215 )