memory_bank.py
1 """Project-wide memory bank: shared, compressed conversation context. 2 3 When a project has `memory_bank_enabled=True` in its options, every chat 4 conversation contributes a short LLM-generated summary to the project's 5 shared memory bank. The bank is then injected into the system prompt of 6 every subsequent chat in that project, giving the agent context across 7 users and across sessions. 8 9 Design notes: 10 11 - **Source of truth** is `OutputDatabase` (one row per inference). It is 12 authoritative across multi-worker deployments, survives Redis TTLs, and 13 stays bound to a `project_id`. Chat sessions in Redis are *not* used 14 here — they aren't enumerable by project and would require a SCAN. 15 - **Summaries are produced by the System LLM** (the global setting also 16 used by Smart Search / Prompt AI). When no System LLM is configured the 17 cron is a no-op. 18 - **Compression ladder**: conversation → day → week → month. The cron 19 rolls older entries up to coarser granularities until the rendered 20 block fits within `memory_bank_max_tokens`. Entries that don't fit 21 even after the coarsest rollup are dropped (oldest first). 22 - **Privacy**: every project member sees summaries derived from every 23 other member's conversations — the project edit form surfaces a 24 disclaimer for this reason. 25 """ 26 from __future__ import annotations 27 28 import logging 29 from datetime import datetime, timedelta, timezone 30 from typing import Any, Iterable, Optional 31 32 from sqlalchemy import func 33 34 from restai.models.databasemodels import ( 35 OutputDatabase, 36 ProjectDatabase, 37 ProjectMemoryBankEntryDatabase, 38 ) 39 from restai.tools import tokens_from_string 40 41 logger = logging.getLogger(__name__) 42 43 44 # Conversations idle for at least this long are considered "settled" and 45 # eligible for summarization. Avoids re-summarizing an active chat between 46 # every turn. 47 CONVERSATION_IDLE_MINUTES = 10 48 49 # Per-summary cap. Keeps any single LLM call deterministic in cost regardless 50 # of how chatty a conversation got. 51 MAX_TURNS_PER_SUMMARY = 40 52 53 # Token budget overrun headroom before the cron triggers compression. We 54 # don't compress on every cron tick — only when the bank is meaningfully 55 # over-budget, to avoid burning System LLM tokens on rolling up entries 56 # that fit within a small overshoot. 57 COMPRESSION_HEADROOM = 1.25 58 59 60 # --------------------------------------------------------------------- helpers 61 62 63 def _now() -> datetime: 64 return datetime.now(timezone.utc).replace(tzinfo=None) 65 66 67 def _day_key(dt: datetime) -> str: 68 return dt.strftime("%Y-%m-%d") 69 70 71 def _week_key(dt: datetime) -> str: 72 iso_year, iso_week, _ = dt.isocalendar() 73 return f"{iso_year:04d}-W{iso_week:02d}" 74 75 76 def _month_key(dt: datetime) -> str: 77 return dt.strftime("%Y-%m") 78 79 80 def _system_llm_complete(brain: Any, db: Any, prompt: str) -> Optional[str]: 81 """Run a one-shot completion via the System LLM. Returns None on any 82 failure so callers can degrade gracefully (skip this entry, not crash).""" 83 llm = brain.get_system_llm(db) 84 if llm is None: 85 return None 86 try: 87 result = llm.llm.complete(prompt) 88 text = result.text if hasattr(result, "text") else str(result) 89 return (text or "").strip() or None 90 except Exception as e: 91 logger.warning("memory_bank: System LLM completion failed: %s", e) 92 return None 93 94 95 # --------------------------------------------------------------------- summarize 96 97 98 _SUMMARY_INSTRUCTIONS = ( 99 "Summarize the conversation below into 1-3 short bullet points. " 100 "Capture: (a) the topic the user was working on, (b) any concrete " 101 "facts, names, IDs, or decisions that emerged, (c) any unresolved " 102 "questions or follow-ups. Skip pleasantries, system messages, and " 103 "tool reasoning. Output bullets only — no preamble, no headings, " 104 "no quoting the user verbatim. Keep it under 80 words." 105 ) 106 107 _DIGEST_INSTRUCTIONS = ( 108 "You are merging multiple conversation summaries into a single short " 109 "digest. Preserve the most important facts, names, IDs, decisions, " 110 "and outstanding questions across all summaries. Output 2-4 short " 111 "bullets — no preamble, no headings. Keep it under 120 words." 112 ) 113 114 115 def _format_messages_for_summary(rows: list[OutputDatabase]) -> str: 116 """Render OutputDatabase rows as a chat transcript for the summarizer.""" 117 lines = [] 118 for r in rows[:MAX_TURNS_PER_SUMMARY]: 119 q = (r.question or "").strip() 120 a = (r.answer or "").strip() 121 if q: 122 lines.append(f"User: {q}") 123 if a: 124 lines.append(f"Assistant: {a}") 125 return "\n".join(lines) 126 127 128 def summarize_conversation( 129 brain: Any, 130 db_wrapper: Any, 131 project_id: int, 132 chat_id: str, 133 ) -> Optional[ProjectMemoryBankEntryDatabase]: 134 """Pull this chat_id's history from OutputDatabase, summarize it via the 135 System LLM, and upsert a 'conversation' memory bank entry. Returns the 136 persisted row, or None when nothing was written (no rows / no LLM / 137 LLM failure).""" 138 sess = db_wrapper.db 139 rows = ( 140 sess.query(OutputDatabase) 141 .filter( 142 OutputDatabase.project_id == project_id, 143 OutputDatabase.chat_id == chat_id, 144 ) 145 .order_by(OutputDatabase.date.asc()) 146 .all() 147 ) 148 if not rows: 149 return None 150 151 transcript = _format_messages_for_summary(rows) 152 if not transcript.strip(): 153 return None 154 155 prompt = f"{_SUMMARY_INSTRUCTIONS}\n\n---\n{transcript}\n---" 156 summary = _system_llm_complete(brain, db_wrapper, prompt) 157 if not summary: 158 return None 159 160 last_at = rows[-1].date or _now() 161 period_key = _day_key(last_at) 162 token_count = tokens_from_string(summary) 163 164 existing = ( 165 sess.query(ProjectMemoryBankEntryDatabase) 166 .filter( 167 ProjectMemoryBankEntryDatabase.project_id == project_id, 168 ProjectMemoryBankEntryDatabase.granularity == "conversation", 169 ProjectMemoryBankEntryDatabase.chat_id == chat_id, 170 ) 171 .first() 172 ) 173 now = _now() 174 if existing is not None: 175 existing.summary = summary 176 existing.token_count = token_count 177 existing.source_message_count = len(rows) 178 existing.last_source_at = last_at 179 existing.period_key = period_key 180 existing.updated_at = now 181 sess.commit() 182 return existing 183 184 row = ProjectMemoryBankEntryDatabase( 185 project_id=project_id, 186 chat_id=chat_id, 187 granularity="conversation", 188 period_key=period_key, 189 summary=summary, 190 token_count=token_count, 191 source_message_count=len(rows), 192 last_source_at=last_at, 193 created_at=now, 194 updated_at=now, 195 ) 196 sess.add(row) 197 sess.commit() 198 return row 199 200 201 # --------------------------------------------------------------------- compress 202 203 204 def _digest_entries( 205 brain: Any, 206 db_wrapper: Any, 207 entries: list[ProjectMemoryBankEntryDatabase], 208 ) -> Optional[str]: 209 """Merge multiple existing summaries into a single coarser digest.""" 210 if not entries: 211 return None 212 parts = [] 213 for e in entries: 214 prefix = f"[{e.granularity}:{e.period_key or ''}]" if e.period_key else f"[{e.granularity}]" 215 parts.append(f"{prefix} {e.summary.strip()}") 216 blob = "\n".join(parts) 217 prompt = f"{_DIGEST_INSTRUCTIONS}\n\n---\n{blob}\n---" 218 return _system_llm_complete(brain, db_wrapper, prompt) 219 220 221 def _rollup( 222 brain: Any, 223 db_wrapper: Any, 224 project_id: int, 225 from_granularity: str, 226 to_granularity: str, 227 age_threshold: timedelta, 228 key_fn, 229 ) -> int: 230 """Group entries of `from_granularity` older than `age_threshold` by 231 `key_fn(last_source_at)` and replace each group with a single digest at 232 `to_granularity`. Returns the number of digests created.""" 233 sess = db_wrapper.db 234 cutoff = _now() - age_threshold 235 rows = ( 236 sess.query(ProjectMemoryBankEntryDatabase) 237 .filter( 238 ProjectMemoryBankEntryDatabase.project_id == project_id, 239 ProjectMemoryBankEntryDatabase.granularity == from_granularity, 240 ProjectMemoryBankEntryDatabase.last_source_at != None, 241 ProjectMemoryBankEntryDatabase.last_source_at < cutoff, 242 ) 243 .all() 244 ) 245 if not rows: 246 return 0 247 248 groups: dict[str, list[ProjectMemoryBankEntryDatabase]] = {} 249 for r in rows: 250 groups.setdefault(key_fn(r.last_source_at), []).append(r) 251 252 created = 0 253 for period_key, group in groups.items(): 254 # If this period already has a digest at the target granularity, 255 # fold the new group into it via a re-digest. 256 existing_digest = ( 257 sess.query(ProjectMemoryBankEntryDatabase) 258 .filter( 259 ProjectMemoryBankEntryDatabase.project_id == project_id, 260 ProjectMemoryBankEntryDatabase.granularity == to_granularity, 261 ProjectMemoryBankEntryDatabase.period_key == period_key, 262 ) 263 .first() 264 ) 265 merge_inputs = list(group) 266 if existing_digest is not None: 267 merge_inputs.append(existing_digest) 268 269 digest = _digest_entries(brain, db_wrapper, merge_inputs) 270 if not digest: 271 # System LLM unavailable / failed — skip this rollup; the cron 272 # will retry next tick. Don't delete sources, don't half-commit. 273 continue 274 275 last_source_at = max((r.last_source_at for r in group), default=_now()) 276 source_message_count = sum(r.source_message_count or 0 for r in group) 277 token_count = tokens_from_string(digest) 278 now = _now() 279 280 if existing_digest is not None: 281 existing_digest.summary = digest 282 existing_digest.token_count = token_count 283 existing_digest.source_message_count += source_message_count 284 existing_digest.last_source_at = max( 285 existing_digest.last_source_at or last_source_at, last_source_at 286 ) 287 existing_digest.updated_at = now 288 else: 289 sess.add(ProjectMemoryBankEntryDatabase( 290 project_id=project_id, 291 chat_id=None, 292 granularity=to_granularity, 293 period_key=period_key, 294 summary=digest, 295 token_count=token_count, 296 source_message_count=source_message_count, 297 last_source_at=last_source_at, 298 created_at=now, 299 updated_at=now, 300 )) 301 for r in group: 302 sess.delete(r) 303 created += 1 304 305 sess.commit() 306 return created 307 308 309 def compress_entries( 310 brain: Any, 311 db_wrapper: Any, 312 project_id: int, 313 max_tokens: int, 314 ) -> None: 315 """Run the rollup ladder until the project's bank fits within budget. 316 317 Order: conversation→day (>1d old), day→week (>7d old), week→month 318 (>30d old). If still over budget after the coarsest rollup, drop the 319 oldest entries until we fit. 320 """ 321 sess = db_wrapper.db 322 323 def total_tokens() -> int: 324 return int( 325 sess.query(func.coalesce(func.sum(ProjectMemoryBankEntryDatabase.token_count), 0)) 326 .filter(ProjectMemoryBankEntryDatabase.project_id == project_id) 327 .scalar() 328 or 0 329 ) 330 331 if total_tokens() <= max_tokens * COMPRESSION_HEADROOM: 332 return 333 334 _rollup(brain, db_wrapper, project_id, "conversation", "day", 335 timedelta(days=1), _day_key) 336 if total_tokens() <= max_tokens: 337 return 338 339 _rollup(brain, db_wrapper, project_id, "day", "week", 340 timedelta(days=7), _week_key) 341 if total_tokens() <= max_tokens: 342 return 343 344 _rollup(brain, db_wrapper, project_id, "week", "month", 345 timedelta(days=30), _month_key) 346 if total_tokens() <= max_tokens: 347 return 348 349 # Last resort: drop oldest entries until we're within budget. Use 350 # coalesce(last_source_at, created_at) so a row with a NULL last_source_at 351 # still has a deterministic ordering position across SQLite/Postgres/MySQL. 352 while total_tokens() > max_tokens: 353 oldest = ( 354 sess.query(ProjectMemoryBankEntryDatabase) 355 .filter(ProjectMemoryBankEntryDatabase.project_id == project_id) 356 .order_by( 357 func.coalesce( 358 ProjectMemoryBankEntryDatabase.last_source_at, 359 ProjectMemoryBankEntryDatabase.created_at, 360 ).asc() 361 ) 362 .first() 363 ) 364 if oldest is None: 365 return 366 sess.delete(oldest) 367 sess.commit() 368 369 370 # --------------------------------------------------------------------- render 371 372 373 _GRANULARITY_ORDER = ("conversation", "day", "week", "month") 374 _GRANULARITY_HEADERS = { 375 "conversation": "Recent conversations", 376 "day": "By day", 377 "week": "By week", 378 "month": "By month", 379 } 380 381 382 def render_for_prompt(db_wrapper: Any, project_id: int, max_tokens: int) -> str: 383 """Produce the memory bank block that gets prepended to the system prompt. 384 385 Returns an empty string when there are no entries (so callers can 386 cheaply check `if block:` before bothering to splice it in). 387 """ 388 sess = db_wrapper.db 389 rows = ( 390 sess.query(ProjectMemoryBankEntryDatabase) 391 .filter(ProjectMemoryBankEntryDatabase.project_id == project_id) 392 .all() 393 ) 394 if not rows: 395 return "" 396 397 by_gran: dict[str, list[ProjectMemoryBankEntryDatabase]] = {} 398 for r in rows: 399 by_gran.setdefault(r.granularity, []).append(r) 400 for entries in by_gran.values(): 401 entries.sort( 402 key=lambda e: e.last_source_at or e.updated_at or _now(), 403 reverse=True, 404 ) 405 406 lines = ["[Project Memory Bank — context aggregated from prior conversations in this project. Use only when directly relevant to the current request.]"] 407 used = tokens_from_string(lines[0]) 408 409 for gran in _GRANULARITY_ORDER: 410 entries = by_gran.get(gran, []) 411 if not entries: 412 continue 413 header = f"\n## {_GRANULARITY_HEADERS[gran]}" 414 header_tokens = tokens_from_string(header) 415 if used + header_tokens > max_tokens: 416 break 417 lines.append(header) 418 used += header_tokens 419 for e in entries: 420 label = e.period_key or (e.chat_id[:8] if e.chat_id else "") 421 body = e.summary.strip() 422 chunk = f"\n- ({label}) {body}" if label else f"\n- {body}" 423 chunk_tokens = tokens_from_string(chunk) 424 if used + chunk_tokens > max_tokens: 425 break 426 lines.append(chunk) 427 used += chunk_tokens 428 429 return "".join(lines).strip() 430 431 432 # --------------------------------------------------------------------- public 433 434 435 def list_enabled_projects(db_wrapper: Any) -> Iterable[ProjectDatabase]: 436 """Yield projects with memory_bank_enabled=True. Done by inspecting the 437 options JSON blob — there's no dedicated column.""" 438 import json 439 440 rows = ( 441 db_wrapper.db.query(ProjectDatabase) 442 .filter(ProjectDatabase.type == "agent") 443 .all() 444 ) 445 for proj in rows: 446 try: 447 opts = json.loads(proj.options) if proj.options else {} 448 except Exception: 449 continue 450 if opts.get("memory_bank_enabled"): 451 yield proj 452 453 454 def chat_ids_needing_refresh( 455 db_wrapper: Any, 456 project_id: int, 457 idle_minutes: int = CONVERSATION_IDLE_MINUTES, 458 ) -> list[str]: 459 """Return chat_ids that have new OutputDatabase rows since the last 460 summarization and are now idle (last activity older than `idle_minutes`). 461 """ 462 sess = db_wrapper.db 463 cutoff = _now() - timedelta(minutes=idle_minutes) 464 465 latest_per_chat = ( 466 sess.query( 467 OutputDatabase.chat_id, 468 func.max(OutputDatabase.date).label("latest"), 469 ) 470 .filter( 471 OutputDatabase.project_id == project_id, 472 OutputDatabase.chat_id.isnot(None), 473 ) 474 .group_by(OutputDatabase.chat_id) 475 .having(func.max(OutputDatabase.date) <= cutoff) 476 .all() 477 ) 478 479 out: list[str] = [] 480 for chat_id, latest in latest_per_chat: 481 if not chat_id: 482 continue 483 existing = ( 484 sess.query(ProjectMemoryBankEntryDatabase) 485 .filter( 486 ProjectMemoryBankEntryDatabase.project_id == project_id, 487 ProjectMemoryBankEntryDatabase.granularity == "conversation", 488 ProjectMemoryBankEntryDatabase.chat_id == chat_id, 489 ) 490 .first() 491 ) 492 if existing is None or (existing.last_source_at or datetime.min) < latest: 493 out.append(chat_id) 494 return out