/ tests / gateway / test_pre_gateway_dispatch.py
test_pre_gateway_dispatch.py
  1  """Tests for the pre_gateway_dispatch plugin hook.
  2  
  3  The hook allows plugins to intercept incoming messages before auth and
  4  agent dispatch. It runs in _handle_message and acts on returned action
  5  dicts: {"action": "skip"|"rewrite"|"allow"}.
  6  """
  7  
  8  from types import SimpleNamespace
  9  from unittest.mock import AsyncMock, MagicMock
 10  
 11  import pytest
 12  
 13  from gateway.config import GatewayConfig, Platform, PlatformConfig
 14  from gateway.platforms.base import MessageEvent
 15  from gateway.session import SessionSource
 16  
 17  
 18  def _clear_auth_env(monkeypatch) -> None:
 19      for key in (
 20          "TELEGRAM_ALLOWED_USERS",
 21          "WHATSAPP_ALLOWED_USERS",
 22          "GATEWAY_ALLOWED_USERS",
 23          "TELEGRAM_ALLOW_ALL_USERS",
 24          "WHATSAPP_ALLOW_ALL_USERS",
 25          "GATEWAY_ALLOW_ALL_USERS",
 26      ):
 27          monkeypatch.delenv(key, raising=False)
 28  
 29  
 30  def _make_event(text: str = "hello", platform: Platform = Platform.WHATSAPP) -> MessageEvent:
 31      return MessageEvent(
 32          text=text,
 33          message_id="m1",
 34          source=SessionSource(
 35              platform=platform,
 36              user_id="15551234567@s.whatsapp.net",
 37              chat_id="15551234567@s.whatsapp.net",
 38              user_name="tester",
 39              chat_type="dm",
 40          ),
 41      )
 42  
 43  
 44  def _make_runner(platform: Platform):
 45      from gateway.run import GatewayRunner
 46  
 47      config = GatewayConfig(
 48          platforms={platform: PlatformConfig(enabled=True)},
 49      )
 50      runner = object.__new__(GatewayRunner)
 51      runner.config = config
 52      adapter = SimpleNamespace(send=AsyncMock())
 53      runner.adapters = {platform: adapter}
 54      runner.pairing_store = MagicMock()
 55      runner.pairing_store.is_approved.return_value = False
 56      runner.pairing_store._is_rate_limited.return_value = False
 57      runner.session_store = MagicMock()
 58      runner._running_agents = {}
 59      runner._update_prompt_pending = {}
 60      return runner, adapter
 61  
 62  
 63  @pytest.mark.asyncio
 64  async def test_hook_skip_short_circuits_dispatch(monkeypatch):
 65      """A plugin returning {'action': 'skip'} drops the message before auth."""
 66      _clear_auth_env(monkeypatch)
 67  
 68      def _fake_hook(name, **kwargs):
 69          if name == "pre_gateway_dispatch":
 70              return [{"action": "skip", "reason": "plugin-handled"}]
 71          return []
 72  
 73      monkeypatch.setattr("hermes_cli.plugins.invoke_hook", _fake_hook)
 74  
 75      runner, adapter = _make_runner(Platform.WHATSAPP)
 76  
 77      result = await runner._handle_message(_make_event("hi"))
 78  
 79      assert result is None
 80      adapter.send.assert_not_awaited()
 81      runner.pairing_store.generate_code.assert_not_called()
 82  
 83  
 84  @pytest.mark.asyncio
 85  async def test_hook_rewrite_replaces_event_text(monkeypatch):
 86      """A plugin returning {'action': 'rewrite', 'text': ...} mutates event.text."""
 87      _clear_auth_env(monkeypatch)
 88      monkeypatch.setenv("WHATSAPP_ALLOWED_USERS", "*")
 89  
 90      seen_text = {}
 91  
 92      def _fake_hook(name, **kwargs):
 93          if name == "pre_gateway_dispatch":
 94              return [{"action": "rewrite", "text": "REWRITTEN"}]
 95          return []
 96  
 97      async def _capture(event, source, _quick_key, _run_generation):
 98          seen_text["value"] = event.text
 99          return "ok"
100  
101      monkeypatch.setattr("hermes_cli.plugins.invoke_hook", _fake_hook)
102  
103      runner, _adapter = _make_runner(Platform.WHATSAPP)
104      runner._handle_message_with_agent = _capture  # noqa: SLF001
105  
106      await runner._handle_message(_make_event("original"))
107  
108      assert seen_text.get("value") == "REWRITTEN"
109  
110  
111  @pytest.mark.asyncio
112  async def test_hook_allow_falls_through_to_auth(monkeypatch):
113      """A plugin returning {'action': 'allow'} continues to normal dispatch."""
114      _clear_auth_env(monkeypatch)
115      # No allowed users set → auth fails → pairing flow triggers.
116      monkeypatch.delenv("WHATSAPP_ALLOWED_USERS", raising=False)
117  
118      def _fake_hook(name, **kwargs):
119          if name == "pre_gateway_dispatch":
120              return [{"action": "allow"}]
121          return []
122  
123      monkeypatch.setattr("hermes_cli.plugins.invoke_hook", _fake_hook)
124  
125      runner, adapter = _make_runner(Platform.WHATSAPP)
126      runner.pairing_store.generate_code.return_value = "12345"
127  
128      result = await runner._handle_message(_make_event("hi"))
129  
130      # auth chain ran → pairing code was generated
131      assert result is None
132      runner.pairing_store.generate_code.assert_called_once()
133  
134  
135  @pytest.mark.asyncio
136  async def test_hook_exception_does_not_break_dispatch(monkeypatch):
137      """A raising plugin hook does not break the gateway."""
138      _clear_auth_env(monkeypatch)
139      monkeypatch.delenv("WHATSAPP_ALLOWED_USERS", raising=False)
140  
141      def _fake_hook(name, **kwargs):
142          raise RuntimeError("plugin blew up")
143  
144      monkeypatch.setattr("hermes_cli.plugins.invoke_hook", _fake_hook)
145  
146      runner, _adapter = _make_runner(Platform.WHATSAPP)
147      runner.pairing_store.generate_code.return_value = None
148  
149      # Should not raise; falls through to auth chain.
150      result = await runner._handle_message(_make_event("hi"))
151      assert result is None
152  
153  
154  @pytest.mark.asyncio
155  async def test_internal_events_bypass_hook(monkeypatch):
156      """Internal events (event.internal=True) skip the plugin hook entirely."""
157      _clear_auth_env(monkeypatch)
158      monkeypatch.setenv("WHATSAPP_ALLOWED_USERS", "*")
159  
160      called = {"count": 0}
161  
162      def _fake_hook(name, **kwargs):
163          called["count"] += 1
164          return [{"action": "skip"}]
165  
166      async def _capture(event, source, _quick_key, _run_generation):
167          return "ok"
168  
169      monkeypatch.setattr("hermes_cli.plugins.invoke_hook", _fake_hook)
170  
171      runner, _adapter = _make_runner(Platform.WHATSAPP)
172      runner._handle_message_with_agent = _capture  # noqa: SLF001
173  
174      event = _make_event("hi")
175      event.internal = True
176  
177      # Even though the hook would say skip, internal events bypass it.
178      await runner._handle_message(event)
179      assert called["count"] == 0