run_interrupt_test.py
1 #!/usr/bin/env python3 2 """Run a real interrupt test with actual AIAgent + delegate child. 3 4 Not a pytest test — runs directly as a script for live testing. 5 """ 6 7 import threading 8 import time 9 import sys 10 import os 11 12 sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 13 14 from unittest.mock import MagicMock, patch 15 from run_agent import AIAgent, IterationBudget 16 from tools.delegate_tool import _run_single_child 17 from tools.interrupt import set_interrupt, is_interrupted 18 19 def main() -> int: 20 set_interrupt(False) 21 22 # Create parent agent (minimal) 23 parent = AIAgent.__new__(AIAgent) 24 parent._interrupt_requested = False 25 parent._interrupt_message = None 26 parent._active_children = [] 27 parent._active_children_lock = threading.Lock() 28 parent.quiet_mode = True 29 parent.model = "test/model" 30 parent.base_url = "http://localhost:1" 31 parent.api_key = "test" 32 parent.provider = "test" 33 parent.api_mode = "chat_completions" 34 parent.platform = "cli" 35 parent.enabled_toolsets = ["terminal", "file"] 36 parent.providers_allowed = None 37 parent.providers_ignored = None 38 parent.providers_order = None 39 parent.provider_sort = None 40 parent.max_tokens = None 41 parent.reasoning_config = None 42 parent.prefill_messages = None 43 parent._session_db = None 44 parent._delegate_depth = 0 45 parent._delegate_spinner = None 46 parent.tool_progress_callback = None 47 parent.iteration_budget = IterationBudget(max_total=100) 48 parent._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1"} 49 50 child_started = threading.Event() 51 result_holder = [None] 52 53 def run_delegate(): 54 with patch("run_agent.OpenAI") as MockOpenAI: 55 mock_client = MagicMock() 56 57 def slow_create(**kwargs): 58 time.sleep(3) 59 resp = MagicMock() 60 resp.choices = [MagicMock()] 61 resp.choices[0].message.content = "Done" 62 resp.choices[0].message.tool_calls = None 63 resp.choices[0].message.refusal = None 64 resp.choices[0].finish_reason = "stop" 65 resp.usage.prompt_tokens = 100 66 resp.usage.completion_tokens = 10 67 resp.usage.total_tokens = 110 68 resp.usage.prompt_tokens_details = None 69 return resp 70 71 mock_client.chat.completions.create = slow_create 72 mock_client.close = MagicMock() 73 MockOpenAI.return_value = mock_client 74 75 original_init = AIAgent.__init__ 76 77 def patched_init(self_agent, *a, **kw): 78 original_init(self_agent, *a, **kw) 79 child_started.set() 80 81 with patch.object(AIAgent, "__init__", patched_init): 82 try: 83 result = _run_single_child( 84 task_index=0, 85 goal="Test slow task", 86 context=None, 87 toolsets=["terminal"], 88 model="test/model", 89 max_iterations=5, 90 parent_agent=parent, 91 task_count=1, 92 override_provider="test", 93 override_base_url="http://localhost:1", 94 override_api_key="test", 95 override_api_mode="chat_completions", 96 ) 97 result_holder[0] = result 98 except Exception as e: 99 print(f"ERROR in delegate: {e}") 100 import traceback 101 traceback.print_exc() 102 103 print("Starting agent thread...") 104 agent_thread = threading.Thread(target=run_delegate, daemon=True) 105 agent_thread.start() 106 107 started = child_started.wait(timeout=10) 108 if not started: 109 print("ERROR: Child never started") 110 set_interrupt(False) 111 return 1 112 113 time.sleep(0.5) 114 115 print(f"Active children: {len(parent._active_children)}") 116 for i, c in enumerate(parent._active_children): 117 print(f" Child {i}: _interrupt_requested={c._interrupt_requested}") 118 119 t0 = time.monotonic() 120 parent.interrupt("User typed a new message") 121 print("Called parent.interrupt()") 122 123 for i, c in enumerate(parent._active_children): 124 print(f" Child {i} after interrupt: _interrupt_requested={c._interrupt_requested}") 125 print(f"Global is_interrupted: {is_interrupted()}") 126 127 agent_thread.join(timeout=10) 128 elapsed = time.monotonic() - t0 129 print(f"Agent thread finished in {elapsed:.2f}s") 130 131 result = result_holder[0] 132 if result: 133 print(f"Status: {result['status']}") 134 print(f"Duration: {result['duration_seconds']}s") 135 if elapsed < 2.0: 136 print("✅ PASS: Interrupt detected quickly!") 137 else: 138 print(f"❌ FAIL: Took {elapsed:.2f}s — interrupt was too slow or not detected") 139 else: 140 print("❌ FAIL: No result!") 141 142 set_interrupt(False) 143 return 0 144 145 146 if __name__ == "__main__": 147 sys.exit(main())