test_interactive_interrupt.py
1 #!/usr/bin/env python3 2 """Interactive interrupt test that mimics the exact CLI flow. 3 4 Starts an agent in a thread with a mock delegate_task that takes a while, 5 then simulates the user typing a message via _interrupt_queue. 6 7 Logs every step to stderr (which isn't affected by redirect_stdout) 8 so we can see exactly where the interrupt gets lost. 9 """ 10 11 import contextlib 12 import io 13 import json 14 import logging 15 import queue 16 import sys 17 import threading 18 import time 19 import os 20 21 # Force stderr logging so redirect_stdout doesn't swallow it 22 logging.basicConfig(level=logging.DEBUG, stream=sys.stderr, 23 format="%(asctime)s [%(threadName)s] %(message)s") 24 log = logging.getLogger("interrupt_test") 25 26 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 27 28 from unittest.mock import MagicMock, patch 29 from run_agent import AIAgent, IterationBudget 30 from tools.interrupt import set_interrupt, is_interrupted 31 32 def make_slow_response(delay=2.0): 33 """API response that takes a while.""" 34 def create(**kwargs): 35 log.info(f" 🌐 Mock API call starting (will take {delay}s)...") 36 time.sleep(delay) 37 log.info(f" 🌐 Mock API call completed") 38 resp = MagicMock() 39 resp.choices = [MagicMock()] 40 resp.choices[0].message.content = "Done with the task" 41 resp.choices[0].message.tool_calls = None 42 resp.choices[0].message.refusal = None 43 resp.choices[0].finish_reason = "stop" 44 resp.usage.prompt_tokens = 100 45 resp.usage.completion_tokens = 10 46 resp.usage.total_tokens = 110 47 resp.usage.prompt_tokens_details = None 48 return resp 49 return create 50 51 52 def main() -> int: 53 set_interrupt(False) 54 55 # ─── Create parent agent ─── 56 parent = AIAgent.__new__(AIAgent) 57 parent._interrupt_requested = False 58 parent._interrupt_message = None 59 parent._active_children = [] 60 parent._active_children_lock = threading.Lock() 61 parent.quiet_mode = True 62 parent.model = "test/model" 63 parent.base_url = "http://localhost:1" 64 parent.api_key = "test" 65 parent.provider = "test" 66 parent.api_mode = "chat_completions" 67 parent.platform = "cli" 68 parent.enabled_toolsets = ["terminal", "file"] 69 parent.providers_allowed = None 70 parent.providers_ignored = None 71 parent.providers_order = None 72 parent.provider_sort = None 73 parent.max_tokens = None 74 parent.reasoning_config = None 75 parent.prefill_messages = None 76 parent._session_db = None 77 parent._delegate_depth = 0 78 parent._delegate_spinner = None 79 parent.tool_progress_callback = None 80 parent.iteration_budget = IterationBudget(max_total=100) 81 parent._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1"} 82 83 # Monkey-patch parent.interrupt to log 84 _original_interrupt = AIAgent.interrupt 85 86 def logged_interrupt(self, message=None): 87 log.info(f"🔴 parent.interrupt() called with: {message!r}") 88 log.info(f" _active_children count: {len(self._active_children)}") 89 _original_interrupt(self, message) 90 log.info(f" After interrupt: _interrupt_requested={self._interrupt_requested}") 91 for i, child in enumerate(self._active_children): 92 log.info(f" Child {i}._interrupt_requested={child._interrupt_requested}") 93 94 parent.interrupt = lambda msg=None: logged_interrupt(parent, msg) 95 96 # ─── Simulate the exact CLI flow ─── 97 interrupt_queue = queue.Queue() 98 child_running = threading.Event() 99 agent_result = [None] 100 101 def agent_thread_func(): 102 """Simulates the agent_thread in cli.py's chat() method.""" 103 log.info("🟢 agent_thread starting") 104 105 with patch("run_agent.OpenAI") as MockOpenAI: 106 mock_client = MagicMock() 107 mock_client.chat.completions.create = make_slow_response(delay=3.0) 108 mock_client.close = MagicMock() 109 MockOpenAI.return_value = mock_client 110 111 from tools.delegate_tool import _run_single_child 112 113 # Signal that child is about to start 114 original_init = AIAgent.__init__ 115 116 def patched_init(self_agent, *a, **kw): 117 log.info("🟡 Child AIAgent.__init__ called") 118 original_init(self_agent, *a, **kw) 119 child_running.set() 120 log.info( 121 f"🟡 Child started, parent._active_children = {len(parent._active_children)}" 122 ) 123 124 with patch.object(AIAgent, "__init__", patched_init): 125 result = _run_single_child( 126 task_index=0, 127 goal="Do a slow thing", 128 context=None, 129 toolsets=["terminal"], 130 model="test/model", 131 max_iterations=3, 132 parent_agent=parent, 133 task_count=1, 134 override_provider="test", 135 override_base_url="http://localhost:1", 136 override_api_key="test", 137 override_api_mode="chat_completions", 138 ) 139 agent_result[0] = result 140 log.info(f"🟢 agent_thread finished. Result status: {result.get('status')}") 141 142 # ─── Start agent thread (like chat() does) ─── 143 agent_thread = threading.Thread(target=agent_thread_func, name="agent_thread", daemon=True) 144 agent_thread.start() 145 146 # ─── Wait for child to start ─── 147 if not child_running.wait(timeout=10): 148 print("FAIL: Child never started", file=sys.stderr) 149 set_interrupt(False) 150 return 1 151 152 # Give child time to enter its main loop and start API call 153 time.sleep(1.0) 154 155 # ─── Simulate user typing a message (like handle_enter does) ─── 156 log.info("📝 Simulating user typing 'Hey stop that'") 157 interrupt_queue.put("Hey stop that") 158 159 # ─── Simulate chat() polling loop (like the real chat() method) ─── 160 log.info("📡 Starting interrupt queue polling (like chat())") 161 interrupt_msg = None 162 poll_count = 0 163 while agent_thread.is_alive(): 164 try: 165 interrupt_msg = interrupt_queue.get(timeout=0.1) 166 if interrupt_msg: 167 log.info(f"📨 Got interrupt message from queue: {interrupt_msg!r}") 168 log.info(" Calling parent.interrupt()...") 169 parent.interrupt(interrupt_msg) 170 log.info(" parent.interrupt() returned. Breaking poll loop.") 171 break 172 except queue.Empty: 173 poll_count += 1 174 if poll_count % 20 == 0: # Log every 2s 175 log.info(f" Still polling ({poll_count} iterations)...") 176 177 # ─── Wait for agent to finish ─── 178 log.info("⏳ Waiting for agent_thread to join...") 179 t0 = time.monotonic() 180 agent_thread.join(timeout=10) 181 elapsed = time.monotonic() - t0 182 log.info(f"✅ agent_thread joined after {elapsed:.2f}s") 183 184 # ─── Check results ─── 185 result = agent_result[0] 186 if result: 187 log.info(f"Result status: {result['status']}") 188 log.info(f"Result duration: {result['duration_seconds']}s") 189 if result["status"] == "interrupted" and elapsed < 2.0: 190 print("✅ PASS: Interrupt worked correctly!", file=sys.stderr) 191 set_interrupt(False) 192 return 0 193 print(f"❌ FAIL: status={result['status']}, elapsed={elapsed:.2f}s", file=sys.stderr) 194 set_interrupt(False) 195 return 1 196 197 print("❌ FAIL: No result returned", file=sys.stderr) 198 set_interrupt(False) 199 return 1 200 201 202 if __name__ == "__main__": 203 sys.exit(main())