test_retaindb_plugin.py
1 """Tests for the RetainDB memory plugin. 2 3 Covers: _Client HTTP client, _WriteQueue SQLite queue, _build_overlay formatter, 4 RetainDBMemoryProvider lifecycle/tools/prefetch, thread management, connection pooling. 5 """ 6 7 import json 8 import os 9 import sqlite3 10 import tempfile 11 import threading 12 import time 13 from pathlib import Path 14 from unittest.mock import MagicMock, patch, PropertyMock 15 16 import pytest 17 18 19 # --------------------------------------------------------------------------- 20 # Imports — guarded since plugins/memory lives outside the standard test path 21 # --------------------------------------------------------------------------- 22 23 @pytest.fixture(autouse=True) 24 def _isolate_env(tmp_path, monkeypatch): 25 """Ensure HERMES_HOME and RETAINDB vars are isolated.""" 26 hermes_home = tmp_path / ".hermes" 27 hermes_home.mkdir() 28 monkeypatch.setenv("HERMES_HOME", str(hermes_home)) 29 monkeypatch.delenv("RETAINDB_API_KEY", raising=False) 30 monkeypatch.delenv("RETAINDB_BASE_URL", raising=False) 31 monkeypatch.delenv("RETAINDB_PROJECT", raising=False) 32 33 34 @pytest.fixture(autouse=True) 35 def _cap_retaindb_sleeps(monkeypatch): 36 """Cap production-code sleeps so background-thread tests run fast. 37 38 The retaindb ``_WriteQueue._flush_row`` does ``time.sleep(2)`` after 39 errors. Across multiple tests that trigger the retry path, that adds 40 up. Cap the module's bound ``time.sleep`` to 0.05s — tests don't care 41 about the exact retry delay, only that it happens. The test file's 42 own ``time.sleep`` stays real since it uses a different reference. 43 """ 44 try: 45 from plugins.memory import retaindb as _retaindb 46 except ImportError: 47 return 48 49 real_sleep = _retaindb.time.sleep 50 51 def _capped_sleep(seconds): 52 return real_sleep(min(float(seconds), 0.05)) 53 54 import types as _types 55 fake_time = _types.SimpleNamespace(sleep=_capped_sleep, time=_retaindb.time.time) 56 monkeypatch.setattr(_retaindb, "time", fake_time) 57 58 59 # We need the repo root on sys.path so the plugin can import agent.memory_provider 60 import sys 61 _repo_root = str(Path(__file__).resolve().parents[2]) 62 if _repo_root not in sys.path: 63 sys.path.insert(0, _repo_root) 64 65 from plugins.memory.retaindb import ( 66 _Client, 67 _WriteQueue, 68 _build_overlay, 69 RetainDBMemoryProvider, 70 _ASYNC_SHUTDOWN, 71 _DEFAULT_BASE_URL, 72 ) 73 74 75 # =========================================================================== 76 # _Client tests 77 # =========================================================================== 78 79 class TestClient: 80 """Test the HTTP client with mocked requests.""" 81 82 def _make_client(self, api_key="rdb-test-key", base_url="https://api.retaindb.com", project="test"): 83 return _Client(api_key, base_url, project) 84 85 def test_base_url_trailing_slash_stripped(self): 86 c = self._make_client(base_url="https://api.retaindb.com///") 87 assert c.base_url == "https://api.retaindb.com" 88 89 def test_headers_include_auth(self): 90 c = self._make_client() 91 h = c._headers("/v1/files") 92 assert h["Authorization"] == "Bearer rdb-test-key" 93 assert "X-API-Key" not in h 94 95 def test_headers_include_api_key_for_memory_path(self): 96 c = self._make_client() 97 h = c._headers("/v1/memory/search") 98 assert h["X-API-Key"] == "rdb-test-key" 99 100 def test_headers_include_api_key_for_context_path(self): 101 c = self._make_client() 102 h = c._headers("/v1/context/query") 103 assert h["X-API-Key"] == "rdb-test-key" 104 105 def test_headers_strip_bearer_prefix(self): 106 c = self._make_client(api_key="Bearer rdb-test-key") 107 h = c._headers("/v1/memory/search") 108 assert h["Authorization"] == "Bearer rdb-test-key" 109 assert h["X-API-Key"] == "rdb-test-key" 110 111 def test_add_memory_tries_fallback(self): 112 c = self._make_client() 113 call_count = 0 114 def fake_request(method, path, **kwargs): 115 nonlocal call_count 116 call_count += 1 117 if call_count == 1: 118 raise RuntimeError("404") 119 return {"id": "mem-1"} 120 121 with patch.object(c, "request", side_effect=fake_request): 122 result = c.add_memory("u1", "s1", "test fact") 123 assert result == {"id": "mem-1"} 124 assert call_count == 2 125 126 def test_delete_memory_tries_fallback(self): 127 c = self._make_client() 128 call_count = 0 129 def fake_request(method, path, **kwargs): 130 nonlocal call_count 131 call_count += 1 132 if call_count == 1: 133 raise RuntimeError("404") 134 return {"deleted": True} 135 136 with patch.object(c, "request", side_effect=fake_request): 137 result = c.delete_memory("mem-123") 138 assert result == {"deleted": True} 139 assert call_count == 2 140 141 # =========================================================================== 142 # _WriteQueue tests 143 # =========================================================================== 144 145 class TestWriteQueue: 146 """Test the SQLite-backed write queue with real SQLite.""" 147 148 def _make_queue(self, tmp_path, client=None): 149 if client is None: 150 client = MagicMock() 151 client.ingest_session = MagicMock(return_value={"status": "ok"}) 152 db_path = tmp_path / "test_queue.db" 153 return _WriteQueue(client, db_path), client, db_path 154 155 def test_enqueue_creates_row(self, tmp_path): 156 q, client, db_path = self._make_queue(tmp_path) 157 q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}]) 158 # shutdown() blocks until the writer thread drains the queue — no need 159 # to pre-sleep (the old 1s sleep was a just-in-case wait, but shutdown 160 # does the right thing). 161 q.shutdown() 162 # If ingest succeeded, the row should be deleted 163 client.ingest_session.assert_called_once() 164 165 def test_enqueue_persists_to_sqlite(self, tmp_path): 166 client = MagicMock() 167 # Make ingest slow so the row is still in SQLite when we peek. 168 # 0.5s is plenty — the test just needs the flush to still be in-flight. 169 client.ingest_session = MagicMock(side_effect=lambda *a, **kw: time.sleep(0.5)) 170 db_path = tmp_path / "test_queue.db" 171 q = _WriteQueue(client, db_path) 172 q.enqueue("user1", "sess1", [{"role": "user", "content": "test"}]) 173 # Check SQLite directly — row should exist since flush is slow 174 conn = sqlite3.connect(str(db_path)) 175 rows = conn.execute("SELECT user_id, session_id FROM pending").fetchall() 176 conn.close() 177 assert len(rows) >= 1 178 assert rows[0][0] == "user1" 179 q.shutdown() 180 181 def test_flush_deletes_row_on_success(self, tmp_path): 182 q, client, db_path = self._make_queue(tmp_path) 183 q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}]) 184 q.shutdown() # blocks until drain 185 # Row should be gone 186 conn = sqlite3.connect(str(db_path)) 187 rows = conn.execute("SELECT COUNT(*) FROM pending").fetchone()[0] 188 conn.close() 189 assert rows == 0 190 191 def test_flush_records_error_on_failure(self, tmp_path): 192 client = MagicMock() 193 client.ingest_session = MagicMock(side_effect=RuntimeError("API down")) 194 db_path = tmp_path / "test_queue.db" 195 q = _WriteQueue(client, db_path) 196 q.enqueue("user1", "sess1", [{"role": "user", "content": "hi"}]) 197 # Poll for the error to be recorded (max 2s), instead of a fixed 3s wait. 198 deadline = time.time() + 2.0 199 last_error = None 200 while time.time() < deadline: 201 conn = sqlite3.connect(str(db_path)) 202 row = conn.execute("SELECT last_error FROM pending").fetchone() 203 conn.close() 204 if row and row[0]: 205 last_error = row[0] 206 break 207 time.sleep(0.05) 208 q.shutdown() 209 assert last_error is not None 210 assert "API down" in last_error 211 212 def test_thread_local_connection_reuse(self, tmp_path): 213 q, _, _ = self._make_queue(tmp_path) 214 # Same thread should get same connection 215 conn1 = q._get_conn() 216 conn2 = q._get_conn() 217 assert conn1 is conn2 218 q.shutdown() 219 220 def test_crash_recovery_replays_pending(self, tmp_path): 221 """Simulate crash: create rows, then new queue should replay them.""" 222 db_path = tmp_path / "recovery_test.db" 223 # First: create a queue and insert rows, but don't let them flush 224 client1 = MagicMock() 225 client1.ingest_session = MagicMock(side_effect=RuntimeError("fail")) 226 q1 = _WriteQueue(client1, db_path) 227 q1.enqueue("user1", "sess1", [{"role": "user", "content": "lost turn"}]) 228 # Wait until the error is recorded (poll with short interval). 229 deadline = time.time() + 2.0 230 while time.time() < deadline: 231 conn = sqlite3.connect(str(db_path)) 232 row = conn.execute("SELECT last_error FROM pending").fetchone() 233 conn.close() 234 if row and row[0]: 235 break 236 time.sleep(0.05) 237 q1.shutdown() 238 239 # Now create a new queue — it should replay the pending rows 240 client2 = MagicMock() 241 client2.ingest_session = MagicMock(return_value={"status": "ok"}) 242 q2 = _WriteQueue(client2, db_path) 243 # Poll for the replay to happen. 244 deadline = time.time() + 2.0 245 while time.time() < deadline: 246 if client2.ingest_session.called: 247 break 248 time.sleep(0.05) 249 q2.shutdown() 250 251 # The replayed row should have been ingested via client2 252 client2.ingest_session.assert_called_once() 253 call_args = client2.ingest_session.call_args 254 assert call_args[0][0] == "user1" # user_id 255 256 257 # =========================================================================== 258 # _build_overlay tests 259 # =========================================================================== 260 261 class TestBuildOverlay: 262 """Test the overlay formatter (pure function).""" 263 264 def test_empty_inputs_returns_empty(self): 265 assert _build_overlay({}, {}) == "" 266 267 def test_empty_memories_returns_empty(self): 268 assert _build_overlay({"memories": []}, {"results": []}) == "" 269 270 def test_profile_items_included(self): 271 profile = {"memories": [{"content": "User likes Python"}]} 272 result = _build_overlay(profile, {}) 273 assert "User likes Python" in result 274 assert "[RetainDB Context]" in result 275 276 def test_query_results_included(self): 277 query_result = {"results": [{"content": "Previous discussion about Rust"}]} 278 result = _build_overlay({}, query_result) 279 assert "Previous discussion about Rust" in result 280 281 def test_deduplication_removes_duplicates(self): 282 profile = {"memories": [{"content": "User likes Python"}]} 283 query_result = {"results": [{"content": "User likes Python"}]} 284 result = _build_overlay(profile, query_result) 285 assert result.count("User likes Python") == 1 286 287 def test_local_entries_filter(self): 288 profile = {"memories": [{"content": "Already known fact"}]} 289 result = _build_overlay(profile, {}, local_entries=["Already known fact"]) 290 # The profile item matches a local entry, should be filtered 291 assert result == "" 292 293 def test_max_five_items_per_section(self): 294 profile = {"memories": [{"content": f"Fact {i}"} for i in range(10)]} 295 result = _build_overlay(profile, {}) 296 # Should only include first 5 297 assert "Fact 0" in result 298 assert "Fact 4" in result 299 assert "Fact 5" not in result 300 301 def test_none_content_handled(self): 302 profile = {"memories": [{"content": None}, {"content": "Real fact"}]} 303 result = _build_overlay(profile, {}) 304 assert "Real fact" in result 305 306 def test_truncation_at_320_chars(self): 307 long_content = "x" * 500 308 profile = {"memories": [{"content": long_content}]} 309 result = _build_overlay(profile, {}) 310 # Each item is compacted to 320 chars max 311 for line in result.split("\n"): 312 if line.startswith("- "): 313 assert len(line) <= 322 # "- " + 320 314 315 316 # =========================================================================== 317 # RetainDBMemoryProvider tests 318 # =========================================================================== 319 320 class TestRetainDBMemoryProvider: 321 """Test the main plugin class.""" 322 323 def _make_provider(self, tmp_path, monkeypatch, api_key="rdb-test-key"): 324 monkeypatch.setenv("RETAINDB_API_KEY", api_key) 325 monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes")) 326 (tmp_path / ".hermes").mkdir(exist_ok=True) 327 provider = RetainDBMemoryProvider() 328 return provider 329 330 def test_name(self): 331 p = RetainDBMemoryProvider() 332 assert p.name == "retaindb" 333 334 def test_is_available_without_key(self): 335 p = RetainDBMemoryProvider() 336 assert p.is_available() is False 337 338 def test_is_available_with_key(self, monkeypatch): 339 monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test") 340 p = RetainDBMemoryProvider() 341 assert p.is_available() is True 342 343 def test_config_schema(self): 344 p = RetainDBMemoryProvider() 345 schema = p.get_config_schema() 346 assert len(schema) == 3 347 keys = [s["key"] for s in schema] 348 assert "api_key" in keys 349 assert "base_url" in keys 350 assert "project" in keys 351 352 def test_initialize_creates_client_and_queue(self, tmp_path, monkeypatch): 353 p = self._make_provider(tmp_path, monkeypatch) 354 p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) 355 assert p._client is not None 356 assert p._queue is not None 357 assert p._session_id == "test-session" 358 p.shutdown() 359 360 def test_initialize_default_project(self, tmp_path, monkeypatch): 361 p = self._make_provider(tmp_path, monkeypatch) 362 p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) 363 assert p._client.project == "default" 364 p.shutdown() 365 366 def test_initialize_explicit_project(self, tmp_path, monkeypatch): 367 monkeypatch.setenv("RETAINDB_PROJECT", "my-project") 368 p = self._make_provider(tmp_path, monkeypatch) 369 p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) 370 assert p._client.project == "my-project" 371 p.shutdown() 372 373 def test_initialize_profile_project(self, tmp_path, monkeypatch): 374 p = self._make_provider(tmp_path, monkeypatch) 375 profile_home = str(tmp_path / "profiles" / "coder") 376 p.initialize("test-session", hermes_home=profile_home) 377 assert p._client.project == "hermes-coder" 378 p.shutdown() 379 380 def test_initialize_seeds_soul_md(self, tmp_path, monkeypatch): 381 p = self._make_provider(tmp_path, monkeypatch) 382 soul_path = tmp_path / ".hermes" / "SOUL.md" 383 soul_path.write_text("I am a helpful agent.") 384 with patch.object(RetainDBMemoryProvider, "_seed_soul") as mock_seed: 385 p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) 386 # Give thread time to start 387 time.sleep(0.5) 388 mock_seed.assert_called_once_with("I am a helpful agent.") 389 p.shutdown() 390 391 def test_system_prompt_block(self, tmp_path, monkeypatch): 392 p = self._make_provider(tmp_path, monkeypatch) 393 p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) 394 block = p.system_prompt_block() 395 assert "RetainDB Memory" in block 396 assert "Active" in block 397 p.shutdown() 398 399 def test_handle_tool_call_not_initialized(self): 400 p = RetainDBMemoryProvider() 401 result = json.loads(p.handle_tool_call("retaindb_profile", {})) 402 assert "error" in result 403 assert "not initialized" in result["error"] 404 405 def test_handle_tool_call_unknown_tool(self, tmp_path, monkeypatch): 406 p = self._make_provider(tmp_path, monkeypatch) 407 p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) 408 result = json.loads(p.handle_tool_call("retaindb_nonexistent", {})) 409 assert result == {"error": "Unknown tool: retaindb_nonexistent"} 410 p.shutdown() 411 412 def test_dispatch_profile(self, tmp_path, monkeypatch): 413 p = self._make_provider(tmp_path, monkeypatch) 414 p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) 415 with patch.object(p._client, "get_profile", return_value={"memories": []}): 416 result = json.loads(p.handle_tool_call("retaindb_profile", {})) 417 assert "memories" in result 418 p.shutdown() 419 420 def test_dispatch_search_requires_query(self, tmp_path, monkeypatch): 421 p = self._make_provider(tmp_path, monkeypatch) 422 p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) 423 result = json.loads(p.handle_tool_call("retaindb_search", {})) 424 assert result == {"error": "query is required"} 425 p.shutdown() 426 427 def test_dispatch_search(self, tmp_path, monkeypatch): 428 p = self._make_provider(tmp_path, monkeypatch) 429 p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) 430 with patch.object(p._client, "search", return_value={"results": [{"content": "found"}]}): 431 result = json.loads(p.handle_tool_call("retaindb_search", {"query": "test"})) 432 assert "results" in result 433 p.shutdown() 434 435 def test_dispatch_search_top_k_capped(self, tmp_path, monkeypatch): 436 p = self._make_provider(tmp_path, monkeypatch) 437 p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) 438 with patch.object(p._client, "search") as mock_search: 439 mock_search.return_value = {"results": []} 440 p.handle_tool_call("retaindb_search", {"query": "test", "top_k": 100}) 441 # top_k should be capped at 20 442 assert mock_search.call_args[1]["top_k"] == 20 443 p.shutdown() 444 445 def test_dispatch_remember(self, tmp_path, monkeypatch): 446 p = self._make_provider(tmp_path, monkeypatch) 447 p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) 448 with patch.object(p._client, "add_memory", return_value={"id": "mem-1"}): 449 result = json.loads(p.handle_tool_call("retaindb_remember", {"content": "test fact"})) 450 assert result["id"] == "mem-1" 451 p.shutdown() 452 453 def test_dispatch_remember_requires_content(self, tmp_path, monkeypatch): 454 p = self._make_provider(tmp_path, monkeypatch) 455 p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) 456 result = json.loads(p.handle_tool_call("retaindb_remember", {})) 457 assert result == {"error": "content is required"} 458 p.shutdown() 459 460 def test_dispatch_forget(self, tmp_path, monkeypatch): 461 p = self._make_provider(tmp_path, monkeypatch) 462 p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) 463 with patch.object(p._client, "delete_memory", return_value={"deleted": True}): 464 result = json.loads(p.handle_tool_call("retaindb_forget", {"memory_id": "mem-1"})) 465 assert result["deleted"] is True 466 p.shutdown() 467 468 def test_dispatch_forget_requires_id(self, tmp_path, monkeypatch): 469 p = self._make_provider(tmp_path, monkeypatch) 470 p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) 471 result = json.loads(p.handle_tool_call("retaindb_forget", {})) 472 assert result == {"error": "memory_id is required"} 473 p.shutdown() 474 475 def test_dispatch_context(self, tmp_path, monkeypatch): 476 p = self._make_provider(tmp_path, monkeypatch) 477 p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) 478 with patch.object(p._client, "query_context", return_value={"results": [{"content": "relevant"}]}), \ 479 patch.object(p._client, "get_profile", return_value={"memories": []}): 480 result = json.loads(p.handle_tool_call("retaindb_context", {"query": "current task"})) 481 assert "context" in result 482 assert "raw" in result 483 p.shutdown() 484 485 def test_dispatch_file_list(self, tmp_path, monkeypatch): 486 p = self._make_provider(tmp_path, monkeypatch) 487 p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) 488 with patch.object(p._client, "list_files", return_value={"files": []}): 489 result = json.loads(p.handle_tool_call("retaindb_list_files", {})) 490 assert "files" in result 491 p.shutdown() 492 493 def test_dispatch_file_upload_missing_path(self, tmp_path, monkeypatch): 494 p = self._make_provider(tmp_path, monkeypatch) 495 p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) 496 result = json.loads(p.handle_tool_call("retaindb_upload_file", {})) 497 assert "error" in result 498 499 def test_dispatch_file_upload_not_found(self, tmp_path, monkeypatch): 500 p = self._make_provider(tmp_path, monkeypatch) 501 p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) 502 result = json.loads(p.handle_tool_call("retaindb_upload_file", {"local_path": "/nonexistent/file.txt"})) 503 assert "File not found" in result["error"] 504 p.shutdown() 505 506 def test_dispatch_file_read_requires_id(self, tmp_path, monkeypatch): 507 p = self._make_provider(tmp_path, monkeypatch) 508 p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) 509 result = json.loads(p.handle_tool_call("retaindb_read_file", {})) 510 assert result == {"error": "file_id is required"} 511 p.shutdown() 512 513 def test_dispatch_file_ingest_requires_id(self, tmp_path, monkeypatch): 514 p = self._make_provider(tmp_path, monkeypatch) 515 p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) 516 result = json.loads(p.handle_tool_call("retaindb_ingest_file", {})) 517 assert result == {"error": "file_id is required"} 518 p.shutdown() 519 520 def test_dispatch_file_delete_requires_id(self, tmp_path, monkeypatch): 521 p = self._make_provider(tmp_path, monkeypatch) 522 p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) 523 result = json.loads(p.handle_tool_call("retaindb_delete_file", {})) 524 assert result == {"error": "file_id is required"} 525 p.shutdown() 526 527 def test_handle_tool_call_wraps_exception(self, tmp_path, monkeypatch): 528 p = self._make_provider(tmp_path, monkeypatch) 529 p.initialize("test-session", hermes_home=str(tmp_path / ".hermes")) 530 with patch.object(p._client, "get_profile", side_effect=RuntimeError("API exploded")): 531 result = json.loads(p.handle_tool_call("retaindb_profile", {})) 532 assert "API exploded" in result["error"] 533 p.shutdown() 534 535 536 # =========================================================================== 537 # Prefetch and thread management tests 538 # =========================================================================== 539 540 class TestPrefetch: 541 """Test background prefetch and thread accumulation prevention.""" 542 543 def _make_initialized_provider(self, tmp_path, monkeypatch): 544 monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key") 545 hermes_home = tmp_path / ".hermes" 546 hermes_home.mkdir(exist_ok=True) 547 monkeypatch.setenv("HERMES_HOME", str(hermes_home)) 548 p = RetainDBMemoryProvider() 549 p.initialize("test-session", hermes_home=str(hermes_home)) 550 return p 551 552 def test_queue_prefetch_skips_without_client(self): 553 p = RetainDBMemoryProvider() 554 p.queue_prefetch("test") # Should not raise 555 556 def test_prefetch_returns_empty_when_nothing_cached(self, tmp_path, monkeypatch): 557 p = self._make_initialized_provider(tmp_path, monkeypatch) 558 result = p.prefetch("test") 559 assert result == "" 560 p.shutdown() 561 562 def test_prefetch_consumes_context_result(self, tmp_path, monkeypatch): 563 p = self._make_initialized_provider(tmp_path, monkeypatch) 564 # Manually set the cached result 565 with p._lock: 566 p._context_result = "[RetainDB Context]\nProfile:\n- User likes tests" 567 result = p.prefetch("test") 568 assert "User likes tests" in result 569 # Should be consumed 570 assert p.prefetch("test") == "" 571 p.shutdown() 572 573 def test_prefetch_consumes_dialectic_result(self, tmp_path, monkeypatch): 574 p = self._make_initialized_provider(tmp_path, monkeypatch) 575 with p._lock: 576 p._dialectic_result = "User is a software engineer who prefers Python." 577 result = p.prefetch("test") 578 assert "[RetainDB User Synthesis]" in result 579 assert "software engineer" in result 580 p.shutdown() 581 582 def test_prefetch_consumes_agent_model(self, tmp_path, monkeypatch): 583 p = self._make_initialized_provider(tmp_path, monkeypatch) 584 with p._lock: 585 p._agent_model = { 586 "memory_count": 5, 587 "persona": "Helpful coding assistant", 588 "persistent_instructions": ["Be concise", "Use Python"], 589 "working_style": "Direct and efficient", 590 } 591 result = p.prefetch("test") 592 assert "[RetainDB Agent Self-Model]" in result 593 assert "Helpful coding assistant" in result 594 assert "Be concise" in result 595 assert "Direct and efficient" in result 596 p.shutdown() 597 598 def test_prefetch_skips_empty_agent_model(self, tmp_path, monkeypatch): 599 p = self._make_initialized_provider(tmp_path, monkeypatch) 600 with p._lock: 601 p._agent_model = {"memory_count": 0} 602 result = p.prefetch("test") 603 assert "Agent Self-Model" not in result 604 p.shutdown() 605 606 def test_thread_accumulation_guard(self, tmp_path, monkeypatch): 607 """Verify old prefetch threads are joined before new ones spawn.""" 608 p = self._make_initialized_provider(tmp_path, monkeypatch) 609 # Mock the prefetch methods to be slow 610 with patch.object(p, "_prefetch_context", side_effect=lambda q: time.sleep(0.5)), \ 611 patch.object(p, "_prefetch_dialectic", side_effect=lambda q: time.sleep(0.5)), \ 612 patch.object(p, "_prefetch_agent_model", side_effect=lambda: time.sleep(0.5)): 613 p.queue_prefetch("query 1") 614 first_threads = list(p._prefetch_threads) 615 assert len(first_threads) == 3 616 617 # Call again — should join first batch before spawning new 618 p.queue_prefetch("query 2") 619 second_threads = list(p._prefetch_threads) 620 assert len(second_threads) == 3 621 # Should be different thread objects 622 for t in second_threads: 623 assert t not in first_threads 624 p.shutdown() 625 626 def test_reasoning_level_short(self): 627 assert RetainDBMemoryProvider._reasoning_level("hi") == "low" 628 629 def test_reasoning_level_medium(self): 630 assert RetainDBMemoryProvider._reasoning_level("x" * 200) == "medium" 631 632 def test_reasoning_level_long(self): 633 assert RetainDBMemoryProvider._reasoning_level("x" * 500) == "high" 634 635 636 # =========================================================================== 637 # sync_turn tests 638 # =========================================================================== 639 640 class TestSyncTurn: 641 """Test turn synchronization via the write queue.""" 642 643 def test_sync_turn_enqueues(self, tmp_path, monkeypatch): 644 monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key") 645 hermes_home = tmp_path / ".hermes" 646 hermes_home.mkdir(exist_ok=True) 647 monkeypatch.setenv("HERMES_HOME", str(hermes_home)) 648 p = RetainDBMemoryProvider() 649 p.initialize("test-session", hermes_home=str(hermes_home)) 650 with patch.object(p._queue, "enqueue") as mock_enqueue: 651 p.sync_turn("user msg", "assistant msg") 652 mock_enqueue.assert_called_once() 653 args = mock_enqueue.call_args[0] 654 assert args[0] == "default" # user_id 655 assert args[1] == "test-session" # session_id 656 msgs = args[2] 657 assert len(msgs) == 2 658 assert msgs[0]["role"] == "user" 659 assert msgs[1]["role"] == "assistant" 660 p.shutdown() 661 662 def test_sync_turn_skips_empty_user_content(self, tmp_path, monkeypatch): 663 monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key") 664 hermes_home = tmp_path / ".hermes" 665 hermes_home.mkdir(exist_ok=True) 666 monkeypatch.setenv("HERMES_HOME", str(hermes_home)) 667 p = RetainDBMemoryProvider() 668 p.initialize("test-session", hermes_home=str(hermes_home)) 669 with patch.object(p._queue, "enqueue") as mock_enqueue: 670 p.sync_turn("", "assistant msg") 671 mock_enqueue.assert_not_called() 672 p.shutdown() 673 674 675 # =========================================================================== 676 # on_memory_write hook tests 677 # =========================================================================== 678 679 class TestOnMemoryWrite: 680 """Test the built-in memory mirror hook.""" 681 682 def test_mirrors_add_action(self, tmp_path, monkeypatch): 683 monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key") 684 hermes_home = tmp_path / ".hermes" 685 hermes_home.mkdir(exist_ok=True) 686 monkeypatch.setenv("HERMES_HOME", str(hermes_home)) 687 p = RetainDBMemoryProvider() 688 p.initialize("test-session", hermes_home=str(hermes_home)) 689 with patch.object(p._client, "add_memory", return_value={"id": "mem-1"}) as mock_add: 690 p.on_memory_write("add", "user", "User prefers dark mode") 691 mock_add.assert_called_once() 692 assert mock_add.call_args[1]["memory_type"] == "preference" 693 p.shutdown() 694 695 def test_skips_non_add_action(self, tmp_path, monkeypatch): 696 monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key") 697 hermes_home = tmp_path / ".hermes" 698 hermes_home.mkdir(exist_ok=True) 699 monkeypatch.setenv("HERMES_HOME", str(hermes_home)) 700 p = RetainDBMemoryProvider() 701 p.initialize("test-session", hermes_home=str(hermes_home)) 702 with patch.object(p._client, "add_memory") as mock_add: 703 p.on_memory_write("remove", "user", "something") 704 mock_add.assert_not_called() 705 p.shutdown() 706 707 def test_skips_empty_content(self, tmp_path, monkeypatch): 708 monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key") 709 hermes_home = tmp_path / ".hermes" 710 hermes_home.mkdir(exist_ok=True) 711 monkeypatch.setenv("HERMES_HOME", str(hermes_home)) 712 p = RetainDBMemoryProvider() 713 p.initialize("test-session", hermes_home=str(hermes_home)) 714 with patch.object(p._client, "add_memory") as mock_add: 715 p.on_memory_write("add", "user", "") 716 mock_add.assert_not_called() 717 p.shutdown() 718 719 def test_memory_target_maps_to_type(self, tmp_path, monkeypatch): 720 monkeypatch.setenv("RETAINDB_API_KEY", "rdb-test-key") 721 hermes_home = tmp_path / ".hermes" 722 hermes_home.mkdir(exist_ok=True) 723 monkeypatch.setenv("HERMES_HOME", str(hermes_home)) 724 p = RetainDBMemoryProvider() 725 p.initialize("test-session", hermes_home=str(hermes_home)) 726 with patch.object(p._client, "add_memory", return_value={"id": "mem-1"}) as mock_add: 727 p.on_memory_write("add", "memory", "Some env fact") 728 assert mock_add.call_args[1]["memory_type"] == "factual" 729 p.shutdown() 730 731 732 # =========================================================================== 733 # register() test 734 # =========================================================================== 735 736 class TestRegister: 737 def test_register_calls_register_memory_provider(self): 738 from plugins.memory.retaindb import register 739 ctx = MagicMock() 740 register(ctx) 741 ctx.register_memory_provider.assert_called_once() 742 arg = ctx.register_memory_provider.call_args[0][0] 743 assert isinstance(arg, RetainDBMemoryProvider)