/ tests / run_interrupt_test.py
run_interrupt_test.py
  1  #!/usr/bin/env python3
  2  """Run a real interrupt test with actual AIAgent + delegate child.
  3  
  4  Not a pytest test — runs directly as a script for live testing.
  5  """
  6  
  7  import threading
  8  import time
  9  import sys
 10  import os
 11  
 12  sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 13  
 14  from unittest.mock import MagicMock, patch
 15  from run_agent import AIAgent, IterationBudget
 16  from tools.delegate_tool import _run_single_child
 17  from tools.interrupt import set_interrupt, is_interrupted
 18  
 19  def main() -> int:
 20      set_interrupt(False)
 21  
 22      # Create parent agent (minimal)
 23      parent = AIAgent.__new__(AIAgent)
 24      parent._interrupt_requested = False
 25      parent._interrupt_message = None
 26      parent._active_children = []
 27      parent._active_children_lock = threading.Lock()
 28      parent.quiet_mode = True
 29      parent.model = "test/model"
 30      parent.base_url = "http://localhost:1"
 31      parent.api_key = "test"
 32      parent.provider = "test"
 33      parent.api_mode = "chat_completions"
 34      parent.platform = "cli"
 35      parent.enabled_toolsets = ["terminal", "file"]
 36      parent.providers_allowed = None
 37      parent.providers_ignored = None
 38      parent.providers_order = None
 39      parent.provider_sort = None
 40      parent.max_tokens = None
 41      parent.reasoning_config = None
 42      parent.prefill_messages = None
 43      parent._session_db = None
 44      parent._delegate_depth = 0
 45      parent._delegate_spinner = None
 46      parent.tool_progress_callback = None
 47      parent.iteration_budget = IterationBudget(max_total=100)
 48      parent._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1"}
 49  
 50      child_started = threading.Event()
 51      result_holder = [None]
 52  
 53      def run_delegate():
 54          with patch("run_agent.OpenAI") as MockOpenAI:
 55              mock_client = MagicMock()
 56  
 57              def slow_create(**kwargs):
 58                  time.sleep(3)
 59                  resp = MagicMock()
 60                  resp.choices = [MagicMock()]
 61                  resp.choices[0].message.content = "Done"
 62                  resp.choices[0].message.tool_calls = None
 63                  resp.choices[0].message.refusal = None
 64                  resp.choices[0].finish_reason = "stop"
 65                  resp.usage.prompt_tokens = 100
 66                  resp.usage.completion_tokens = 10
 67                  resp.usage.total_tokens = 110
 68                  resp.usage.prompt_tokens_details = None
 69                  return resp
 70  
 71              mock_client.chat.completions.create = slow_create
 72              mock_client.close = MagicMock()
 73              MockOpenAI.return_value = mock_client
 74  
 75              original_init = AIAgent.__init__
 76  
 77              def patched_init(self_agent, *a, **kw):
 78                  original_init(self_agent, *a, **kw)
 79                  child_started.set()
 80  
 81              with patch.object(AIAgent, "__init__", patched_init):
 82                  try:
 83                      result = _run_single_child(
 84                          task_index=0,
 85                          goal="Test slow task",
 86                          context=None,
 87                          toolsets=["terminal"],
 88                          model="test/model",
 89                          max_iterations=5,
 90                          parent_agent=parent,
 91                          task_count=1,
 92                          override_provider="test",
 93                          override_base_url="http://localhost:1",
 94                          override_api_key="test",
 95                          override_api_mode="chat_completions",
 96                      )
 97                      result_holder[0] = result
 98                  except Exception as e:
 99                      print(f"ERROR in delegate: {e}")
100                      import traceback
101                      traceback.print_exc()
102  
103      print("Starting agent thread...")
104      agent_thread = threading.Thread(target=run_delegate, daemon=True)
105      agent_thread.start()
106  
107      started = child_started.wait(timeout=10)
108      if not started:
109          print("ERROR: Child never started")
110          set_interrupt(False)
111          return 1
112  
113      time.sleep(0.5)
114  
115      print(f"Active children: {len(parent._active_children)}")
116      for i, c in enumerate(parent._active_children):
117          print(f"  Child {i}: _interrupt_requested={c._interrupt_requested}")
118  
119      t0 = time.monotonic()
120      parent.interrupt("User typed a new message")
121      print("Called parent.interrupt()")
122  
123      for i, c in enumerate(parent._active_children):
124          print(f"  Child {i} after interrupt: _interrupt_requested={c._interrupt_requested}")
125      print(f"Global is_interrupted: {is_interrupted()}")
126  
127      agent_thread.join(timeout=10)
128      elapsed = time.monotonic() - t0
129      print(f"Agent thread finished in {elapsed:.2f}s")
130  
131      result = result_holder[0]
132      if result:
133          print(f"Status: {result['status']}")
134          print(f"Duration: {result['duration_seconds']}s")
135          if elapsed < 2.0:
136              print("✅ PASS: Interrupt detected quickly!")
137          else:
138              print(f"❌ FAIL: Took {elapsed:.2f}s — interrupt was too slow or not detected")
139      else:
140          print("❌ FAIL: No result!")
141  
142      set_interrupt(False)
143      return 0
144  
145  
146  if __name__ == "__main__":
147      sys.exit(main())