/ tests / run_agent / test_interactive_interrupt.py
test_interactive_interrupt.py
  1  #!/usr/bin/env python3
  2  """Interactive interrupt test that mimics the exact CLI flow.
  3  
  4  Starts an agent in a thread with a mock delegate_task that takes a while,
  5  then simulates the user typing a message via _interrupt_queue.
  6  
  7  Logs every step to stderr (which isn't affected by redirect_stdout)
  8  so we can see exactly where the interrupt gets lost.
  9  """
 10  
 11  import contextlib
 12  import io
 13  import json
 14  import logging
 15  import queue
 16  import sys
 17  import threading
 18  import time
 19  import os
 20  
 21  # Force stderr logging so redirect_stdout doesn't swallow it
 22  logging.basicConfig(level=logging.DEBUG, stream=sys.stderr,
 23                      format="%(asctime)s [%(threadName)s] %(message)s")
 24  log = logging.getLogger("interrupt_test")
 25  
 26  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
 27  
 28  from unittest.mock import MagicMock, patch
 29  from run_agent import AIAgent, IterationBudget
 30  from tools.interrupt import set_interrupt, is_interrupted
 31  
 32  def make_slow_response(delay=2.0):
 33      """API response that takes a while."""
 34      def create(**kwargs):
 35          log.info(f"   🌐 Mock API call starting (will take {delay}s)...")
 36          time.sleep(delay)
 37          log.info(f"   🌐 Mock API call completed")
 38          resp = MagicMock()
 39          resp.choices = [MagicMock()]
 40          resp.choices[0].message.content = "Done with the task"
 41          resp.choices[0].message.tool_calls = None
 42          resp.choices[0].message.refusal = None
 43          resp.choices[0].finish_reason = "stop"
 44          resp.usage.prompt_tokens = 100
 45          resp.usage.completion_tokens = 10
 46          resp.usage.total_tokens = 110
 47          resp.usage.prompt_tokens_details = None
 48          return resp
 49      return create
 50  
 51  
 52  def main() -> int:
 53      set_interrupt(False)
 54  
 55      # ─── Create parent agent ───
 56      parent = AIAgent.__new__(AIAgent)
 57      parent._interrupt_requested = False
 58      parent._interrupt_message = None
 59      parent._active_children = []
 60      parent._active_children_lock = threading.Lock()
 61      parent.quiet_mode = True
 62      parent.model = "test/model"
 63      parent.base_url = "http://localhost:1"
 64      parent.api_key = "test"
 65      parent.provider = "test"
 66      parent.api_mode = "chat_completions"
 67      parent.platform = "cli"
 68      parent.enabled_toolsets = ["terminal", "file"]
 69      parent.providers_allowed = None
 70      parent.providers_ignored = None
 71      parent.providers_order = None
 72      parent.provider_sort = None
 73      parent.max_tokens = None
 74      parent.reasoning_config = None
 75      parent.prefill_messages = None
 76      parent._session_db = None
 77      parent._delegate_depth = 0
 78      parent._delegate_spinner = None
 79      parent.tool_progress_callback = None
 80      parent.iteration_budget = IterationBudget(max_total=100)
 81      parent._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1"}
 82  
 83      # Monkey-patch parent.interrupt to log
 84      _original_interrupt = AIAgent.interrupt
 85  
 86      def logged_interrupt(self, message=None):
 87          log.info(f"🔴 parent.interrupt() called with: {message!r}")
 88          log.info(f"   _active_children count: {len(self._active_children)}")
 89          _original_interrupt(self, message)
 90          log.info(f"   After interrupt: _interrupt_requested={self._interrupt_requested}")
 91          for i, child in enumerate(self._active_children):
 92              log.info(f"   Child {i}._interrupt_requested={child._interrupt_requested}")
 93  
 94      parent.interrupt = lambda msg=None: logged_interrupt(parent, msg)
 95  
 96      # ─── Simulate the exact CLI flow ───
 97      interrupt_queue = queue.Queue()
 98      child_running = threading.Event()
 99      agent_result = [None]
100  
101      def agent_thread_func():
102          """Simulates the agent_thread in cli.py's chat() method."""
103          log.info("🟢 agent_thread starting")
104  
105          with patch("run_agent.OpenAI") as MockOpenAI:
106              mock_client = MagicMock()
107              mock_client.chat.completions.create = make_slow_response(delay=3.0)
108              mock_client.close = MagicMock()
109              MockOpenAI.return_value = mock_client
110  
111              from tools.delegate_tool import _run_single_child
112  
113              # Signal that child is about to start
114              original_init = AIAgent.__init__
115  
116              def patched_init(self_agent, *a, **kw):
117                  log.info("🟡 Child AIAgent.__init__ called")
118                  original_init(self_agent, *a, **kw)
119                  child_running.set()
120                  log.info(
121                      f"🟡 Child started, parent._active_children = {len(parent._active_children)}"
122                  )
123  
124              with patch.object(AIAgent, "__init__", patched_init):
125                  result = _run_single_child(
126                      task_index=0,
127                      goal="Do a slow thing",
128                      context=None,
129                      toolsets=["terminal"],
130                      model="test/model",
131                      max_iterations=3,
132                      parent_agent=parent,
133                      task_count=1,
134                      override_provider="test",
135                      override_base_url="http://localhost:1",
136                      override_api_key="test",
137                      override_api_mode="chat_completions",
138                  )
139                  agent_result[0] = result
140                  log.info(f"🟢 agent_thread finished. Result status: {result.get('status')}")
141  
142      # ─── Start agent thread (like chat() does) ───
143      agent_thread = threading.Thread(target=agent_thread_func, name="agent_thread", daemon=True)
144      agent_thread.start()
145  
146      # ─── Wait for child to start ───
147      if not child_running.wait(timeout=10):
148          print("FAIL: Child never started", file=sys.stderr)
149          set_interrupt(False)
150          return 1
151  
152      # Give child time to enter its main loop and start API call
153      time.sleep(1.0)
154  
155      # ─── Simulate user typing a message (like handle_enter does) ───
156      log.info("📝 Simulating user typing 'Hey stop that'")
157      interrupt_queue.put("Hey stop that")
158  
159      # ─── Simulate chat() polling loop (like the real chat() method) ───
160      log.info("📡 Starting interrupt queue polling (like chat())")
161      interrupt_msg = None
162      poll_count = 0
163      while agent_thread.is_alive():
164          try:
165              interrupt_msg = interrupt_queue.get(timeout=0.1)
166              if interrupt_msg:
167                  log.info(f"📨 Got interrupt message from queue: {interrupt_msg!r}")
168                  log.info("   Calling parent.interrupt()...")
169                  parent.interrupt(interrupt_msg)
170                  log.info("   parent.interrupt() returned. Breaking poll loop.")
171                  break
172          except queue.Empty:
173              poll_count += 1
174              if poll_count % 20 == 0:  # Log every 2s
175                  log.info(f"   Still polling ({poll_count} iterations)...")
176  
177      # ─── Wait for agent to finish ───
178      log.info("⏳ Waiting for agent_thread to join...")
179      t0 = time.monotonic()
180      agent_thread.join(timeout=10)
181      elapsed = time.monotonic() - t0
182      log.info(f"✅ agent_thread joined after {elapsed:.2f}s")
183  
184      # ─── Check results ───
185      result = agent_result[0]
186      if result:
187          log.info(f"Result status: {result['status']}")
188          log.info(f"Result duration: {result['duration_seconds']}s")
189          if result["status"] == "interrupted" and elapsed < 2.0:
190              print("✅ PASS: Interrupt worked correctly!", file=sys.stderr)
191              set_interrupt(False)
192              return 0
193          print(f"❌ FAIL: status={result['status']}, elapsed={elapsed:.2f}s", file=sys.stderr)
194          set_interrupt(False)
195          return 1
196  
197      print("❌ FAIL: No result returned", file=sys.stderr)
198      set_interrupt(False)
199      return 1
200  
201  
202  if __name__ == "__main__":
203      sys.exit(main())