/ tests / gateway / test_interrupt_key_match.py
test_interrupt_key_match.py
  1  """Tests verifying interrupt key consistency between adapter and gateway.
  2  
  3  Regression test for a bug where monitor_for_interrupt() in _run_agent used
  4  source.chat_id to query the adapter, but the adapter stores interrupts under
  5  the full session key (build_session_key output).  This mismatch meant
  6  interrupts were never detected, causing subagents to ignore new messages.
  7  """
  8  
  9  import asyncio
 10  
 11  import pytest
 12  
 13  from gateway.config import Platform, PlatformConfig
 14  from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType, SendResult
 15  from gateway.session import SessionSource, build_session_key
 16  
 17  
 18  class StubAdapter(BasePlatformAdapter):
 19      """Minimal adapter for interrupt tests."""
 20  
 21      def __init__(self):
 22          super().__init__(PlatformConfig(enabled=True, token="test"), Platform.TELEGRAM)
 23  
 24      async def connect(self):
 25          return True
 26  
 27      async def disconnect(self):
 28          pass
 29  
 30      async def send(self, chat_id, content, reply_to=None, metadata=None):
 31          return SendResult(success=True, message_id="1")
 32  
 33      async def send_typing(self, chat_id, metadata=None):
 34          pass
 35  
 36      async def get_chat_info(self, chat_id):
 37          return {"id": chat_id}
 38  
 39  
 40  def _source(chat_id="123456", chat_type="dm", thread_id=None):
 41      return SessionSource(
 42          platform=Platform.TELEGRAM,
 43          chat_id=chat_id,
 44          chat_type=chat_type,
 45          thread_id=thread_id,
 46      )
 47  
 48  
 49  class TestInterruptKeyConsistency:
 50      """Ensure adapter interrupt methods are queried with session_key, not chat_id."""
 51  
 52      def test_session_key_differs_from_chat_id_for_dm(self):
 53          """Session key for a DM is namespaced and includes the DM chat_id."""
 54          source = _source("123456", "dm")
 55          session_key = build_session_key(source)
 56          assert session_key != source.chat_id
 57          assert session_key == "agent:main:telegram:dm:123456"
 58  
 59      def test_session_key_differs_from_chat_id_for_group(self):
 60          """Session key for a group chat includes prefix, unlike raw chat_id."""
 61          source = _source("-1001234", "group")
 62          session_key = build_session_key(source)
 63          assert session_key != source.chat_id
 64          assert "agent:main:" in session_key
 65          assert source.chat_id in session_key
 66  
 67      @pytest.mark.asyncio
 68      async def test_has_pending_interrupt_requires_session_key(self):
 69          """has_pending_interrupt returns True only when queried with session_key."""
 70          adapter = StubAdapter()
 71          source = _source("123456", "dm")
 72          session_key = build_session_key(source)
 73  
 74          # Simulate adapter storing interrupt under session_key
 75          interrupt_event = asyncio.Event()
 76          adapter._active_sessions[session_key] = interrupt_event
 77          interrupt_event.set()
 78  
 79          # Using session_key → found
 80          assert adapter.has_pending_interrupt(session_key) is True
 81  
 82          # Using chat_id → NOT found (this was the bug)
 83          assert adapter.has_pending_interrupt(source.chat_id) is False
 84  
 85      @pytest.mark.asyncio
 86      async def test_get_pending_message_requires_session_key(self):
 87          """get_pending_message returns the event only with session_key."""
 88          adapter = StubAdapter()
 89          source = _source("123456", "dm")
 90          session_key = build_session_key(source)
 91  
 92          event = MessageEvent(text="hello", source=source, message_id="42")
 93          adapter._pending_messages[session_key] = event
 94  
 95          # Using chat_id → None (the bug)
 96          assert adapter.get_pending_message(source.chat_id) is None
 97  
 98          # Using session_key → found
 99          result = adapter.get_pending_message(session_key)
100          assert result is event
101  
102      @pytest.mark.asyncio
103      async def test_handle_message_stores_under_session_key(self):
104          """handle_message stores pending messages under session_key, not chat_id."""
105          adapter = StubAdapter()
106          adapter.set_message_handler(lambda event: asyncio.sleep(0, result=None))
107  
108          source = _source("-1001234", "group")
109          session_key = build_session_key(source)
110  
111          # Mark session as active
112          adapter._active_sessions[session_key] = asyncio.Event()
113  
114          # Send a second message while session is active
115          event = MessageEvent(text="interrupt!", source=source, message_id="2")
116          await adapter.handle_message(event)
117  
118          # Stored under session_key
119          assert session_key in adapter._pending_messages
120          # NOT stored under chat_id
121          assert source.chat_id not in adapter._pending_messages
122  
123          # Interrupt event was set
124          assert adapter._active_sessions[session_key].is_set()
125  
126      @pytest.mark.asyncio
127      async def test_photo_followup_is_queued_without_interrupt(self):
128          """Photo follow-ups should queue behind the active run instead of interrupting it."""
129          adapter = StubAdapter()
130          adapter.set_message_handler(lambda event: asyncio.sleep(0, result=None))
131  
132          source = _source("-1001234", "group")
133          session_key = build_session_key(source)
134          interrupt_event = asyncio.Event()
135          adapter._active_sessions[session_key] = interrupt_event
136  
137          event = MessageEvent(
138              text="caption",
139              source=source,
140              message_type=MessageType.PHOTO,
141              message_id="2",
142              media_urls=["/tmp/photo-a.jpg"],
143              media_types=["image/jpeg"],
144          )
145          await adapter.handle_message(event)
146  
147          queued = adapter._pending_messages[session_key]
148          assert queued is event
149          assert queued.media_urls == ["/tmp/photo-a.jpg"]
150          assert interrupt_event.is_set() is False