/ tests / agent / test_memory_session_switch.py
test_memory_session_switch.py
  1  """Tests for the on_session_switch hook and session_id propagation.
  2  
  3  Covers #6672: memory providers must be notified when AIAgent.session_id
  4  rotates mid-process (via /resume, /branch, /reset, /new, or context
  5  compression). Without the notification, providers that cache per-session
  6  state in initialize() (Hindsight, and any plugin that stores session_id
  7  for scoped writes) keep writing into the old session's record.
  8  """
  9  
 10  import json
 11  
 12  import pytest
 13  
 14  from agent.memory_manager import MemoryManager
 15  from agent.memory_provider import MemoryProvider
 16  
 17  
 18  class _RecordingProvider(MemoryProvider):
 19      """Provider that records every lifecycle call for assertion."""
 20  
 21      def __init__(self, name="rec"):
 22          self._name = name
 23          self.switch_calls: list[dict] = []
 24          self.sync_calls: list[dict] = []
 25          self.queue_calls: list[dict] = []
 26          self.initialize_calls: list[dict] = []
 27  
 28      @property
 29      def name(self) -> str:
 30          return self._name
 31  
 32      def is_available(self) -> bool:  # pragma: no cover - unused
 33          return True
 34  
 35      def initialize(self, session_id, **kwargs):
 36          self.initialize_calls.append({"session_id": session_id, **kwargs})
 37  
 38      def get_tool_schemas(self):
 39          return []
 40  
 41      def sync_turn(self, user_content, assistant_content, *, session_id=""):
 42          self.sync_calls.append(
 43              {"user": user_content, "asst": assistant_content, "session_id": session_id}
 44          )
 45  
 46      def queue_prefetch(self, query, *, session_id=""):
 47          self.queue_calls.append({"query": query, "session_id": session_id})
 48  
 49      def on_session_switch(
 50          self,
 51          new_session_id,
 52          *,
 53          parent_session_id="",
 54          reset=False,
 55          **kwargs,
 56      ):
 57          self.switch_calls.append(
 58              {
 59                  "new": new_session_id,
 60                  "parent": parent_session_id,
 61                  "reset": reset,
 62                  "extra": kwargs,
 63              }
 64          )
 65  
 66  
 67  # ---------------------------------------------------------------------------
 68  # MemoryProvider ABC — default on_session_switch is a no-op
 69  # ---------------------------------------------------------------------------
 70  
 71  
 72  class _MinimalProvider(MemoryProvider):
 73      """Provider that does NOT override on_session_switch — ABC default must no-op."""
 74  
 75      @property
 76      def name(self) -> str:
 77          return "minimal"
 78  
 79      def is_available(self) -> bool:
 80          return True
 81  
 82      def initialize(self, session_id, **kwargs):  # pragma: no cover - unused
 83          pass
 84  
 85      def get_tool_schemas(self):
 86          return []
 87  
 88  
 89  def test_abc_default_on_session_switch_is_noop():
 90      """Providers that don't override the hook must not raise."""
 91      p = _MinimalProvider()
 92      # All three call styles must be accepted without raising
 93      p.on_session_switch("new-id")
 94      p.on_session_switch("new-id", parent_session_id="old-id")
 95      p.on_session_switch("new-id", parent_session_id="old-id", reset=True)
 96      p.on_session_switch("new-id", parent_session_id="old-id", reset=True, reason="new_session")
 97  
 98  
 99  # ---------------------------------------------------------------------------
100  # MemoryManager.on_session_switch — fan-out
101  # ---------------------------------------------------------------------------
102  
103  
104  def test_manager_fans_out_to_all_providers():
105      mm = MemoryManager()
106      # Only one external provider is allowed; use the builtin slot for p1.
107      p1 = _RecordingProvider(name="builtin")
108      p2 = _RecordingProvider(name="hindsight")
109      mm.add_provider(p1)
110      mm.add_provider(p2)
111  
112      mm.on_session_switch("new-sid", parent_session_id="old-sid", reset=False, reason="resume")
113  
114      assert len(p1.switch_calls) == 1
115      assert len(p2.switch_calls) == 1
116      for call in (p1.switch_calls[0], p2.switch_calls[0]):
117          assert call["new"] == "new-sid"
118          assert call["parent"] == "old-sid"
119          assert call["reset"] is False
120          assert call["extra"] == {"reason": "resume"}
121  
122  
123  def test_manager_ignores_empty_session_id():
124      """Empty string session_id must not trigger provider hooks.
125  
126      Prevents accidental fires during shutdown when self.session_id may be
127      cleared. Providers expect a meaningful id to switch TO.
128      """
129      mm = MemoryManager()
130      p = _RecordingProvider()
131      mm.add_provider(p)
132      mm.on_session_switch("")
133      mm.on_session_switch(None)  # type: ignore[arg-type]
134      assert p.switch_calls == []
135  
136  
137  def test_manager_isolates_provider_failures():
138      """A provider that raises must not block other providers."""
139  
140      class _Broken(_RecordingProvider):
141          def on_session_switch(self, *args, **kwargs):  # type: ignore[override]
142              raise RuntimeError("boom")
143  
144      mm = MemoryManager()
145      # MemoryManager rejects a second external provider, so pair broken
146      # (builtin slot) with a good external one.
147      broken = _Broken(name="builtin")
148      good = _RecordingProvider(name="good")
149      mm.add_provider(broken)
150      mm.add_provider(good)
151  
152      # Must not raise — exceptions in one provider are swallowed + logged
153      mm.on_session_switch("new-sid", parent_session_id="old-sid")
154      assert len(good.switch_calls) == 1
155      assert good.switch_calls[0]["new"] == "new-sid"
156  
157  
158  def test_manager_reset_flag_preserved():
159      mm = MemoryManager()
160      p = _RecordingProvider()
161      mm.add_provider(p)
162      mm.on_session_switch("new-sid", reset=True, reason="new_session")
163      assert p.switch_calls[0]["reset"] is True
164      assert p.switch_calls[0]["extra"] == {"reason": "new_session"}
165  
166  
167  # ---------------------------------------------------------------------------
168  # MemoryManager.sync_all / queue_prefetch_all — session_id propagation
169  # ---------------------------------------------------------------------------
170  
171  
172  def test_sync_all_propagates_session_id_to_providers():
173      """run_agent.py's sync_all call must pass session_id through to providers.
174  
175      Without this, a provider that updates _session_id defensively in
176      sync_turn (as Hindsight does at hindsight/__init__.py:1199) never
177      sees the new id and keeps writing under the old one.
178      """
179      mm = MemoryManager()
180      p = _RecordingProvider()
181      mm.add_provider(p)
182      mm.sync_all("hello", "world", session_id="sess-42")
183      assert p.sync_calls == [
184          {"user": "hello", "asst": "world", "session_id": "sess-42"}
185      ]
186  
187  
188  def test_queue_prefetch_all_propagates_session_id_to_providers():
189      mm = MemoryManager()
190      p = _RecordingProvider()
191      mm.add_provider(p)
192      mm.queue_prefetch_all("next query", session_id="sess-42")
193      assert p.queue_calls == [{"query": "next query", "session_id": "sess-42"}]
194  
195  
196  # ---------------------------------------------------------------------------
197  # Hindsight reference implementation — state-flush semantics
198  # ---------------------------------------------------------------------------
199  
200  
201  def _make_hindsight_provider():
202      """Build a bare HindsightMemoryProvider that skips network setup.
203  
204      We instantiate without importing optional deps at class-level by
205      bypassing __init__ and seeding the attributes on_session_switch
206      reads/writes. This keeps the test hermetic.
207      """
208      import threading
209      hindsight_mod = pytest.importorskip("plugins.memory.hindsight")
210      provider = object.__new__(hindsight_mod.HindsightMemoryProvider)
211      provider._session_id = "old-sid"
212      provider._parent_session_id = ""
213      provider._document_id = "old-sid-20260101_000000_000000"
214      provider._session_turns = ["turn-1", "turn-2"]
215      provider._turn_counter = 2
216      provider._turn_index = 2
217      # Attrs read by _build_metadata / _build_retain_kwargs when the
218      # buffer-flush path on session switch fires. Empty strings keep the
219      # metadata minimal but well-formed.
220      provider._retain_source = ""
221      provider._platform = ""
222      provider._user_id = ""
223      provider._user_name = ""
224      provider._chat_id = ""
225      provider._chat_name = ""
226      provider._chat_type = ""
227      provider._thread_id = ""
228      provider._agent_identity = ""
229      provider._agent_workspace = ""
230      provider._retain_tags = []
231      provider._retain_context = "test-context"
232      provider._retain_async = False
233      provider._bank_id = "test-bank"
234      # Prefetch state the switch path drains/clears.
235      provider._prefetch_thread = None
236      provider._prefetch_lock = threading.Lock()
237      provider._prefetch_result = ""
238      # Sync thread tracking (legacy alias at the writer).
239      provider._sync_thread = None
240      # Writer queue infra the flush-on-switch path enqueues onto. We stub
241      # _ensure_writer / _register_atexit so no real thread is spawned;
242      # tests exercising flush delivery live in
243      # tests/plugins/memory/test_hindsight_provider.py where the full
244      # writer-queue wiring is in place.
245      import queue as _queue
246      provider._retain_queue = _queue.Queue()
247      provider._shutting_down = threading.Event()
248      provider._atexit_registered = True
249      provider._ensure_writer = lambda: None
250      provider._register_atexit = lambda: None
251      # Stub the network-touching helper so any enqueued flush closure is
252      # a no-op if ever drained in a unit test.
253      provider._run_hindsight_operation = lambda _op: None
254      return provider
255  
256  
257  def test_hindsight_on_session_switch_updates_session_id_and_mints_fresh_doc():
258      provider = _make_hindsight_provider()
259      old_doc = provider._document_id
260  
261      provider.on_session_switch(
262          "new-sid", parent_session_id="old-sid", reset=False, reason="resume"
263      )
264  
265      assert provider._session_id == "new-sid"
266      assert provider._parent_session_id == "old-sid"
267      # Document id MUST be fresh — else next retain overwrites old session doc
268      assert provider._document_id != old_doc
269      assert provider._document_id.startswith("new-sid-")
270  
271  
272  def test_hindsight_on_session_switch_clears_turn_buffers():
273      """Accumulated _session_turns must not leak into the next session.
274  
275      Hindsight batches turns under a single _document_id. If the buffer
276      isn't cleared on switch, the next retain under the new _document_id
277      flushes turns that belong to the previous session.
278      """
279      provider = _make_hindsight_provider()
280      provider.on_session_switch("new-sid", parent_session_id="old-sid")
281      assert provider._session_turns == []
282      assert provider._turn_counter == 0
283      assert provider._turn_index == 0
284  
285  
286  def test_hindsight_on_session_switch_clears_on_reset_true():
287      """reset=True (from /new, /reset) must also flush buffers."""
288      provider = _make_hindsight_provider()
289      provider.on_session_switch("new-sid", reset=True, reason="new_session")
290      assert provider._session_id == "new-sid"
291      assert provider._session_turns == []
292      assert provider._turn_counter == 0
293  
294  
295  def test_hindsight_on_session_switch_ignores_empty_id():
296      """Empty new_session_id must be a no-op to avoid corrupting state."""
297      provider = _make_hindsight_provider()
298      before = (
299          provider._session_id,
300          provider._document_id,
301          list(provider._session_turns),
302          provider._turn_counter,
303      )
304      provider.on_session_switch("")
305      provider.on_session_switch(None)  # type: ignore[arg-type]
306      after = (
307          provider._session_id,
308          provider._document_id,
309          list(provider._session_turns),
310          provider._turn_counter,
311      )
312      assert before == after
313  
314  
315  def test_hindsight_preserves_parent_across_empty_parent_arg():
316      """Omitting parent_session_id must NOT overwrite an existing one."""
317      provider = _make_hindsight_provider()
318      provider._parent_session_id = "original-parent"
319      provider.on_session_switch("new-sid")  # no parent passed
320      assert provider._parent_session_id == "original-parent"