/ tests / gateway / test_cancel_background_drain.py
test_cancel_background_drain.py
  1  """Regression test: cancel_background_tasks must drain late-arrival tasks.
  2  
  3  During gateway shutdown, a message arriving while
  4  cancel_background_tasks is mid-await can spawn a fresh
  5  _process_message_background task via handle_message, which is added
  6  to self._background_tasks.  Without the re-drain loop, the subsequent
  7  _background_tasks.clear() drops the reference; the task runs
  8  untracked against a disconnecting adapter.
  9  """
 10  
 11  import asyncio
 12  from unittest.mock import AsyncMock
 13  
 14  import pytest
 15  
 16  from gateway.config import Platform, PlatformConfig
 17  from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType
 18  from gateway.session import SessionSource, build_session_key
 19  
 20  
 21  class _StubAdapter(BasePlatformAdapter):
 22      async def connect(self):
 23          pass
 24  
 25      async def disconnect(self):
 26          pass
 27  
 28      async def send(self, chat_id, text, **kwargs):
 29          return None
 30  
 31      async def get_chat_info(self, chat_id):
 32          return {}
 33  
 34  
 35  def _make_adapter():
 36      adapter = _StubAdapter(PlatformConfig(enabled=True, token="t"), Platform.TELEGRAM)
 37      adapter._send_with_retry = AsyncMock(return_value=None)
 38      return adapter
 39  
 40  
 41  def _event(text, cid="42"):
 42      return MessageEvent(
 43          text=text,
 44          message_type=MessageType.TEXT,
 45          source=SessionSource(platform=Platform.TELEGRAM, chat_id=cid, chat_type="dm"),
 46      )
 47  
 48  
 49  @pytest.mark.asyncio
 50  async def test_cancel_background_tasks_drains_late_arrivals():
 51      """A message that arrives during the gather window must be picked
 52      up by the re-drain loop, not leaked as an untracked task."""
 53      adapter = _make_adapter()
 54      sk = build_session_key(
 55          SessionSource(platform=Platform.TELEGRAM, chat_id="42", chat_type="dm")
 56      )
 57  
 58      m1_started = asyncio.Event()
 59      m1_cleanup_running = asyncio.Event()
 60      m2_started = asyncio.Event()
 61      m2_cancelled = asyncio.Event()
 62  
 63      async def handler(event):
 64          if event.text == "M1":
 65              m1_started.set()
 66              try:
 67                  await asyncio.sleep(10)
 68              except asyncio.CancelledError:
 69                  m1_cleanup_running.set()
 70                  # Widen the gather window with a shielded cleanup
 71                  # delay so M2 can get injected during it.
 72                  await asyncio.shield(asyncio.sleep(0.2))
 73                  raise
 74          else:  # M2 — the late arrival
 75              m2_started.set()
 76              try:
 77                  await asyncio.sleep(10)
 78              except asyncio.CancelledError:
 79                  m2_cancelled.set()
 80                  raise
 81  
 82      adapter._message_handler = handler
 83  
 84      # Spawn M1.
 85      await adapter.handle_message(_event("M1"))
 86      await asyncio.wait_for(m1_started.wait(), timeout=1.0)
 87  
 88      # Kick off shutdown.  This will cancel M1 and await its cleanup.
 89      cancel_task = asyncio.create_task(adapter.cancel_background_tasks())
 90  
 91      # Wait until M1's cleanup is running (inside the shielded sleep).
 92      # This is the race window: cancel_task is awaiting gather, M1 is
 93      # shielded in cleanup, the _active_sessions entry has been cleared
 94      # by M1's own finally.
 95      await asyncio.wait_for(m1_cleanup_running.wait(), timeout=1.0)
 96  
 97      # Clear the active-session entry (M1's finally hasn't fully run yet,
 98      # but in production the platform dispatcher would deliver a new
 99      # message that takes the no-active-session spawn path).  For this
100      # repro, make it deterministic.
101      adapter._active_sessions.pop(sk, None)
102  
103      # Inject late arrival — spawns a fresh _process_message_background
104      # task and adds it to _background_tasks while cancel_task is still
105      # in gather.
106      await adapter.handle_message(_event("M2"))
107      await asyncio.wait_for(m2_started.wait(), timeout=1.0)
108  
109      # Let cancel_task finish.  Round 1's gather completes when M1's
110      # shielded cleanup finishes.  Round 2 should pick up M2.
111      await asyncio.wait_for(cancel_task, timeout=5.0)
112  
113      # Assert M2 was drained, not leaked.
114      assert m2_cancelled.is_set(), (
115          "Late-arrival M2 was NOT cancelled by cancel_background_tasks — "
116          "the re-drain loop is missing and the task leaked"
117      )
118      assert adapter._background_tasks == set()
119  
120  
121  @pytest.mark.asyncio
122  async def test_cancel_background_tasks_handles_no_tasks():
123      """Regression guard: no tasks, no hang, no error."""
124      adapter = _make_adapter()
125      await adapter.cancel_background_tasks()
126      assert adapter._background_tasks == set()
127  
128  
129  @pytest.mark.asyncio
130  async def test_cancel_background_tasks_bounded_rounds():
131      """Regression guard: the drain loop is bounded — it does not spin
132      forever even if late-arrival tasks keep getting spawned."""
133      adapter = _make_adapter()
134  
135      # Single well-behaved task that cancels cleanly — baseline check
136      # that the loop terminates in one round.
137      async def quick():
138          try:
139              await asyncio.sleep(10)
140          except asyncio.CancelledError:
141              raise
142  
143      task = asyncio.create_task(quick())
144      adapter._background_tasks.add(task)
145  
146      await adapter.cancel_background_tasks()
147      assert task.done()
148      assert adapter._background_tasks == set()