test_cancel_background_drain.py
1 """Regression test: cancel_background_tasks must drain late-arrival tasks. 2 3 During gateway shutdown, a message arriving while 4 cancel_background_tasks is mid-await can spawn a fresh 5 _process_message_background task via handle_message, which is added 6 to self._background_tasks. Without the re-drain loop, the subsequent 7 _background_tasks.clear() drops the reference; the task runs 8 untracked against a disconnecting adapter. 9 """ 10 11 import asyncio 12 from unittest.mock import AsyncMock 13 14 import pytest 15 16 from gateway.config import Platform, PlatformConfig 17 from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType 18 from gateway.session import SessionSource, build_session_key 19 20 21 class _StubAdapter(BasePlatformAdapter): 22 async def connect(self): 23 pass 24 25 async def disconnect(self): 26 pass 27 28 async def send(self, chat_id, text, **kwargs): 29 return None 30 31 async def get_chat_info(self, chat_id): 32 return {} 33 34 35 def _make_adapter(): 36 adapter = _StubAdapter(PlatformConfig(enabled=True, token="t"), Platform.TELEGRAM) 37 adapter._send_with_retry = AsyncMock(return_value=None) 38 return adapter 39 40 41 def _event(text, cid="42"): 42 return MessageEvent( 43 text=text, 44 message_type=MessageType.TEXT, 45 source=SessionSource(platform=Platform.TELEGRAM, chat_id=cid, chat_type="dm"), 46 ) 47 48 49 @pytest.mark.asyncio 50 async def test_cancel_background_tasks_drains_late_arrivals(): 51 """A message that arrives during the gather window must be picked 52 up by the re-drain loop, not leaked as an untracked task.""" 53 adapter = _make_adapter() 54 sk = build_session_key( 55 SessionSource(platform=Platform.TELEGRAM, chat_id="42", chat_type="dm") 56 ) 57 58 m1_started = asyncio.Event() 59 m1_cleanup_running = asyncio.Event() 60 m2_started = asyncio.Event() 61 m2_cancelled = asyncio.Event() 62 63 async def handler(event): 64 if event.text == "M1": 65 m1_started.set() 66 try: 67 await asyncio.sleep(10) 68 except asyncio.CancelledError: 69 m1_cleanup_running.set() 70 # Widen the gather window with a shielded cleanup 71 # delay so M2 can get injected during it. 72 await asyncio.shield(asyncio.sleep(0.2)) 73 raise 74 else: # M2 — the late arrival 75 m2_started.set() 76 try: 77 await asyncio.sleep(10) 78 except asyncio.CancelledError: 79 m2_cancelled.set() 80 raise 81 82 adapter._message_handler = handler 83 84 # Spawn M1. 85 await adapter.handle_message(_event("M1")) 86 await asyncio.wait_for(m1_started.wait(), timeout=1.0) 87 88 # Kick off shutdown. This will cancel M1 and await its cleanup. 89 cancel_task = asyncio.create_task(adapter.cancel_background_tasks()) 90 91 # Wait until M1's cleanup is running (inside the shielded sleep). 92 # This is the race window: cancel_task is awaiting gather, M1 is 93 # shielded in cleanup, the _active_sessions entry has been cleared 94 # by M1's own finally. 95 await asyncio.wait_for(m1_cleanup_running.wait(), timeout=1.0) 96 97 # Clear the active-session entry (M1's finally hasn't fully run yet, 98 # but in production the platform dispatcher would deliver a new 99 # message that takes the no-active-session spawn path). For this 100 # repro, make it deterministic. 101 adapter._active_sessions.pop(sk, None) 102 103 # Inject late arrival — spawns a fresh _process_message_background 104 # task and adds it to _background_tasks while cancel_task is still 105 # in gather. 106 await adapter.handle_message(_event("M2")) 107 await asyncio.wait_for(m2_started.wait(), timeout=1.0) 108 109 # Let cancel_task finish. Round 1's gather completes when M1's 110 # shielded cleanup finishes. Round 2 should pick up M2. 111 await asyncio.wait_for(cancel_task, timeout=5.0) 112 113 # Assert M2 was drained, not leaked. 114 assert m2_cancelled.is_set(), ( 115 "Late-arrival M2 was NOT cancelled by cancel_background_tasks — " 116 "the re-drain loop is missing and the task leaked" 117 ) 118 assert adapter._background_tasks == set() 119 120 121 @pytest.mark.asyncio 122 async def test_cancel_background_tasks_handles_no_tasks(): 123 """Regression guard: no tasks, no hang, no error.""" 124 adapter = _make_adapter() 125 await adapter.cancel_background_tasks() 126 assert adapter._background_tasks == set() 127 128 129 @pytest.mark.asyncio 130 async def test_cancel_background_tasks_bounded_rounds(): 131 """Regression guard: the drain loop is bounded — it does not spin 132 forever even if late-arrival tasks keep getting spawned.""" 133 adapter = _make_adapter() 134 135 # Single well-behaved task that cancels cleanly — baseline check 136 # that the loop terminates in one round. 137 async def quick(): 138 try: 139 await asyncio.sleep(10) 140 except asyncio.CancelledError: 141 raise 142 143 task = asyncio.create_task(quick()) 144 adapter._background_tasks.add(task) 145 146 await adapter.cancel_background_tasks() 147 assert task.done() 148 assert adapter._background_tasks == set()