test_memory_session_switch.py
1 """Tests for the on_session_switch hook and session_id propagation. 2 3 Covers #6672: memory providers must be notified when AIAgent.session_id 4 rotates mid-process (via /resume, /branch, /reset, /new, or context 5 compression). Without the notification, providers that cache per-session 6 state in initialize() (Hindsight, and any plugin that stores session_id 7 for scoped writes) keep writing into the old session's record. 8 """ 9 10 import json 11 12 import pytest 13 14 from agent.memory_manager import MemoryManager 15 from agent.memory_provider import MemoryProvider 16 17 18 class _RecordingProvider(MemoryProvider): 19 """Provider that records every lifecycle call for assertion.""" 20 21 def __init__(self, name="rec"): 22 self._name = name 23 self.switch_calls: list[dict] = [] 24 self.sync_calls: list[dict] = [] 25 self.queue_calls: list[dict] = [] 26 self.initialize_calls: list[dict] = [] 27 28 @property 29 def name(self) -> str: 30 return self._name 31 32 def is_available(self) -> bool: # pragma: no cover - unused 33 return True 34 35 def initialize(self, session_id, **kwargs): 36 self.initialize_calls.append({"session_id": session_id, **kwargs}) 37 38 def get_tool_schemas(self): 39 return [] 40 41 def sync_turn(self, user_content, assistant_content, *, session_id=""): 42 self.sync_calls.append( 43 {"user": user_content, "asst": assistant_content, "session_id": session_id} 44 ) 45 46 def queue_prefetch(self, query, *, session_id=""): 47 self.queue_calls.append({"query": query, "session_id": session_id}) 48 49 def on_session_switch( 50 self, 51 new_session_id, 52 *, 53 parent_session_id="", 54 reset=False, 55 **kwargs, 56 ): 57 self.switch_calls.append( 58 { 59 "new": new_session_id, 60 "parent": parent_session_id, 61 "reset": reset, 62 "extra": kwargs, 63 } 64 ) 65 66 67 # --------------------------------------------------------------------------- 68 # MemoryProvider ABC — default on_session_switch is a no-op 69 # --------------------------------------------------------------------------- 70 71 72 class _MinimalProvider(MemoryProvider): 73 """Provider that does NOT override on_session_switch — ABC default must no-op.""" 74 75 @property 76 def name(self) -> str: 77 return "minimal" 78 79 def is_available(self) -> bool: 80 return True 81 82 def initialize(self, session_id, **kwargs): # pragma: no cover - unused 83 pass 84 85 def get_tool_schemas(self): 86 return [] 87 88 89 def test_abc_default_on_session_switch_is_noop(): 90 """Providers that don't override the hook must not raise.""" 91 p = _MinimalProvider() 92 # All three call styles must be accepted without raising 93 p.on_session_switch("new-id") 94 p.on_session_switch("new-id", parent_session_id="old-id") 95 p.on_session_switch("new-id", parent_session_id="old-id", reset=True) 96 p.on_session_switch("new-id", parent_session_id="old-id", reset=True, reason="new_session") 97 98 99 # --------------------------------------------------------------------------- 100 # MemoryManager.on_session_switch — fan-out 101 # --------------------------------------------------------------------------- 102 103 104 def test_manager_fans_out_to_all_providers(): 105 mm = MemoryManager() 106 # Only one external provider is allowed; use the builtin slot for p1. 107 p1 = _RecordingProvider(name="builtin") 108 p2 = _RecordingProvider(name="hindsight") 109 mm.add_provider(p1) 110 mm.add_provider(p2) 111 112 mm.on_session_switch("new-sid", parent_session_id="old-sid", reset=False, reason="resume") 113 114 assert len(p1.switch_calls) == 1 115 assert len(p2.switch_calls) == 1 116 for call in (p1.switch_calls[0], p2.switch_calls[0]): 117 assert call["new"] == "new-sid" 118 assert call["parent"] == "old-sid" 119 assert call["reset"] is False 120 assert call["extra"] == {"reason": "resume"} 121 122 123 def test_manager_ignores_empty_session_id(): 124 """Empty string session_id must not trigger provider hooks. 125 126 Prevents accidental fires during shutdown when self.session_id may be 127 cleared. Providers expect a meaningful id to switch TO. 128 """ 129 mm = MemoryManager() 130 p = _RecordingProvider() 131 mm.add_provider(p) 132 mm.on_session_switch("") 133 mm.on_session_switch(None) # type: ignore[arg-type] 134 assert p.switch_calls == [] 135 136 137 def test_manager_isolates_provider_failures(): 138 """A provider that raises must not block other providers.""" 139 140 class _Broken(_RecordingProvider): 141 def on_session_switch(self, *args, **kwargs): # type: ignore[override] 142 raise RuntimeError("boom") 143 144 mm = MemoryManager() 145 # MemoryManager rejects a second external provider, so pair broken 146 # (builtin slot) with a good external one. 147 broken = _Broken(name="builtin") 148 good = _RecordingProvider(name="good") 149 mm.add_provider(broken) 150 mm.add_provider(good) 151 152 # Must not raise — exceptions in one provider are swallowed + logged 153 mm.on_session_switch("new-sid", parent_session_id="old-sid") 154 assert len(good.switch_calls) == 1 155 assert good.switch_calls[0]["new"] == "new-sid" 156 157 158 def test_manager_reset_flag_preserved(): 159 mm = MemoryManager() 160 p = _RecordingProvider() 161 mm.add_provider(p) 162 mm.on_session_switch("new-sid", reset=True, reason="new_session") 163 assert p.switch_calls[0]["reset"] is True 164 assert p.switch_calls[0]["extra"] == {"reason": "new_session"} 165 166 167 # --------------------------------------------------------------------------- 168 # MemoryManager.sync_all / queue_prefetch_all — session_id propagation 169 # --------------------------------------------------------------------------- 170 171 172 def test_sync_all_propagates_session_id_to_providers(): 173 """run_agent.py's sync_all call must pass session_id through to providers. 174 175 Without this, a provider that updates _session_id defensively in 176 sync_turn (as Hindsight does at hindsight/__init__.py:1199) never 177 sees the new id and keeps writing under the old one. 178 """ 179 mm = MemoryManager() 180 p = _RecordingProvider() 181 mm.add_provider(p) 182 mm.sync_all("hello", "world", session_id="sess-42") 183 assert p.sync_calls == [ 184 {"user": "hello", "asst": "world", "session_id": "sess-42"} 185 ] 186 187 188 def test_queue_prefetch_all_propagates_session_id_to_providers(): 189 mm = MemoryManager() 190 p = _RecordingProvider() 191 mm.add_provider(p) 192 mm.queue_prefetch_all("next query", session_id="sess-42") 193 assert p.queue_calls == [{"query": "next query", "session_id": "sess-42"}] 194 195 196 # --------------------------------------------------------------------------- 197 # Hindsight reference implementation — state-flush semantics 198 # --------------------------------------------------------------------------- 199 200 201 def _make_hindsight_provider(): 202 """Build a bare HindsightMemoryProvider that skips network setup. 203 204 We instantiate without importing optional deps at class-level by 205 bypassing __init__ and seeding the attributes on_session_switch 206 reads/writes. This keeps the test hermetic. 207 """ 208 import threading 209 hindsight_mod = pytest.importorskip("plugins.memory.hindsight") 210 provider = object.__new__(hindsight_mod.HindsightMemoryProvider) 211 provider._session_id = "old-sid" 212 provider._parent_session_id = "" 213 provider._document_id = "old-sid-20260101_000000_000000" 214 provider._session_turns = ["turn-1", "turn-2"] 215 provider._turn_counter = 2 216 provider._turn_index = 2 217 # Attrs read by _build_metadata / _build_retain_kwargs when the 218 # buffer-flush path on session switch fires. Empty strings keep the 219 # metadata minimal but well-formed. 220 provider._retain_source = "" 221 provider._platform = "" 222 provider._user_id = "" 223 provider._user_name = "" 224 provider._chat_id = "" 225 provider._chat_name = "" 226 provider._chat_type = "" 227 provider._thread_id = "" 228 provider._agent_identity = "" 229 provider._agent_workspace = "" 230 provider._retain_tags = [] 231 provider._retain_context = "test-context" 232 provider._retain_async = False 233 provider._bank_id = "test-bank" 234 # Prefetch state the switch path drains/clears. 235 provider._prefetch_thread = None 236 provider._prefetch_lock = threading.Lock() 237 provider._prefetch_result = "" 238 # Sync thread tracking (legacy alias at the writer). 239 provider._sync_thread = None 240 # Writer queue infra the flush-on-switch path enqueues onto. We stub 241 # _ensure_writer / _register_atexit so no real thread is spawned; 242 # tests exercising flush delivery live in 243 # tests/plugins/memory/test_hindsight_provider.py where the full 244 # writer-queue wiring is in place. 245 import queue as _queue 246 provider._retain_queue = _queue.Queue() 247 provider._shutting_down = threading.Event() 248 provider._atexit_registered = True 249 provider._ensure_writer = lambda: None 250 provider._register_atexit = lambda: None 251 # Stub the network-touching helper so any enqueued flush closure is 252 # a no-op if ever drained in a unit test. 253 provider._run_hindsight_operation = lambda _op: None 254 return provider 255 256 257 def test_hindsight_on_session_switch_updates_session_id_and_mints_fresh_doc(): 258 provider = _make_hindsight_provider() 259 old_doc = provider._document_id 260 261 provider.on_session_switch( 262 "new-sid", parent_session_id="old-sid", reset=False, reason="resume" 263 ) 264 265 assert provider._session_id == "new-sid" 266 assert provider._parent_session_id == "old-sid" 267 # Document id MUST be fresh — else next retain overwrites old session doc 268 assert provider._document_id != old_doc 269 assert provider._document_id.startswith("new-sid-") 270 271 272 def test_hindsight_on_session_switch_clears_turn_buffers(): 273 """Accumulated _session_turns must not leak into the next session. 274 275 Hindsight batches turns under a single _document_id. If the buffer 276 isn't cleared on switch, the next retain under the new _document_id 277 flushes turns that belong to the previous session. 278 """ 279 provider = _make_hindsight_provider() 280 provider.on_session_switch("new-sid", parent_session_id="old-sid") 281 assert provider._session_turns == [] 282 assert provider._turn_counter == 0 283 assert provider._turn_index == 0 284 285 286 def test_hindsight_on_session_switch_clears_on_reset_true(): 287 """reset=True (from /new, /reset) must also flush buffers.""" 288 provider = _make_hindsight_provider() 289 provider.on_session_switch("new-sid", reset=True, reason="new_session") 290 assert provider._session_id == "new-sid" 291 assert provider._session_turns == [] 292 assert provider._turn_counter == 0 293 294 295 def test_hindsight_on_session_switch_ignores_empty_id(): 296 """Empty new_session_id must be a no-op to avoid corrupting state.""" 297 provider = _make_hindsight_provider() 298 before = ( 299 provider._session_id, 300 provider._document_id, 301 list(provider._session_turns), 302 provider._turn_counter, 303 ) 304 provider.on_session_switch("") 305 provider.on_session_switch(None) # type: ignore[arg-type] 306 after = ( 307 provider._session_id, 308 provider._document_id, 309 list(provider._session_turns), 310 provider._turn_counter, 311 ) 312 assert before == after 313 314 315 def test_hindsight_preserves_parent_across_empty_parent_arg(): 316 """Omitting parent_session_id must NOT overwrite an existing one.""" 317 provider = _make_hindsight_provider() 318 provider._parent_session_id = "original-parent" 319 provider.on_session_switch("new-sid") # no parent passed 320 assert provider._parent_session_id == "original-parent"