/ tests / gateway / test_run_progress_interrupt.py
test_run_progress_interrupt.py
  1  """Tests for interrupt-aware tool-progress suppression in gateway.
  2  
  3  When a user sends `stop` while the agent is executing a batch of parallel
  4  tool calls, the gateway's progress_callback should stop queuing 🔍 bubbles
  5  and the drain loop should drop any already-queued events.  Without this
  6  guard, the stop acknowledgement appears first but is followed by a trail
  7  of tool-progress bubbles for calls that were already parsed from the LLM
  8  response — making the interrupt feel ignored.
  9  """
 10  
 11  import asyncio
 12  import importlib
 13  import sys
 14  import time
 15  import types
 16  from types import SimpleNamespace
 17  
 18  import pytest
 19  
 20  from gateway.config import Platform, PlatformConfig
 21  from gateway.platforms.base import BasePlatformAdapter, SendResult
 22  from gateway.session import SessionSource
 23  
 24  
 25  class ProgressCaptureAdapter(BasePlatformAdapter):
 26      def __init__(self, platform=Platform.TELEGRAM):
 27          super().__init__(PlatformConfig(enabled=True, token="***"), platform)
 28          self.sent = []
 29          self.edits = []
 30          self.typing = []
 31  
 32      async def connect(self) -> bool:
 33          return True
 34  
 35      async def disconnect(self) -> None:
 36          return None
 37  
 38      async def send(self, chat_id, content, reply_to=None, metadata=None) -> SendResult:
 39          self.sent.append({"chat_id": chat_id, "content": content})
 40          return SendResult(success=True, message_id="progress-1")
 41  
 42      async def edit_message(self, chat_id, message_id, content) -> SendResult:
 43          self.edits.append({"message_id": message_id, "content": content})
 44          return SendResult(success=True, message_id=message_id)
 45  
 46      async def send_typing(self, chat_id, metadata=None) -> None:
 47          self.typing.append(chat_id)
 48  
 49      async def stop_typing(self, chat_id) -> None:
 50          return None
 51  
 52      async def get_chat_info(self, chat_id: str):
 53          return {"id": chat_id}
 54  
 55  
 56  class PreInterruptAgent:
 57      """Fires tool-progress events BEFORE the interrupt lands.
 58  
 59      These should render normally.  Baseline for comparison with the
 60      interrupted case — proves the harness renders events when no
 61      interrupt is active.
 62      """
 63  
 64      def __init__(self, **kwargs):
 65          self.tool_progress_callback = kwargs.get("tool_progress_callback")
 66          self.tools = []
 67          self._interrupt_requested = False
 68  
 69      @property
 70      def is_interrupted(self) -> bool:
 71          return self._interrupt_requested
 72  
 73      def run_conversation(self, message, conversation_history=None, task_id=None):
 74          self.tool_progress_callback("tool.started", "web_search", "first search", {})
 75          time.sleep(0.35)  # let the drain loop process
 76          return {"final_response": "done", "messages": [], "api_calls": 1}
 77  
 78  
 79  class InterruptedAgent:
 80      """Fires tool.started events AFTER interrupt — all should be suppressed.
 81  
 82      Mirrors the failure mode in the bug report: LLM returned N parallel
 83      web_search calls, interrupt flag flipped, remaining events still
 84      rendered as bubbles.  With the fix, none of these should appear.
 85      """
 86  
 87      def __init__(self, **kwargs):
 88          self.tool_progress_callback = kwargs.get("tool_progress_callback")
 89          self.tools = []
 90          # Start already interrupted — simulates stop having already landed
 91          # by the time the agent batch starts firing tool.started events.
 92          self._interrupt_requested = True
 93  
 94      @property
 95      def is_interrupted(self) -> bool:
 96          return self._interrupt_requested
 97  
 98      def run_conversation(self, message, conversation_history=None, task_id=None):
 99          # Parallel tool batch — in production these come from one LLM
100          # response with 5 tool_calls.  All are post-interrupt.
101          self.tool_progress_callback("tool.started", "web_search", "cognee hermes", {})
102          self.tool_progress_callback("tool.started", "web_search", "McBee deer hunting", {})
103          self.tool_progress_callback("tool.started", "web_search", "kuzu graph db", {})
104          self.tool_progress_callback("tool.started", "web_search", "moonshot kimi api", {})
105          self.tool_progress_callback("tool.started", "web_search", "platform.moonshot.cn", {})
106          time.sleep(0.35)  # let the drain loop attempt to process the queue
107          return {"final_response": "interrupted", "messages": [], "api_calls": 1}
108  
109  
110  def _make_runner(adapter):
111      gateway_run = importlib.import_module("gateway.run")
112      GatewayRunner = gateway_run.GatewayRunner
113  
114      runner = object.__new__(GatewayRunner)
115      runner.adapters = {adapter.platform: adapter}
116      runner._voice_mode = {}
117      runner._prefill_messages = []
118      runner._ephemeral_system_prompt = ""
119      runner._reasoning_config = None
120      runner._provider_routing = {}
121      runner._fallback_model = None
122      runner._session_db = None
123      runner._running_agents = {}
124      runner._session_run_generation = {}
125      runner.hooks = SimpleNamespace(loaded_hooks=False)
126      runner.config = SimpleNamespace(
127          thread_sessions_per_user=False,
128          group_sessions_per_user=False,
129          stt_enabled=False,
130      )
131      return runner
132  
133  
134  async def _run_once(monkeypatch, tmp_path, agent_cls, session_id):
135      monkeypatch.setenv("HERMES_TOOL_PROGRESS_MODE", "all")
136  
137      fake_dotenv = types.ModuleType("dotenv")
138      fake_dotenv.load_dotenv = lambda *args, **kwargs: None
139      monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
140  
141      fake_run_agent = types.ModuleType("run_agent")
142      fake_run_agent.AIAgent = agent_cls
143      monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
144  
145      adapter = ProgressCaptureAdapter()
146      runner = _make_runner(adapter)
147      gateway_run = importlib.import_module("gateway.run")
148      monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
149      monkeypatch.setattr(
150          gateway_run,
151          "_resolve_runtime_agent_kwargs",
152          lambda: {"api_key": "fake"},
153      )
154      source = SessionSource(
155          platform=Platform.TELEGRAM,
156          chat_id="-1001",
157          chat_type="group",
158          thread_id="17585",
159      )
160      result = await runner._run_agent(
161          message="hi",
162          context_prompt="",
163          history=[],
164          source=source,
165          session_id=session_id,
166          session_key="agent:main:telegram:group:-1001:17585",
167      )
168      return adapter, result
169  
170  
171  @pytest.mark.asyncio
172  async def test_baseline_non_interrupted_agent_renders_progress(monkeypatch, tmp_path):
173      """Sanity check: when is_interrupted is False, tool-progress renders normally."""
174      adapter, result = await _run_once(monkeypatch, tmp_path, PreInterruptAgent, "sess-baseline")
175      assert result["final_response"] == "done"
176      rendered = " ".join(c["content"] for c in adapter.sent) + " " + " ".join(
177          c["content"] for c in adapter.edits
178      )
179      assert "first search" in rendered, (
180          "baseline agent should render its tool-progress event — "
181          "if this fails the test harness is broken, not the fix"
182      )
183  
184  
185  @pytest.mark.asyncio
186  async def test_progress_suppressed_when_agent_is_interrupted(monkeypatch, tmp_path):
187      """Post-interrupt tool.started events must not render as bubbles.
188  
189      This is Bug B from the screenshot: user sends `stop`, agent acks with
190      ⚡ Interrupting, but 5 more 🔍 web_search bubbles still render because
191      their tool.started events were already parsed from the LLM response.
192      With the fix, progress_callback and the drain loop both check
193      is_interrupted and skip these events.
194      """
195      adapter, result = await _run_once(
196          monkeypatch, tmp_path, InterruptedAgent, "sess-interrupted"
197      )
198      assert result["final_response"] == "interrupted"
199  
200      rendered = " ".join(c["content"] for c in adapter.sent) + " " + " ".join(
201          c["content"] for c in adapter.edits
202      )
203  
204      # None of the post-interrupt queries should appear.
205      for leaked_query in (
206          "cognee hermes",
207          "McBee deer hunting",
208          "kuzu graph db",
209          "moonshot kimi api",
210          "platform.moonshot.cn",
211      ):
212          assert leaked_query not in rendered, (
213              f"event '{leaked_query}' leaked into the UI after interrupt — "
214              f"progress_callback / drain loop is not checking is_interrupted"
215          )