/ tests / gateway / test_session_boundary_hooks.py
test_session_boundary_hooks.py
  1  """Tests that on_session_finalize and on_session_reset plugin hooks fire in the gateway."""
  2  from datetime import datetime
  3  from types import SimpleNamespace
  4  from unittest.mock import AsyncMock, MagicMock, patch
  5  
  6  import pytest
  7  
  8  from gateway.config import GatewayConfig, Platform, PlatformConfig
  9  from gateway.platforms.base import MessageEvent
 10  from gateway.session import SessionEntry, SessionSource, build_session_key
 11  
 12  
 13  def _make_source() -> SessionSource:
 14      return SessionSource(
 15          platform=Platform.TELEGRAM,
 16          user_id="u1",
 17          chat_id="c1",
 18          user_name="tester",
 19          chat_type="dm",
 20      )
 21  
 22  
 23  def _make_event(text: str) -> MessageEvent:
 24      return MessageEvent(text=text, source=_make_source(), message_id="m1")
 25  
 26  
 27  def _make_runner():
 28      from gateway.run import GatewayRunner
 29  
 30      runner = object.__new__(GatewayRunner)
 31      runner.config = GatewayConfig(
 32          platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
 33      )
 34      adapter = MagicMock()
 35      adapter.send = AsyncMock()
 36      runner.adapters = {Platform.TELEGRAM: adapter}
 37      runner._voice_mode = {}
 38      runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
 39      runner._session_model_overrides = {}
 40      runner._pending_model_notes = {}
 41      runner._background_tasks = set()
 42  
 43      session_key = build_session_key(_make_source())
 44      session_entry = SessionEntry(
 45          session_key=session_key,
 46          session_id="sess-old",
 47          created_at=datetime.now(),
 48          updated_at=datetime.now(),
 49          platform=Platform.TELEGRAM,
 50          chat_type="dm",
 51      )
 52      new_session_entry = SessionEntry(
 53          session_key=session_key,
 54          session_id="sess-new",
 55          created_at=datetime.now(),
 56          updated_at=datetime.now(),
 57          platform=Platform.TELEGRAM,
 58          chat_type="dm",
 59      )
 60      runner.session_store = MagicMock()
 61      runner.session_store.get_or_create_session.return_value = new_session_entry
 62      runner.session_store.reset_session.return_value = new_session_entry
 63      runner.session_store._entries = {session_key: session_entry}
 64      runner.session_store._generate_session_key.return_value = session_key
 65      runner._running_agents = {}
 66      runner._pending_messages = {}
 67      runner._pending_approvals = {}
 68      runner._session_db = None
 69      runner._agent_cache_lock = None
 70      runner._is_user_authorized = lambda _source: True
 71      runner._format_session_info = lambda: ""
 72  
 73      return runner
 74  
 75  
 76  @pytest.mark.asyncio
 77  @patch("hermes_cli.plugins.invoke_hook")
 78  async def test_reset_fires_finalize_hook(mock_invoke_hook):
 79      """/new must fire on_session_finalize with the OLD session id."""
 80      runner = _make_runner()
 81  
 82      await runner._handle_reset_command(_make_event("/new"))
 83  
 84      mock_invoke_hook.assert_any_call(
 85          "on_session_finalize", session_id="sess-old", platform="telegram"
 86      )
 87  
 88  
 89  @pytest.mark.asyncio
 90  @patch("hermes_cli.plugins.invoke_hook")
 91  async def test_reset_fires_reset_hook(mock_invoke_hook):
 92      """/new must fire on_session_reset with the NEW session id."""
 93      runner = _make_runner()
 94  
 95      await runner._handle_reset_command(_make_event("/new"))
 96  
 97      mock_invoke_hook.assert_any_call(
 98          "on_session_reset", session_id="sess-new", platform="telegram"
 99      )
100  
101  
102  @pytest.mark.asyncio
103  @patch("hermes_cli.plugins.invoke_hook")
104  async def test_finalize_before_reset(mock_invoke_hook):
105      """on_session_finalize must fire before on_session_reset."""
106      runner = _make_runner()
107  
108      await runner._handle_reset_command(_make_event("/new"))
109  
110      calls = [c for c in mock_invoke_hook.call_args_list
111               if c[0][0] in ("on_session_finalize", "on_session_reset")]
112      hook_names = [c[0][0] for c in calls]
113      assert hook_names == ["on_session_finalize", "on_session_reset"]
114  
115  
116  @pytest.mark.asyncio
117  @patch("hermes_cli.plugins.invoke_hook")
118  async def test_shutdown_fires_finalize_for_active_agents(mock_invoke_hook):
119      """Gateway stop() must fire on_session_finalize for each active agent."""
120      from gateway.run import GatewayRunner
121  
122      runner = object.__new__(GatewayRunner)
123      runner._running = True
124      runner._background_tasks = set()
125      runner._pending_messages = {}
126      runner._pending_approvals = {}
127      runner._shutdown_event = MagicMock()
128      runner.adapters = {}
129      runner._exit_reason = "test"
130      runner._exit_code = None
131      runner._draining = False
132      runner._restart_requested = False
133      runner._restart_task_started = False
134      runner._restart_detached = False
135      runner._restart_via_service = False
136      runner._restart_drain_timeout = 0.0
137      runner._stop_task = None
138      runner._running_agents_ts = {}
139      runner._update_runtime_status = MagicMock()
140  
141      agent1 = MagicMock()
142      agent1.session_id = "sess-a"
143      agent2 = MagicMock()
144      agent2.session_id = "sess-b"
145      runner._running_agents = {"key-a": agent1, "key-b": agent2}
146  
147      with patch("gateway.status.remove_pid_file"), \
148           patch("gateway.status.write_runtime_status"):
149          await runner.stop()
150  
151      finalize_calls = [
152          c for c in mock_invoke_hook.call_args_list
153          if c[0][0] == "on_session_finalize"
154      ]
155      session_ids = {c[1]["session_id"] for c in finalize_calls}
156      assert session_ids == {"sess-a", "sess-b"}
157  
158  
159  @pytest.mark.asyncio
160  @patch("hermes_cli.plugins.invoke_hook", side_effect=Exception("boom"))
161  async def test_hook_error_does_not_break_reset(mock_invoke_hook):
162      """Plugin hook errors must not prevent /new from completing."""
163      runner = _make_runner()
164  
165      result = await runner._handle_reset_command(_make_event("/new"))
166  
167      # Should still return a success message despite hook errors
168      assert "Session reset" in result or "New session" in result
169  
170  
171  @pytest.mark.asyncio
172  @patch("hermes_cli.plugins.invoke_hook")
173  async def test_idle_expiry_fires_finalize_hook(mock_invoke_hook):
174      """Regression test for #14981.
175  
176      When ``_session_expiry_watcher`` sweeps a session that has aged past
177      its reset policy (idle timeout, scheduled reset), it must fire
178      ``on_session_finalize`` so plugin providers get the same final-pass
179      extraction opportunity they'd get from /new or CLI shutdown.  Before
180      the fix, the expiry path evicted the agent but silently skipped the
181      hook.
182      """
183      from datetime import datetime, timedelta
184  
185      from gateway.run import GatewayRunner
186  
187      runner = object.__new__(GatewayRunner)
188      runner._running = True
189      runner._running_agents = {}
190      runner._agent_cache = {}
191      runner._agent_cache_lock = None
192      runner._last_session_store_prune_ts = 0.0
193  
194      session_key = "agent:main:telegram:dm:42"
195      expired_entry = SessionEntry(
196          session_key=session_key,
197          session_id="sess-expired",
198          created_at=datetime.now() - timedelta(hours=2),
199          updated_at=datetime.now() - timedelta(hours=2),
200          platform=Platform.TELEGRAM,
201          chat_type="dm",
202      )
203      expired_entry.expiry_finalized = False
204  
205      runner.session_store = MagicMock()
206      runner.session_store._ensure_loaded = MagicMock()
207      runner.session_store._entries = {session_key: expired_entry}
208      runner.session_store._is_session_expired = MagicMock(return_value=True)
209      runner.session_store._lock = MagicMock()
210      runner.session_store._lock.__enter__ = MagicMock(return_value=None)
211      runner.session_store._lock.__exit__ = MagicMock(return_value=None)
212      runner.session_store._save = MagicMock()
213  
214      runner._evict_cached_agent = MagicMock()
215      runner._cleanup_agent_resources = MagicMock()
216      runner._sweep_idle_cached_agents = MagicMock(return_value=0)
217  
218      # The watcher starts with `await asyncio.sleep(60)` and loops while
219      # `self._running`.  Patch sleep so the 60s initial delay is instant, and
220      # make the expiry hook invocation flip `_running` false so the loop
221      # exits cleanly after one pass.
222      _orig_sleep = __import__("asyncio").sleep
223  
224      async def _fast_sleep(_):
225          await _orig_sleep(0)
226  
227      def _hook_and_stop(*a, **kw):
228          runner._running = False
229          return None
230  
231      mock_invoke_hook.side_effect = _hook_and_stop
232  
233      with patch("gateway.run.asyncio.sleep", side_effect=_fast_sleep):
234          await runner._session_expiry_watcher(interval=0)
235  
236      # Look for the finalize call targeting the expired session.
237      finalize_calls = [
238          c for c in mock_invoke_hook.call_args_list
239          if c[0] and c[0][0] == "on_session_finalize"
240      ]
241      session_ids = {c[1].get("session_id") for c in finalize_calls}
242      assert "sess-expired" in session_ids, (
243          f"on_session_finalize was not fired during idle expiry; "
244          f"got session_ids={session_ids} (regression of #14981)"
245      )