/ restai / memory_bank.py
memory_bank.py
  1  """Project-wide memory bank: shared, compressed conversation context.
  2  
  3  When a project has `memory_bank_enabled=True` in its options, every chat
  4  conversation contributes a short LLM-generated summary to the project's
  5  shared memory bank. The bank is then injected into the system prompt of
  6  every subsequent chat in that project, giving the agent context across
  7  users and across sessions.
  8  
  9  Design notes:
 10  
 11  - **Source of truth** is `OutputDatabase` (one row per inference). It is
 12    authoritative across multi-worker deployments, survives Redis TTLs, and
 13    stays bound to a `project_id`. Chat sessions in Redis are *not* used
 14    here — they aren't enumerable by project and would require a SCAN.
 15  - **Summaries are produced by the System LLM** (the global setting also
 16    used by Smart Search / Prompt AI). When no System LLM is configured the
 17    cron is a no-op.
 18  - **Compression ladder**: conversation → day → week → month. The cron
 19    rolls older entries up to coarser granularities until the rendered
 20    block fits within `memory_bank_max_tokens`. Entries that don't fit
 21    even after the coarsest rollup are dropped (oldest first).
 22  - **Privacy**: every project member sees summaries derived from every
 23    other member's conversations — the project edit form surfaces a
 24    disclaimer for this reason.
 25  """
 26  from __future__ import annotations
 27  
 28  import logging
 29  from datetime import datetime, timedelta, timezone
 30  from typing import Any, Iterable, Optional
 31  
 32  from sqlalchemy import func
 33  
 34  from restai.models.databasemodels import (
 35      OutputDatabase,
 36      ProjectDatabase,
 37      ProjectMemoryBankEntryDatabase,
 38  )
 39  from restai.tools import tokens_from_string
 40  
 41  logger = logging.getLogger(__name__)
 42  
 43  
 44  # Conversations idle for at least this long are considered "settled" and
 45  # eligible for summarization. Avoids re-summarizing an active chat between
 46  # every turn.
 47  CONVERSATION_IDLE_MINUTES = 10
 48  
 49  # Per-summary cap. Keeps any single LLM call deterministic in cost regardless
 50  # of how chatty a conversation got.
 51  MAX_TURNS_PER_SUMMARY = 40
 52  
 53  # Token budget overrun headroom before the cron triggers compression. We
 54  # don't compress on every cron tick — only when the bank is meaningfully
 55  # over-budget, to avoid burning System LLM tokens on rolling up entries
 56  # that fit within a small overshoot.
 57  COMPRESSION_HEADROOM = 1.25
 58  
 59  
 60  # --------------------------------------------------------------------- helpers
 61  
 62  
 63  def _now() -> datetime:
 64      return datetime.now(timezone.utc).replace(tzinfo=None)
 65  
 66  
 67  def _day_key(dt: datetime) -> str:
 68      return dt.strftime("%Y-%m-%d")
 69  
 70  
 71  def _week_key(dt: datetime) -> str:
 72      iso_year, iso_week, _ = dt.isocalendar()
 73      return f"{iso_year:04d}-W{iso_week:02d}"
 74  
 75  
 76  def _month_key(dt: datetime) -> str:
 77      return dt.strftime("%Y-%m")
 78  
 79  
 80  def _system_llm_complete(brain: Any, db: Any, prompt: str) -> Optional[str]:
 81      """Run a one-shot completion via the System LLM. Returns None on any
 82      failure so callers can degrade gracefully (skip this entry, not crash)."""
 83      llm = brain.get_system_llm(db)
 84      if llm is None:
 85          return None
 86      try:
 87          result = llm.llm.complete(prompt)
 88          text = result.text if hasattr(result, "text") else str(result)
 89          return (text or "").strip() or None
 90      except Exception as e:
 91          logger.warning("memory_bank: System LLM completion failed: %s", e)
 92          return None
 93  
 94  
 95  # --------------------------------------------------------------------- summarize
 96  
 97  
 98  _SUMMARY_INSTRUCTIONS = (
 99      "Summarize the conversation below into 1-3 short bullet points. "
100      "Capture: (a) the topic the user was working on, (b) any concrete "
101      "facts, names, IDs, or decisions that emerged, (c) any unresolved "
102      "questions or follow-ups. Skip pleasantries, system messages, and "
103      "tool reasoning. Output bullets only — no preamble, no headings, "
104      "no quoting the user verbatim. Keep it under 80 words."
105  )
106  
107  _DIGEST_INSTRUCTIONS = (
108      "You are merging multiple conversation summaries into a single short "
109      "digest. Preserve the most important facts, names, IDs, decisions, "
110      "and outstanding questions across all summaries. Output 2-4 short "
111      "bullets — no preamble, no headings. Keep it under 120 words."
112  )
113  
114  
115  def _format_messages_for_summary(rows: list[OutputDatabase]) -> str:
116      """Render OutputDatabase rows as a chat transcript for the summarizer."""
117      lines = []
118      for r in rows[:MAX_TURNS_PER_SUMMARY]:
119          q = (r.question or "").strip()
120          a = (r.answer or "").strip()
121          if q:
122              lines.append(f"User: {q}")
123          if a:
124              lines.append(f"Assistant: {a}")
125      return "\n".join(lines)
126  
127  
128  def summarize_conversation(
129      brain: Any,
130      db_wrapper: Any,
131      project_id: int,
132      chat_id: str,
133  ) -> Optional[ProjectMemoryBankEntryDatabase]:
134      """Pull this chat_id's history from OutputDatabase, summarize it via the
135      System LLM, and upsert a 'conversation' memory bank entry. Returns the
136      persisted row, or None when nothing was written (no rows / no LLM /
137      LLM failure)."""
138      sess = db_wrapper.db
139      rows = (
140          sess.query(OutputDatabase)
141          .filter(
142              OutputDatabase.project_id == project_id,
143              OutputDatabase.chat_id == chat_id,
144          )
145          .order_by(OutputDatabase.date.asc())
146          .all()
147      )
148      if not rows:
149          return None
150  
151      transcript = _format_messages_for_summary(rows)
152      if not transcript.strip():
153          return None
154  
155      prompt = f"{_SUMMARY_INSTRUCTIONS}\n\n---\n{transcript}\n---"
156      summary = _system_llm_complete(brain, db_wrapper, prompt)
157      if not summary:
158          return None
159  
160      last_at = rows[-1].date or _now()
161      period_key = _day_key(last_at)
162      token_count = tokens_from_string(summary)
163  
164      existing = (
165          sess.query(ProjectMemoryBankEntryDatabase)
166          .filter(
167              ProjectMemoryBankEntryDatabase.project_id == project_id,
168              ProjectMemoryBankEntryDatabase.granularity == "conversation",
169              ProjectMemoryBankEntryDatabase.chat_id == chat_id,
170          )
171          .first()
172      )
173      now = _now()
174      if existing is not None:
175          existing.summary = summary
176          existing.token_count = token_count
177          existing.source_message_count = len(rows)
178          existing.last_source_at = last_at
179          existing.period_key = period_key
180          existing.updated_at = now
181          sess.commit()
182          return existing
183  
184      row = ProjectMemoryBankEntryDatabase(
185          project_id=project_id,
186          chat_id=chat_id,
187          granularity="conversation",
188          period_key=period_key,
189          summary=summary,
190          token_count=token_count,
191          source_message_count=len(rows),
192          last_source_at=last_at,
193          created_at=now,
194          updated_at=now,
195      )
196      sess.add(row)
197      sess.commit()
198      return row
199  
200  
201  # --------------------------------------------------------------------- compress
202  
203  
204  def _digest_entries(
205      brain: Any,
206      db_wrapper: Any,
207      entries: list[ProjectMemoryBankEntryDatabase],
208  ) -> Optional[str]:
209      """Merge multiple existing summaries into a single coarser digest."""
210      if not entries:
211          return None
212      parts = []
213      for e in entries:
214          prefix = f"[{e.granularity}:{e.period_key or ''}]" if e.period_key else f"[{e.granularity}]"
215          parts.append(f"{prefix} {e.summary.strip()}")
216      blob = "\n".join(parts)
217      prompt = f"{_DIGEST_INSTRUCTIONS}\n\n---\n{blob}\n---"
218      return _system_llm_complete(brain, db_wrapper, prompt)
219  
220  
221  def _rollup(
222      brain: Any,
223      db_wrapper: Any,
224      project_id: int,
225      from_granularity: str,
226      to_granularity: str,
227      age_threshold: timedelta,
228      key_fn,
229  ) -> int:
230      """Group entries of `from_granularity` older than `age_threshold` by
231      `key_fn(last_source_at)` and replace each group with a single digest at
232      `to_granularity`. Returns the number of digests created."""
233      sess = db_wrapper.db
234      cutoff = _now() - age_threshold
235      rows = (
236          sess.query(ProjectMemoryBankEntryDatabase)
237          .filter(
238              ProjectMemoryBankEntryDatabase.project_id == project_id,
239              ProjectMemoryBankEntryDatabase.granularity == from_granularity,
240              ProjectMemoryBankEntryDatabase.last_source_at != None,
241              ProjectMemoryBankEntryDatabase.last_source_at < cutoff,
242          )
243          .all()
244      )
245      if not rows:
246          return 0
247  
248      groups: dict[str, list[ProjectMemoryBankEntryDatabase]] = {}
249      for r in rows:
250          groups.setdefault(key_fn(r.last_source_at), []).append(r)
251  
252      created = 0
253      for period_key, group in groups.items():
254          # If this period already has a digest at the target granularity,
255          # fold the new group into it via a re-digest.
256          existing_digest = (
257              sess.query(ProjectMemoryBankEntryDatabase)
258              .filter(
259                  ProjectMemoryBankEntryDatabase.project_id == project_id,
260                  ProjectMemoryBankEntryDatabase.granularity == to_granularity,
261                  ProjectMemoryBankEntryDatabase.period_key == period_key,
262              )
263              .first()
264          )
265          merge_inputs = list(group)
266          if existing_digest is not None:
267              merge_inputs.append(existing_digest)
268  
269          digest = _digest_entries(brain, db_wrapper, merge_inputs)
270          if not digest:
271              # System LLM unavailable / failed — skip this rollup; the cron
272              # will retry next tick. Don't delete sources, don't half-commit.
273              continue
274  
275          last_source_at = max((r.last_source_at for r in group), default=_now())
276          source_message_count = sum(r.source_message_count or 0 for r in group)
277          token_count = tokens_from_string(digest)
278          now = _now()
279  
280          if existing_digest is not None:
281              existing_digest.summary = digest
282              existing_digest.token_count = token_count
283              existing_digest.source_message_count += source_message_count
284              existing_digest.last_source_at = max(
285                  existing_digest.last_source_at or last_source_at, last_source_at
286              )
287              existing_digest.updated_at = now
288          else:
289              sess.add(ProjectMemoryBankEntryDatabase(
290                  project_id=project_id,
291                  chat_id=None,
292                  granularity=to_granularity,
293                  period_key=period_key,
294                  summary=digest,
295                  token_count=token_count,
296                  source_message_count=source_message_count,
297                  last_source_at=last_source_at,
298                  created_at=now,
299                  updated_at=now,
300              ))
301          for r in group:
302              sess.delete(r)
303          created += 1
304  
305      sess.commit()
306      return created
307  
308  
309  def compress_entries(
310      brain: Any,
311      db_wrapper: Any,
312      project_id: int,
313      max_tokens: int,
314  ) -> None:
315      """Run the rollup ladder until the project's bank fits within budget.
316  
317      Order: conversation→day (>1d old), day→week (>7d old), week→month
318      (>30d old). If still over budget after the coarsest rollup, drop the
319      oldest entries until we fit.
320      """
321      sess = db_wrapper.db
322  
323      def total_tokens() -> int:
324          return int(
325              sess.query(func.coalesce(func.sum(ProjectMemoryBankEntryDatabase.token_count), 0))
326              .filter(ProjectMemoryBankEntryDatabase.project_id == project_id)
327              .scalar()
328              or 0
329          )
330  
331      if total_tokens() <= max_tokens * COMPRESSION_HEADROOM:
332          return
333  
334      _rollup(brain, db_wrapper, project_id, "conversation", "day",
335              timedelta(days=1), _day_key)
336      if total_tokens() <= max_tokens:
337          return
338  
339      _rollup(brain, db_wrapper, project_id, "day", "week",
340              timedelta(days=7), _week_key)
341      if total_tokens() <= max_tokens:
342          return
343  
344      _rollup(brain, db_wrapper, project_id, "week", "month",
345              timedelta(days=30), _month_key)
346      if total_tokens() <= max_tokens:
347          return
348  
349      # Last resort: drop oldest entries until we're within budget. Use
350      # coalesce(last_source_at, created_at) so a row with a NULL last_source_at
351      # still has a deterministic ordering position across SQLite/Postgres/MySQL.
352      while total_tokens() > max_tokens:
353          oldest = (
354              sess.query(ProjectMemoryBankEntryDatabase)
355              .filter(ProjectMemoryBankEntryDatabase.project_id == project_id)
356              .order_by(
357                  func.coalesce(
358                      ProjectMemoryBankEntryDatabase.last_source_at,
359                      ProjectMemoryBankEntryDatabase.created_at,
360                  ).asc()
361              )
362              .first()
363          )
364          if oldest is None:
365              return
366          sess.delete(oldest)
367          sess.commit()
368  
369  
370  # --------------------------------------------------------------------- render
371  
372  
373  _GRANULARITY_ORDER = ("conversation", "day", "week", "month")
374  _GRANULARITY_HEADERS = {
375      "conversation": "Recent conversations",
376      "day": "By day",
377      "week": "By week",
378      "month": "By month",
379  }
380  
381  
382  def render_for_prompt(db_wrapper: Any, project_id: int, max_tokens: int) -> str:
383      """Produce the memory bank block that gets prepended to the system prompt.
384  
385      Returns an empty string when there are no entries (so callers can
386      cheaply check `if block:` before bothering to splice it in).
387      """
388      sess = db_wrapper.db
389      rows = (
390          sess.query(ProjectMemoryBankEntryDatabase)
391          .filter(ProjectMemoryBankEntryDatabase.project_id == project_id)
392          .all()
393      )
394      if not rows:
395          return ""
396  
397      by_gran: dict[str, list[ProjectMemoryBankEntryDatabase]] = {}
398      for r in rows:
399          by_gran.setdefault(r.granularity, []).append(r)
400      for entries in by_gran.values():
401          entries.sort(
402              key=lambda e: e.last_source_at or e.updated_at or _now(),
403              reverse=True,
404          )
405  
406      lines = ["[Project Memory Bank — context aggregated from prior conversations in this project. Use only when directly relevant to the current request.]"]
407      used = tokens_from_string(lines[0])
408  
409      for gran in _GRANULARITY_ORDER:
410          entries = by_gran.get(gran, [])
411          if not entries:
412              continue
413          header = f"\n## {_GRANULARITY_HEADERS[gran]}"
414          header_tokens = tokens_from_string(header)
415          if used + header_tokens > max_tokens:
416              break
417          lines.append(header)
418          used += header_tokens
419          for e in entries:
420              label = e.period_key or (e.chat_id[:8] if e.chat_id else "")
421              body = e.summary.strip()
422              chunk = f"\n- ({label}) {body}" if label else f"\n- {body}"
423              chunk_tokens = tokens_from_string(chunk)
424              if used + chunk_tokens > max_tokens:
425                  break
426              lines.append(chunk)
427              used += chunk_tokens
428  
429      return "".join(lines).strip()
430  
431  
432  # --------------------------------------------------------------------- public
433  
434  
435  def list_enabled_projects(db_wrapper: Any) -> Iterable[ProjectDatabase]:
436      """Yield projects with memory_bank_enabled=True. Done by inspecting the
437      options JSON blob — there's no dedicated column."""
438      import json
439  
440      rows = (
441          db_wrapper.db.query(ProjectDatabase)
442          .filter(ProjectDatabase.type == "agent")
443          .all()
444      )
445      for proj in rows:
446          try:
447              opts = json.loads(proj.options) if proj.options else {}
448          except Exception:
449              continue
450          if opts.get("memory_bank_enabled"):
451              yield proj
452  
453  
454  def chat_ids_needing_refresh(
455      db_wrapper: Any,
456      project_id: int,
457      idle_minutes: int = CONVERSATION_IDLE_MINUTES,
458  ) -> list[str]:
459      """Return chat_ids that have new OutputDatabase rows since the last
460      summarization and are now idle (last activity older than `idle_minutes`).
461      """
462      sess = db_wrapper.db
463      cutoff = _now() - timedelta(minutes=idle_minutes)
464  
465      latest_per_chat = (
466          sess.query(
467              OutputDatabase.chat_id,
468              func.max(OutputDatabase.date).label("latest"),
469          )
470          .filter(
471              OutputDatabase.project_id == project_id,
472              OutputDatabase.chat_id.isnot(None),
473          )
474          .group_by(OutputDatabase.chat_id)
475          .having(func.max(OutputDatabase.date) <= cutoff)
476          .all()
477      )
478  
479      out: list[str] = []
480      for chat_id, latest in latest_per_chat:
481          if not chat_id:
482              continue
483          existing = (
484              sess.query(ProjectMemoryBankEntryDatabase)
485              .filter(
486                  ProjectMemoryBankEntryDatabase.project_id == project_id,
487                  ProjectMemoryBankEntryDatabase.granularity == "conversation",
488                  ProjectMemoryBankEntryDatabase.chat_id == chat_id,
489              )
490              .first()
491          )
492          if existing is None or (existing.last_source_at or datetime.min) < latest:
493              out.append(chat_id)
494      return out