/ tests / plugins / test_retaindb_plugin.py
test_retaindb_plugin.py
  1  """Tests for the RetainDB memory plugin.
  2  
  3  Covers: _Client HTTP client, _WriteQueue SQLite queue, _build_overlay formatter,
  4  RetainDBMemoryProvider lifecycle/tools/prefetch, thread management, connection pooling.
  5  """
  6  
  7  import json
  8  import os
  9  import sqlite3
 10  import tempfile
 11  import threading
 12  import time
 13  from pathlib import Path
 14  from unittest.mock import MagicMock, patch, PropertyMock
 15  
 16  import pytest
 17  
 18  
 19  # ---------------------------------------------------------------------------
 20  # Imports — guarded since plugins/memory lives outside the standard test path
 21  # ---------------------------------------------------------------------------
 22  
 23  @pytest.fixture(autouse=True)
 24  def _isolate_env(tmp_path, monkeypatch):
 25      """Ensure HERMES_HOME and RETAINDB vars are isolated."""
 26      hermes_home = tmp_path / ".hermes"
 27      hermes_home.mkdir()
 28      monkeypatch.setenv("HERMES_HOME", str(hermes_home))
 29      monkeypatch.delenv("RETAINDB_API_KEY", raising=False)
 30      monkeypatch.delenv("RETAINDB_BASE_URL", raising=False)
 31      monkeypatch.delenv("RETAINDB_PROJECT", raising=False)
 32  
 33  
 34  @pytest.fixture(autouse=True)
 35  def _cap_retaindb_sleeps(monkeypatch):
 36      """Cap production-code sleeps so background-thread tests run fast.
 37  
 38      The retaindb ``_WriteQueue._flush_row`` does ``time.sleep(2)`` after
 39      errors. Across multiple tests that trigger the retry path, that adds
 40      up. Cap the module's bound ``time.sleep`` to 0.05s — tests don't care
 41      about the exact retry delay, only that it happens. The test file's
 42      own ``time.sleep`` stays real since it uses a different reference.
 43      """
 44      try:
 45          from plugins.memory import retaindb as _retaindb
 46      except ImportError:
 47          return
 48  
 49      real_sleep = _retaindb.time.sleep
 50  
 51      def _capped_sleep(seconds):
 52          return real_sleep(min(float(seconds), 0.05))
 53  
 54      import types as _types
 55      fake_time = _types.SimpleNamespace(sleep=_capped_sleep, time=_retaindb.time.time)
 56      monkeypatch.setattr(_retaindb, "time", fake_time)
 57  
 58  
 59  # We need the repo root on sys.path so the plugin can import agent.memory_provider
 60  import sys
 61  _repo_root = str(Path(__file__).resolve().parents[2])
 62  if _repo_root not in sys.path:
 63      sys.path.insert(0, _repo_root)
 64  
 65  from plugins.memory.retaindb import (
 66      _Client,
 67      _WriteQueue,
 68      _build_overlay,
 69      RetainDBMemoryProvider,
 70      _ASYNC_SHUTDOWN,
 71      _DEFAULT_BASE_URL,
 72  )
 73  
 74  
 75  # ===========================================================================
 76  # _Client tests
 77  # ===========================================================================
 78  
 79  class TestClient:
 80      """Test the HTTP client with mocked requests."""
 81  
 82      def _make_client(self, api_key="rdb-test-key", base_url="https://api.retaindb.com", project="test"):
 83          return _Client(api_key, base_url, project)
 84  
 85      def test_base_url_trailing_slash_stripped(self):
 86          c = self._make_client(base_url="https://api.retaindb.com///")
 87          assert c.base_url == "https://api.retaindb.com"
 88  
 89      def test_headers_include_auth(self):
 90          c = self._make_client()
 91          h = c._headers("/v1/files")
 92          assert h["Authorization"] == "Bearer rdb-test-key"
 93          assert "X-API-Key" not in h
 94  
 95      def test_headers_include_api_key_for_memory_path(self):
 96          c = self._make_client()
 97          h = c._headers("/v1/memory/search")
 98          assert h["X-API-Key"] == "rdb-test-key"
 99  
100      def test_headers_include_api_key_for_context_path(self):
101          c = self._make_client()
102          h = c._headers("/v1/context/query")
103          assert h["X-API-Key"] == "rdb-test-key"
104  
105      def test_headers_strip_bearer_prefix(self):
106          c = self._make_client(api_key="Bearer rdb-test-key")
107          h = c._headers("/v1/memory/search")
108          assert h["Authorization"] == "Bearer rdb-test-key"
109          assert h["X-API-Key"] == "rdb-test-key"
110  
111      def test_add_memory_tries_fallback(self):
112          c = self._make_client()
113          call_count = 0
114          def fake_request(method, path, **kwargs):
115              nonlocal call_count
116              call_count += 1
117              if call_count == 1:
118                  raise RuntimeError("404")
119              return {"id": "mem-1"}
120  
121          with patch.object(c, "request", side_effect=fake_request):
122              result = c.add_memory("u1", "s1", "test fact")
123              assert result == {"id": "mem-1"}
124              assert call_count == 2
125  
126      def test_delete_memory_tries_fallback(self):
127          c = self._make_client()
128          call_count = 0
129          def fake_request(method, path, **kwargs):
130              nonlocal call_count
131              call_count += 1
132              if call_count == 1:
133                  raise RuntimeError("404")
134              return {"deleted": True}
135  
136          with patch.object(c, "request", side_effect=fake_request):
137              result = c.delete_memory("mem-123")
138              assert result == {"deleted": True}
139              assert call_count == 2
140  
141  # ===========================================================================
142  # _WriteQueue tests
143  # ===========================================================================
144  
145  class TestWriteQueue:
146      """Test the SQLite-backed write queue with real SQLite."""
147  
148      def _make_queue(self, tmp_path, client=None):
149          if client is None:
150              client = MagicMock()
151              client.ingest_session = MagicMock(return_value={"status": "ok"})
152          db_path = tmp_path / "test_queue.db"
153          return _WriteQueue(client, db_path), client, db_path
154  
155      def test_enqueue_creates_row(self, tmp_path):
156          q, client, db_path = self._make_queue(tmp_path)
157          q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}])
158          # shutdown() blocks until the writer thread drains the queue — no need
159          # to pre-sleep (the old 1s sleep was a just-in-case wait, but shutdown
160          # does the right thing).
161          q.shutdown()
162          # If ingest succeeded, the row should be deleted
163          client.ingest_session.assert_called_once()
164  
165      def test_enqueue_persists_to_sqlite(self, tmp_path):
166          client = MagicMock()
167          # Make ingest slow so the row is still in SQLite when we peek.
168          # 0.5s is plenty — the test just needs the flush to still be in-flight.
169          client.ingest_session = MagicMock(side_effect=lambda *a, **kw: time.sleep(0.5))
170          db_path = tmp_path / "test_queue.db"
171          q = _WriteQueue(client, db_path)
172          q.enqueue("user1", "sess1", [{"role": "user", "content": "test"}])
173          # Check SQLite directly — row should exist since flush is slow
174          conn = sqlite3.connect(str(db_path))
175          rows = conn.execute("SELECT user_id, session_id FROM pending").fetchall()
176          conn.close()
177          assert len(rows) >= 1
178          assert rows[0][0] == "user1"
179          q.shutdown()
180  
181      def test_flush_deletes_row_on_success(self, tmp_path):
182          q, client, db_path = self._make_queue(tmp_path)
183          q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}])
184          q.shutdown()  # blocks until drain
185          # Row should be gone
186          conn = sqlite3.connect(str(db_path))
187          rows = conn.execute("SELECT COUNT(*) FROM pending").fetchone()[0]
188          conn.close()
189          assert rows == 0
190  
191      def test_flush_records_error_on_failure(self, tmp_path):
192          client = MagicMock()
193          client.ingest_session = MagicMock(side_effect=RuntimeError("API down"))
194          db_path = tmp_path / "test_queue.db"
195          q = _WriteQueue(client, db_path)
196          q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}])
197          # Poll for the error to be recorded (max 2s), instead of a fixed 3s wait.
198          deadline = time.time() + 2.0
199          last_error = None
200          while time.time() < deadline:
201              conn = sqlite3.connect(str(db_path))
202              row = conn.execute("SELECT last_error FROM pending").fetchone()
203              conn.close()
204              if row and row[0]:
205                  last_error = row[0]
206                  break
207              time.sleep(0.05)
208          q.shutdown()
209          assert last_error is not None
210          assert "API down" in last_error
211  
212      def test_thread_local_connection_reuse(self, tmp_path):
213          q, _, _ = self._make_queue(tmp_path)
214          # Same thread should get same connection
215          conn1 = q._get_conn()
216          conn2 = q._get_conn()
217          assert conn1 is conn2
218          q.shutdown()
219  
220      def test_crash_recovery_replays_pending(self, tmp_path):
221          """Simulate crash: create rows, then new queue should replay them."""
222          db_path = tmp_path / "recovery_test.db"
223          # First: create a queue and insert rows, but don't let them flush
224          client1 = MagicMock()
225          client1.ingest_session = MagicMock(side_effect=RuntimeError("fail"))
226          q1 = _WriteQueue(client1, db_path)
227          q1.enqueue("user1", "sess1", [{"role": "user", "content": "lost turn"}])
228          # Wait until the error is recorded (poll with short interval).
229          deadline = time.time() + 2.0
230          while time.time() < deadline:
231              conn = sqlite3.connect(str(db_path))
232              row = conn.execute("SELECT last_error FROM pending").fetchone()
233              conn.close()
234              if row and row[0]:
235                  break
236              time.sleep(0.05)
237          q1.shutdown()
238  
239          # Now create a new queue — it should replay the pending rows
240          client2 = MagicMock()
241          client2.ingest_session = MagicMock(return_value={"status": "ok"})
242          q2 = _WriteQueue(client2, db_path)
243          # Poll for the replay to happen.
244          deadline = time.time() + 2.0
245          while time.time() < deadline:
246              if client2.ingest_session.called:
247                  break
248              time.sleep(0.05)
249          q2.shutdown()
250  
251          # The replayed row should have been ingested via client2
252          client2.ingest_session.assert_called_once()
253          call_args = client2.ingest_session.call_args
254          assert call_args[0][0] == "user1"  # user_id
255  
256  
257  # ===========================================================================
258  # _build_overlay tests
259  # ===========================================================================
260  
261  class TestBuildOverlay:
262      """Test the overlay formatter (pure function)."""
263  
264      def test_empty_inputs_returns_empty(self):
265          assert _build_overlay({}, {}) == ""
266  
267      def test_empty_memories_returns_empty(self):
268          assert _build_overlay({"memories": []}, {"results": []}) == ""
269  
270      def test_profile_items_included(self):
271          profile = {"memories": [{"content": "User likes Python"}]}
272          result = _build_overlay(profile, {})
273          assert "User likes Python" in result
274          assert "[RetainDB Context]" in result
275  
276      def test_query_results_included(self):
277          query_result = {"results": [{"content": "Previous discussion about Rust"}]}
278          result = _build_overlay({}, query_result)
279          assert "Previous discussion about Rust" in result
280  
281      def test_deduplication_removes_duplicates(self):
282          profile = {"memories": [{"content": "User likes Python"}]}
283          query_result = {"results": [{"content": "User likes Python"}]}
284          result = _build_overlay(profile, query_result)
285          assert result.count("User likes Python") == 1
286  
287      def test_local_entries_filter(self):
288          profile = {"memories": [{"content": "Already known fact"}]}
289          result = _build_overlay(profile, {}, local_entries=["Already known fact"])
290          # The profile item matches a local entry, should be filtered
291          assert result == ""
292  
293      def test_max_five_items_per_section(self):
294          profile = {"memories": [{"content": f"Fact {i}"} for i in range(10)]}
295          result = _build_overlay(profile, {})
296          # Should only include first 5
297          assert "Fact 0" in result
298          assert "Fact 4" in result
299          assert "Fact 5" not in result
300  
301      def test_none_content_handled(self):
302          profile = {"memories": [{"content": None}, {"content": "Real fact"}]}
303          result = _build_overlay(profile, {})
304          assert "Real fact" in result
305  
306      def test_truncation_at_320_chars(self):
307          long_content = "x" * 500
308          profile = {"memories": [{"content": long_content}]}
309          result = _build_overlay(profile, {})
310          # Each item is compacted to 320 chars max
311          for line in result.split("\n"):
312              if line.startswith("- "):
313                  assert len(line) <= 322  # "- " + 320
314  
315  
316  # ===========================================================================
317  # RetainDBMemoryProvider tests
318  # ===========================================================================
319  
320  class TestRetainDBMemoryProvider:
321      """Test the main plugin class."""
322  
323      def _make_provider(self, tmp_path, monkeypatch, api_key="rdb-test-key"):
324          monkeypatch.setenv("RETAINDB_API_KEY", api_key)
325          monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
326          (tmp_path / ".hermes").mkdir(exist_ok=True)
327          provider = RetainDBMemoryProvider()
328          return provider
329  
330      def test_name(self):
331          p = RetainDBMemoryProvider()
332          assert p.name == "retaindb"
333  
334      def test_is_available_without_key(self):
335          p = RetainDBMemoryProvider()
336          assert p.is_available() is False
337  
338      def test_is_available_with_key(self, monkeypatch):
339          monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test")
340          p = RetainDBMemoryProvider()
341          assert p.is_available() is True
342  
343      def test_config_schema(self):
344          p = RetainDBMemoryProvider()
345          schema = p.get_config_schema()
346          assert len(schema) == 3
347          keys = [s["key"] for s in schema]
348          assert "api_key" in keys
349          assert "base_url" in keys
350          assert "project" in keys
351  
352      def test_initialize_creates_client_and_queue(self, tmp_path, monkeypatch):
353          p = self._make_provider(tmp_path, monkeypatch)
354          p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
355          assert p._client is not None
356          assert p._queue is not None
357          assert p._session_id == "test-session"
358          p.shutdown()
359  
360      def test_initialize_default_project(self, tmp_path, monkeypatch):
361          p = self._make_provider(tmp_path, monkeypatch)
362          p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
363          assert p._client.project == "default"
364          p.shutdown()
365  
366      def test_initialize_explicit_project(self, tmp_path, monkeypatch):
367          monkeypatch.setenv("RETAINDB_PROJECT", "my-project")
368          p = self._make_provider(tmp_path, monkeypatch)
369          p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
370          assert p._client.project == "my-project"
371          p.shutdown()
372  
373      def test_initialize_profile_project(self, tmp_path, monkeypatch):
374          p = self._make_provider(tmp_path, monkeypatch)
375          profile_home = str(tmp_path / "profiles" / "coder")
376          p.initialize("test-session", hermes_home=profile_home)
377          assert p._client.project == "hermes-coder"
378          p.shutdown()
379  
380      def test_initialize_seeds_soul_md(self, tmp_path, monkeypatch):
381          p = self._make_provider(tmp_path, monkeypatch)
382          soul_path = tmp_path / ".hermes" / "SOUL.md"
383          soul_path.write_text("I am a helpful agent.")
384          with patch.object(RetainDBMemoryProvider, "_seed_soul") as mock_seed:
385              p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
386              # Give thread time to start
387              time.sleep(0.5)
388              mock_seed.assert_called_once_with("I am a helpful agent.")
389          p.shutdown()
390  
391      def test_system_prompt_block(self, tmp_path, monkeypatch):
392          p = self._make_provider(tmp_path, monkeypatch)
393          p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
394          block = p.system_prompt_block()
395          assert "RetainDB Memory" in block
396          assert "Active" in block
397          p.shutdown()
398  
399      def test_handle_tool_call_not_initialized(self):
400          p = RetainDBMemoryProvider()
401          result = json.loads(p.handle_tool_call("retaindb_profile", {}))
402          assert "error" in result
403          assert "not initialized" in result["error"]
404  
405      def test_handle_tool_call_unknown_tool(self, tmp_path, monkeypatch):
406          p = self._make_provider(tmp_path, monkeypatch)
407          p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
408          result = json.loads(p.handle_tool_call("retaindb_nonexistent", {}))
409          assert result == {"error": "Unknown tool: retaindb_nonexistent"}
410          p.shutdown()
411  
412      def test_dispatch_profile(self, tmp_path, monkeypatch):
413          p = self._make_provider(tmp_path, monkeypatch)
414          p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
415          with patch.object(p._client, "get_profile", return_value={"memories": []}):
416              result = json.loads(p.handle_tool_call("retaindb_profile", {}))
417              assert "memories" in result
418          p.shutdown()
419  
420      def test_dispatch_search_requires_query(self, tmp_path, monkeypatch):
421          p = self._make_provider(tmp_path, monkeypatch)
422          p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
423          result = json.loads(p.handle_tool_call("retaindb_search", {}))
424          assert result == {"error": "query is required"}
425          p.shutdown()
426  
427      def test_dispatch_search(self, tmp_path, monkeypatch):
428          p = self._make_provider(tmp_path, monkeypatch)
429          p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
430          with patch.object(p._client, "search", return_value={"results": [{"content": "found"}]}):
431              result = json.loads(p.handle_tool_call("retaindb_search", {"query": "test"}))
432              assert "results" in result
433          p.shutdown()
434  
435      def test_dispatch_search_top_k_capped(self, tmp_path, monkeypatch):
436          p = self._make_provider(tmp_path, monkeypatch)
437          p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
438          with patch.object(p._client, "search") as mock_search:
439              mock_search.return_value = {"results": []}
440              p.handle_tool_call("retaindb_search", {"query": "test", "top_k": 100})
441              # top_k should be capped at 20
442              assert mock_search.call_args[1]["top_k"] == 20
443          p.shutdown()
444  
445      def test_dispatch_remember(self, tmp_path, monkeypatch):
446          p = self._make_provider(tmp_path, monkeypatch)
447          p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
448          with patch.object(p._client, "add_memory", return_value={"id": "mem-1"}):
449              result = json.loads(p.handle_tool_call("retaindb_remember", {"content": "test fact"}))
450              assert result["id"] == "mem-1"
451          p.shutdown()
452  
453      def test_dispatch_remember_requires_content(self, tmp_path, monkeypatch):
454          p = self._make_provider(tmp_path, monkeypatch)
455          p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
456          result = json.loads(p.handle_tool_call("retaindb_remember", {}))
457          assert result == {"error": "content is required"}
458          p.shutdown()
459  
460      def test_dispatch_forget(self, tmp_path, monkeypatch):
461          p = self._make_provider(tmp_path, monkeypatch)
462          p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
463          with patch.object(p._client, "delete_memory", return_value={"deleted": True}):
464              result = json.loads(p.handle_tool_call("retaindb_forget", {"memory_id": "mem-1"}))
465              assert result["deleted"] is True
466          p.shutdown()
467  
468      def test_dispatch_forget_requires_id(self, tmp_path, monkeypatch):
469          p = self._make_provider(tmp_path, monkeypatch)
470          p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
471          result = json.loads(p.handle_tool_call("retaindb_forget", {}))
472          assert result == {"error": "memory_id is required"}
473          p.shutdown()
474  
475      def test_dispatch_context(self, tmp_path, monkeypatch):
476          p = self._make_provider(tmp_path, monkeypatch)
477          p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
478          with patch.object(p._client, "query_context", return_value={"results": [{"content": "relevant"}]}), \
479               patch.object(p._client, "get_profile", return_value={"memories": []}):
480              result = json.loads(p.handle_tool_call("retaindb_context", {"query": "current task"}))
481              assert "context" in result
482              assert "raw" in result
483          p.shutdown()
484  
485      def test_dispatch_file_list(self, tmp_path, monkeypatch):
486          p = self._make_provider(tmp_path, monkeypatch)
487          p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
488          with patch.object(p._client, "list_files", return_value={"files": []}):
489              result = json.loads(p.handle_tool_call("retaindb_list_files", {}))
490              assert "files" in result
491          p.shutdown()
492  
493      def test_dispatch_file_upload_missing_path(self, tmp_path, monkeypatch):
494          p = self._make_provider(tmp_path, monkeypatch)
495          p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
496          result = json.loads(p.handle_tool_call("retaindb_upload_file", {}))
497          assert "error" in result
498  
499      def test_dispatch_file_upload_not_found(self, tmp_path, monkeypatch):
500          p = self._make_provider(tmp_path, monkeypatch)
501          p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
502          result = json.loads(p.handle_tool_call("retaindb_upload_file", {"local_path": "/nonexistent/file.txt"}))
503          assert "File not found" in result["error"]
504          p.shutdown()
505  
506      def test_dispatch_file_read_requires_id(self, tmp_path, monkeypatch):
507          p = self._make_provider(tmp_path, monkeypatch)
508          p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
509          result = json.loads(p.handle_tool_call("retaindb_read_file", {}))
510          assert result == {"error": "file_id is required"}
511          p.shutdown()
512  
513      def test_dispatch_file_ingest_requires_id(self, tmp_path, monkeypatch):
514          p = self._make_provider(tmp_path, monkeypatch)
515          p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
516          result = json.loads(p.handle_tool_call("retaindb_ingest_file", {}))
517          assert result == {"error": "file_id is required"}
518          p.shutdown()
519  
520      def test_dispatch_file_delete_requires_id(self, tmp_path, monkeypatch):
521          p = self._make_provider(tmp_path, monkeypatch)
522          p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
523          result = json.loads(p.handle_tool_call("retaindb_delete_file", {}))
524          assert result == {"error": "file_id is required"}
525          p.shutdown()
526  
527      def test_handle_tool_call_wraps_exception(self, tmp_path, monkeypatch):
528          p = self._make_provider(tmp_path, monkeypatch)
529          p.initialize("test-session", hermes_home=str(tmp_path / ".hermes"))
530          with patch.object(p._client, "get_profile", side_effect=RuntimeError("API exploded")):
531              result = json.loads(p.handle_tool_call("retaindb_profile", {}))
532              assert "API exploded" in result["error"]
533          p.shutdown()
534  
535  
536  # ===========================================================================
537  # Prefetch and thread management tests
538  # ===========================================================================
539  
540  class TestPrefetch:
541      """Test background prefetch and thread accumulation prevention."""
542  
543      def _make_initialized_provider(self, tmp_path, monkeypatch):
544          monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
545          hermes_home = tmp_path / ".hermes"
546          hermes_home.mkdir(exist_ok=True)
547          monkeypatch.setenv("HERMES_HOME", str(hermes_home))
548          p = RetainDBMemoryProvider()
549          p.initialize("test-session", hermes_home=str(hermes_home))
550          return p
551  
552      def test_queue_prefetch_skips_without_client(self):
553          p = RetainDBMemoryProvider()
554          p.queue_prefetch("test")  # Should not raise
555  
556      def test_prefetch_returns_empty_when_nothing_cached(self, tmp_path, monkeypatch):
557          p = self._make_initialized_provider(tmp_path, monkeypatch)
558          result = p.prefetch("test")
559          assert result == ""
560          p.shutdown()
561  
562      def test_prefetch_consumes_context_result(self, tmp_path, monkeypatch):
563          p = self._make_initialized_provider(tmp_path, monkeypatch)
564          # Manually set the cached result
565          with p._lock:
566              p._context_result = "[RetainDB Context]\nProfile:\n- User likes tests"
567          result = p.prefetch("test")
568          assert "User likes tests" in result
569          # Should be consumed
570          assert p.prefetch("test") == ""
571          p.shutdown()
572  
573      def test_prefetch_consumes_dialectic_result(self, tmp_path, monkeypatch):
574          p = self._make_initialized_provider(tmp_path, monkeypatch)
575          with p._lock:
576              p._dialectic_result = "User is a software engineer who prefers Python."
577          result = p.prefetch("test")
578          assert "[RetainDB User Synthesis]" in result
579          assert "software engineer" in result
580          p.shutdown()
581  
582      def test_prefetch_consumes_agent_model(self, tmp_path, monkeypatch):
583          p = self._make_initialized_provider(tmp_path, monkeypatch)
584          with p._lock:
585              p._agent_model = {
586                  "memory_count": 5,
587                  "persona": "Helpful coding assistant",
588                  "persistent_instructions": ["Be concise", "Use Python"],
589                  "working_style": "Direct and efficient",
590              }
591          result = p.prefetch("test")
592          assert "[RetainDB Agent Self-Model]" in result
593          assert "Helpful coding assistant" in result
594          assert "Be concise" in result
595          assert "Direct and efficient" in result
596          p.shutdown()
597  
598      def test_prefetch_skips_empty_agent_model(self, tmp_path, monkeypatch):
599          p = self._make_initialized_provider(tmp_path, monkeypatch)
600          with p._lock:
601              p._agent_model = {"memory_count": 0}
602          result = p.prefetch("test")
603          assert "Agent Self-Model" not in result
604          p.shutdown()
605  
606      def test_thread_accumulation_guard(self, tmp_path, monkeypatch):
607          """Verify old prefetch threads are joined before new ones spawn."""
608          p = self._make_initialized_provider(tmp_path, monkeypatch)
609          # Mock the prefetch methods to be slow
610          with patch.object(p, "_prefetch_context", side_effect=lambda q: time.sleep(0.5)), \
611               patch.object(p, "_prefetch_dialectic", side_effect=lambda q: time.sleep(0.5)), \
612               patch.object(p, "_prefetch_agent_model", side_effect=lambda: time.sleep(0.5)):
613              p.queue_prefetch("query 1")
614              first_threads = list(p._prefetch_threads)
615              assert len(first_threads) == 3
616  
617              # Call again — should join first batch before spawning new
618              p.queue_prefetch("query 2")
619              second_threads = list(p._prefetch_threads)
620              assert len(second_threads) == 3
621              # Should be different thread objects
622              for t in second_threads:
623                  assert t not in first_threads
624          p.shutdown()
625  
626      def test_reasoning_level_short(self):
627          assert RetainDBMemoryProvider._reasoning_level("hi") == "low"
628  
629      def test_reasoning_level_medium(self):
630          assert RetainDBMemoryProvider._reasoning_level("x" * 200) == "medium"
631  
632      def test_reasoning_level_long(self):
633          assert RetainDBMemoryProvider._reasoning_level("x" * 500) == "high"
634  
635  
636  # ===========================================================================
637  # sync_turn tests
638  # ===========================================================================
639  
640  class TestSyncTurn:
641      """Test turn synchronization via the write queue."""
642  
643      def test_sync_turn_enqueues(self, tmp_path, monkeypatch):
644          monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
645          hermes_home = tmp_path / ".hermes"
646          hermes_home.mkdir(exist_ok=True)
647          monkeypatch.setenv("HERMES_HOME", str(hermes_home))
648          p = RetainDBMemoryProvider()
649          p.initialize("test-session", hermes_home=str(hermes_home))
650          with patch.object(p._queue, "enqueue") as mock_enqueue:
651              p.sync_turn("user msg", "assistant msg")
652              mock_enqueue.assert_called_once()
653              args = mock_enqueue.call_args[0]
654              assert args[0] == "default"  # user_id
655              assert args[1] == "test-session"  # session_id
656              msgs = args[2]
657              assert len(msgs) == 2
658              assert msgs[0]["role"] == "user"
659              assert msgs[1]["role"] == "assistant"
660          p.shutdown()
661  
662      def test_sync_turn_skips_empty_user_content(self, tmp_path, monkeypatch):
663          monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
664          hermes_home = tmp_path / ".hermes"
665          hermes_home.mkdir(exist_ok=True)
666          monkeypatch.setenv("HERMES_HOME", str(hermes_home))
667          p = RetainDBMemoryProvider()
668          p.initialize("test-session", hermes_home=str(hermes_home))
669          with patch.object(p._queue, "enqueue") as mock_enqueue:
670              p.sync_turn("", "assistant msg")
671              mock_enqueue.assert_not_called()
672          p.shutdown()
673  
674  
675  # ===========================================================================
676  # on_memory_write hook tests
677  # ===========================================================================
678  
679  class TestOnMemoryWrite:
680      """Test the built-in memory mirror hook."""
681  
682      def test_mirrors_add_action(self, tmp_path, monkeypatch):
683          monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
684          hermes_home = tmp_path / ".hermes"
685          hermes_home.mkdir(exist_ok=True)
686          monkeypatch.setenv("HERMES_HOME", str(hermes_home))
687          p = RetainDBMemoryProvider()
688          p.initialize("test-session", hermes_home=str(hermes_home))
689          with patch.object(p._client, "add_memory", return_value={"id": "mem-1"}) as mock_add:
690              p.on_memory_write("add", "user", "User prefers dark mode")
691              mock_add.assert_called_once()
692              assert mock_add.call_args[1]["memory_type"] == "preference"
693          p.shutdown()
694  
695      def test_skips_non_add_action(self, tmp_path, monkeypatch):
696          monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
697          hermes_home = tmp_path / ".hermes"
698          hermes_home.mkdir(exist_ok=True)
699          monkeypatch.setenv("HERMES_HOME", str(hermes_home))
700          p = RetainDBMemoryProvider()
701          p.initialize("test-session", hermes_home=str(hermes_home))
702          with patch.object(p._client, "add_memory") as mock_add:
703              p.on_memory_write("remove", "user", "something")
704              mock_add.assert_not_called()
705          p.shutdown()
706  
707      def test_skips_empty_content(self, tmp_path, monkeypatch):
708          monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
709          hermes_home = tmp_path / ".hermes"
710          hermes_home.mkdir(exist_ok=True)
711          monkeypatch.setenv("HERMES_HOME", str(hermes_home))
712          p = RetainDBMemoryProvider()
713          p.initialize("test-session", hermes_home=str(hermes_home))
714          with patch.object(p._client, "add_memory") as mock_add:
715              p.on_memory_write("add", "user", "")
716              mock_add.assert_not_called()
717          p.shutdown()
718  
719      def test_memory_target_maps_to_type(self, tmp_path, monkeypatch):
720          monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key")
721          hermes_home = tmp_path / ".hermes"
722          hermes_home.mkdir(exist_ok=True)
723          monkeypatch.setenv("HERMES_HOME", str(hermes_home))
724          p = RetainDBMemoryProvider()
725          p.initialize("test-session", hermes_home=str(hermes_home))
726          with patch.object(p._client, "add_memory", return_value={"id": "mem-1"}) as mock_add:
727              p.on_memory_write("add", "memory", "Some env fact")
728              assert mock_add.call_args[1]["memory_type"] == "factual"
729          p.shutdown()
730  
731  
732  # ===========================================================================
733  # register() test
734  # ===========================================================================
735  
736  class TestRegister:
737      def test_register_calls_register_memory_provider(self):
738          from plugins.memory.retaindb import register
739          ctx = MagicMock()
740          register(ctx)
741          ctx.register_memory_provider.assert_called_once()
742          arg = ctx.register_memory_provider.call_args[0][0]
743          assert isinstance(arg, RetainDBMemoryProvider)