/ tests / cli / test_cprint_bg_thread.py
test_cprint_bg_thread.py
  1  """Tests for cli._cprint's bg-thread cooperation with prompt_toolkit.
  2  
  3  Background: when a prompt_toolkit Application is running, a bg thread that
  4  calls ``_pt_print`` directly can race with the input-area redraw and the
  5  printed line can end up visually buried behind the prompt.  ``_cprint`` now
  6  routes cross-thread prints through ``run_in_terminal`` via
  7  ``loop.call_soon_threadsafe`` so the self-improvement background review's
  8  ``💾 Self-improvement review: …`` summary actually surfaces to the user.
  9  
 10  These tests verify the routing logic without spinning up a real PT app.
 11  """
 12  
 13  from __future__ import annotations
 14  
 15  import sys
 16  import types
 17  from types import SimpleNamespace
 18  
 19  import cli
 20  
 21  
 22  def test_cprint_no_app_direct_print(monkeypatch):
 23      """No active app → direct _pt_print, no run_in_terminal involvement."""
 24      calls = []
 25      monkeypatch.setattr(cli, "_pt_print", lambda x: calls.append(("pt_print", x)))
 26      monkeypatch.setattr(cli, "_PT_ANSI", lambda t: ("ANSI", t))
 27  
 28      # Patch the prompt_toolkit import the function performs internally.
 29      fake_pt_app = types.ModuleType("prompt_toolkit.application")
 30      fake_pt_app.get_app_or_none = lambda: None
 31      fake_pt_app.run_in_terminal = lambda *a, **kw: calls.append(("run_in_terminal",))
 32      monkeypatch.setitem(sys.modules, "prompt_toolkit.application", fake_pt_app)
 33  
 34      cli._cprint("hello")
 35  
 36      assert calls == [("pt_print", ("ANSI", "hello"))]
 37  
 38  
 39  def test_cprint_app_not_running_direct_print(monkeypatch):
 40      """App exists but not running (e.g. teardown) → direct print."""
 41      calls = []
 42      monkeypatch.setattr(cli, "_pt_print", lambda x: calls.append(("pt_print", x)))
 43      monkeypatch.setattr(cli, "_PT_ANSI", lambda t: t)
 44  
 45      fake_app = SimpleNamespace(_is_running=False, loop=None)
 46      fake_pt_app = types.ModuleType("prompt_toolkit.application")
 47      fake_pt_app.get_app_or_none = lambda: fake_app
 48      fake_pt_app.run_in_terminal = lambda *a, **kw: calls.append(("run_in_terminal",))
 49      monkeypatch.setitem(sys.modules, "prompt_toolkit.application", fake_pt_app)
 50  
 51      cli._cprint("x")
 52  
 53      assert calls == [("pt_print", "x")]
 54  
 55  
 56  def test_cprint_bg_thread_schedules_on_app_loop(monkeypatch):
 57      """App running + different thread → schedules via call_soon_threadsafe."""
 58      scheduled = []
 59      direct_prints = []
 60  
 61      monkeypatch.setattr(cli, "_pt_print", lambda x: direct_prints.append(x))
 62      monkeypatch.setattr(cli, "_PT_ANSI", lambda t: t)
 63  
 64      class FakeLoop:
 65          def is_running(self):
 66              return True
 67  
 68          def call_soon_threadsafe(self, cb, *args):
 69              scheduled.append(cb)
 70  
 71      fake_loop = FakeLoop()
 72  
 73      # Install a fake "current loop" that is NOT the app's loop, so the
 74      # cross-thread branch is taken.
 75      fake_current_loop = SimpleNamespace(is_running=lambda: True)
 76      fake_asyncio = types.ModuleType("asyncio")
 77  
 78      class _Policy:
 79          def get_event_loop(self):
 80              return fake_current_loop
 81  
 82      fake_asyncio.get_event_loop_policy = lambda: _Policy()
 83      monkeypatch.setitem(sys.modules, "asyncio", fake_asyncio)
 84  
 85      fake_app = SimpleNamespace(_is_running=True, loop=fake_loop)
 86      fake_pt_app = types.ModuleType("prompt_toolkit.application")
 87      fake_pt_app.get_app_or_none = lambda: fake_app
 88  
 89      run_in_terminal_calls = []
 90  
 91      def _fake_run_in_terminal(func, **kw):
 92          run_in_terminal_calls.append(func)
 93          # Simulate run_in_terminal actually calling func (as the real PT
 94          # impl would once the app loop tick picks it up).
 95          func()
 96          return None
 97  
 98      fake_pt_app.run_in_terminal = _fake_run_in_terminal
 99      monkeypatch.setitem(sys.modules, "prompt_toolkit.application", fake_pt_app)
100  
101      cli._cprint("💾 Self-improvement review: Skill updated")
102  
103      # call_soon_threadsafe must have been called with a scheduling cb.
104      assert len(scheduled) == 1
105  
106      # Invoking the scheduled callback should hit run_in_terminal.
107      scheduled[0]()
108      assert len(run_in_terminal_calls) == 1
109  
110      # And run_in_terminal's inner func should have emitted a pt_print.
111      assert direct_prints == ["💾 Self-improvement review: Skill updated"]
112  
113  
114  def test_cprint_same_thread_as_app_loop_direct_print(monkeypatch):
115      """App running on same thread → direct print (no scheduling)."""
116      direct_prints = []
117      monkeypatch.setattr(cli, "_pt_print", lambda x: direct_prints.append(x))
118      monkeypatch.setattr(cli, "_PT_ANSI", lambda t: t)
119  
120      class FakeLoop:
121          def is_running(self):
122              return True
123  
124          def call_soon_threadsafe(self, cb, *args):
125              raise AssertionError(
126                  "call_soon_threadsafe must not be used on the app's own thread"
127              )
128  
129      fake_loop = FakeLoop()
130      fake_asyncio = types.ModuleType("asyncio")
131  
132      class _Policy:
133          def get_event_loop(self):
134              return fake_loop  # same as app loop
135  
136      fake_asyncio.get_event_loop_policy = lambda: _Policy()
137      monkeypatch.setitem(sys.modules, "asyncio", fake_asyncio)
138  
139      fake_app = SimpleNamespace(_is_running=True, loop=fake_loop)
140      fake_pt_app = types.ModuleType("prompt_toolkit.application")
141      fake_pt_app.get_app_or_none = lambda: fake_app
142      fake_pt_app.run_in_terminal = lambda *a, **kw: None
143      monkeypatch.setitem(sys.modules, "prompt_toolkit.application", fake_pt_app)
144  
145      cli._cprint("x")
146  
147      assert direct_prints == ["x"]
148  
149  
150  def test_cprint_swallows_app_loop_attr_error(monkeypatch):
151      """Loop missing on app → fall back to direct print, no crash."""
152      direct_prints = []
153      monkeypatch.setattr(cli, "_pt_print", lambda x: direct_prints.append(x))
154      monkeypatch.setattr(cli, "_PT_ANSI", lambda t: t)
155  
156      class WeirdApp:
157          _is_running = True
158  
159          @property
160          def loop(self):
161              raise RuntimeError("no loop for you")
162  
163      fake_pt_app = types.ModuleType("prompt_toolkit.application")
164      fake_pt_app.get_app_or_none = lambda: WeirdApp()
165      fake_pt_app.run_in_terminal = lambda *a, **kw: None
166      monkeypatch.setitem(sys.modules, "prompt_toolkit.application", fake_pt_app)
167  
168      cli._cprint("fallback")
169  
170      assert direct_prints == ["fallback"]
171  
172  
173  def test_cprint_swallows_prompt_toolkit_import_error(monkeypatch):
174      """If prompt_toolkit.application itself fails to import, fall back."""
175      direct_prints = []
176      monkeypatch.setattr(cli, "_pt_print", lambda x: direct_prints.append(x))
177      monkeypatch.setattr(cli, "_PT_ANSI", lambda t: t)
178  
179      # Drop cached prompt_toolkit.application AND install a meta-path finder
180      # that raises ImportError on re-import.
181      monkeypatch.delitem(sys.modules, "prompt_toolkit.application", raising=False)
182  
183      class _BlockFinder:
184          def find_module(self, name, path=None):
185              if name == "prompt_toolkit.application":
186                  return self
187              return None
188  
189          def load_module(self, name):
190              raise ImportError("blocked for test")
191  
192          def find_spec(self, name, path=None, target=None):
193              if name == "prompt_toolkit.application":
194                  # Returning a bogus spec that will fail on load works too,
195                  # but raising here keeps the test simple.
196                  raise ImportError("blocked for test")
197              return None
198  
199      blocker = _BlockFinder()
200      sys.meta_path.insert(0, blocker)
201      try:
202          cli._cprint("fallback2")
203      finally:
204          sys.meta_path.remove(blocker)
205  
206      assert direct_prints == ["fallback2"]