/ acp_adapter / session.py
session.py
  1  """ACP session manager — maps ACP sessions to Hermes AIAgent instances.
  2  
  3  Sessions are persisted to the shared SessionDB (``~/.hermes/state.db``) so they
  4  survive process restarts and appear in ``session_search``.  When the editor
  5  reconnects after idle/restart, the ``load_session`` / ``resume_session`` calls
  6  find the persisted session in the database and restore the full conversation
  7  history.
  8  """
  9  from __future__ import annotations
 10  
 11  from hermes_constants import get_hermes_home
 12  
 13  import copy
 14  import json
 15  import logging
 16  import os
 17  import re
 18  import sys
 19  import time
 20  import uuid
 21  from datetime import datetime, timezone
 22  from dataclasses import dataclass, field
 23  from threading import Lock
 24  from typing import Any, Dict, List, Optional
 25  
 26  logger = logging.getLogger(__name__)
 27  
 28  
 29  def _win_path_to_wsl(path: str) -> str | None:
 30      """Convert a Windows drive path to its WSL /mnt/<drive>/... equivalent."""
 31      match = re.match(r"^([A-Za-z]):[\\/](.*)$", path)
 32      if not match:
 33          return None
 34      drive = match.group(1).lower()
 35      tail = match.group(2).replace("\\", "/")
 36      return f"/mnt/{drive}/{tail}"
 37  
 38  
 39  def _translate_acp_cwd(cwd: str) -> str:
 40      """Translate Windows ACP cwd values when Hermes itself is running in WSL.
 41  
 42      Windows ACP clients can launch ``hermes acp`` inside WSL while still sending
 43      editor workspaces as Windows drive paths such as ``E:\\Projects``. Store
 44      and execute against the WSL mount path so agents, tools, and persisted ACP
 45      sessions all agree on the usable workspace. Native Linux/macOS keeps the
 46      original cwd unchanged.
 47      """
 48      from hermes_constants import is_wsl
 49  
 50      if not is_wsl():
 51          return cwd
 52      translated = _win_path_to_wsl(str(cwd))
 53      return translated if translated is not None else cwd
 54  
 55  
 56  def _normalize_cwd_for_compare(cwd: str | None) -> str:
 57      raw = str(cwd or ".").strip()
 58      if not raw:
 59          raw = "."
 60      expanded = os.path.expanduser(raw)
 61  
 62      # Normalize Windows drive paths into the equivalent WSL mount form so
 63      # ACP history filters match the same workspace across Windows and WSL.
 64      translated = _win_path_to_wsl(expanded)
 65      if translated is not None:
 66          expanded = translated
 67      elif re.match(r"^/mnt/[A-Za-z]/", expanded):
 68          expanded = f"/mnt/{expanded[5].lower()}/{expanded[7:]}"
 69  
 70      return os.path.normpath(expanded)
 71  
 72  
 73  def _build_session_title(title: Any, preview: Any, cwd: str | None) -> str:
 74      explicit = str(title or "").strip()
 75      if explicit:
 76          return explicit
 77      preview_text = str(preview or "").strip()
 78      if preview_text:
 79          return preview_text
 80      leaf = os.path.basename(str(cwd or "").rstrip("/\\"))
 81      return leaf or "New thread"
 82  
 83  
 84  def _format_updated_at(value: Any) -> str | None:
 85      if value is None:
 86          return None
 87      if isinstance(value, str) and value.strip():
 88          return value
 89      try:
 90          return datetime.fromtimestamp(float(value), tz=timezone.utc).isoformat()
 91      except Exception:
 92          return None
 93  
 94  
 95  def _updated_at_sort_key(value: Any) -> float:
 96      if value is None:
 97          return float("-inf")
 98      if isinstance(value, (int, float)):
 99          return float(value)
100      raw = str(value).strip()
101      if not raw:
102          return float("-inf")
103      try:
104          return datetime.fromisoformat(raw.replace("Z", "+00:00")).timestamp()
105      except Exception:
106          try:
107              return float(raw)
108          except Exception:
109              return float("-inf")
110  
111  
112  def _acp_stderr_print(*args, **kwargs) -> None:
113      """Best-effort human-readable output sink for ACP stdio sessions.
114  
115      ACP reserves stdout for JSON-RPC frames, so any incidental CLI/status output
116      from AIAgent must be redirected away from stdout. Route it to stderr instead.
117      """
118      kwargs = dict(kwargs)
119      kwargs.setdefault("file", sys.stderr)
120      print(*args, **kwargs)
121  
122  
123  def _register_task_cwd(task_id: str, cwd: str) -> None:
124      """Bind a task/session id to the editor's working directory for tools.
125  
126      Zed can launch Hermes from a Windows workspace while the ACP process runs
127      inside WSL. In that case ACP sends cwd as e.g. ``E:\\Projects\\POTI``;
128      local tools need the WSL mount equivalent or subprocess creation fails
129      before the command can run.
130      """
131      if not task_id:
132          return
133      try:
134          from tools.terminal_tool import register_task_env_overrides
135          register_task_env_overrides(task_id, {"cwd": _translate_acp_cwd(cwd)})
136      except Exception:
137          logger.debug("Failed to register ACP task cwd override", exc_info=True)
138  
139  
140  def _expand_acp_enabled_toolsets(
141      toolsets: List[str] | None = None,
142      mcp_server_names: List[str] | None = None,
143  ) -> List[str]:
144      """Return ACP toolsets plus explicit MCP server toolsets for this session."""
145      expanded: List[str] = []
146      for name in list(toolsets or ["hermes-acp"]):
147          if name and name not in expanded:
148              expanded.append(name)
149  
150      for server_name in list(mcp_server_names or []):
151          toolset_name = f"mcp-{server_name}"
152          if server_name and toolset_name not in expanded:
153              expanded.append(toolset_name)
154  
155      return expanded
156  
157  
158  def _clear_task_cwd(task_id: str) -> None:
159      """Remove task-specific cwd overrides for an ACP session."""
160      if not task_id:
161          return
162      try:
163          from tools.terminal_tool import clear_task_env_overrides
164          clear_task_env_overrides(task_id)
165      except Exception:
166          logger.debug("Failed to clear ACP task cwd override", exc_info=True)
167  
168  
169  @dataclass
170  class SessionState:
171      """Tracks per-session state for an ACP-managed Hermes agent."""
172  
173      session_id: str
174      agent: Any  # AIAgent instance
175      cwd: str = "."
176      model: str = ""
177      history: List[Dict[str, Any]] = field(default_factory=list)
178      cancel_event: Any = None  # threading.Event
179      is_running: bool = False
180      queued_prompts: List[str] = field(default_factory=list)
181      runtime_lock: Any = field(default_factory=Lock)
182      current_prompt_text: str = ""
183      interrupted_prompt_text: str = ""
184  
185  
186  class SessionManager:
187      """Thread-safe manager for ACP sessions backed by Hermes AIAgent instances.
188  
189      Sessions are held in-memory for fast access **and** persisted to the
190      shared SessionDB so they survive process restarts and are searchable
191      via ``session_search``.
192      """
193  
194      def __init__(self, agent_factory=None, db=None):
195          """
196          Args:
197              agent_factory: Optional callable that creates an AIAgent-like object.
198                             Used by tests. When omitted, a real AIAgent is created
199                             using the current Hermes runtime provider configuration.
200              db:            Optional SessionDB instance. When omitted, the default
201                             SessionDB (``~/.hermes/state.db``) is lazily created.
202          """
203          self._sessions: Dict[str, SessionState] = {}
204          self._lock = Lock()
205          self._agent_factory = agent_factory
206          self._db_instance = db  # None → lazy-init on first use
207  
208      # ---- public API ---------------------------------------------------------
209  
210      def create_session(self, cwd: str = ".") -> SessionState:
211          """Create a new session with a unique ID and a fresh AIAgent."""
212          import threading
213  
214          cwd = _translate_acp_cwd(cwd)
215          session_id = str(uuid.uuid4())
216          agent = self._make_agent(session_id=session_id, cwd=cwd)
217          state = SessionState(
218              session_id=session_id,
219              agent=agent,
220              cwd=cwd,
221              model=getattr(agent, "model", "") or "",
222              cancel_event=threading.Event(),
223          )
224          with self._lock:
225              self._sessions[session_id] = state
226          _register_task_cwd(session_id, cwd)
227          self._persist(state)
228          logger.info("Created ACP session %s (cwd=%s)", session_id, cwd)
229          return state
230  
231      def get_session(self, session_id: str) -> Optional[SessionState]:
232          """Return the session for *session_id*, or ``None``.
233  
234          If the session is not in memory but exists in the database (e.g. after
235          a process restart), it is transparently restored.
236          """
237          with self._lock:
238              state = self._sessions.get(session_id)
239          if state is not None:
240              return state
241          # Attempt to restore from database.
242          return self._restore(session_id)
243  
244      def remove_session(self, session_id: str) -> bool:
245          """Remove a session from memory and database. Returns True if it existed."""
246          with self._lock:
247              existed = self._sessions.pop(session_id, None) is not None
248          db_existed = self._delete_persisted(session_id)
249          if existed or db_existed:
250              _clear_task_cwd(session_id)
251          return existed or db_existed
252  
253      def fork_session(self, session_id: str, cwd: str = ".") -> Optional[SessionState]:
254          """Deep-copy a session's history into a new session."""
255          import threading
256  
257          cwd = _translate_acp_cwd(cwd)
258          original = self.get_session(session_id)  # checks DB too
259          if original is None:
260              return None
261  
262          new_id = str(uuid.uuid4())
263          agent = self._make_agent(
264              session_id=new_id,
265              cwd=cwd,
266              model=original.model or None,
267          )
268          state = SessionState(
269              session_id=new_id,
270              agent=agent,
271              cwd=cwd,
272              model=getattr(agent, "model", original.model) or original.model,
273              history=copy.deepcopy(original.history),
274              cancel_event=threading.Event(),
275          )
276          with self._lock:
277              self._sessions[new_id] = state
278          _register_task_cwd(new_id, cwd)
279          self._persist(state)
280          logger.info("Forked ACP session %s -> %s", session_id, new_id)
281          return state
282  
283      def list_sessions(self, cwd: str | None = None) -> List[Dict[str, Any]]:
284          """Return lightweight info dicts for all sessions (memory + database)."""
285          normalized_cwd = _normalize_cwd_for_compare(cwd) if cwd else None
286          db = self._get_db()
287          persisted_rows: dict[str, dict[str, Any]] = {}
288  
289          if db is not None:
290              try:
291                  for row in db.list_sessions_rich(source="acp", limit=1000):
292                      persisted_rows[str(row["id"])] = dict(row)
293              except Exception:
294                  logger.debug("Failed to load ACP sessions from DB", exc_info=True)
295  
296          # Collect in-memory sessions first.
297          with self._lock:
298              seen_ids = set(self._sessions.keys())
299              results = []
300              for s in self._sessions.values():
301                  history_len = len(s.history)
302                  if history_len <= 0:
303                      continue
304                  if normalized_cwd and _normalize_cwd_for_compare(s.cwd) != normalized_cwd:
305                      continue
306                  persisted = persisted_rows.get(s.session_id, {})
307                  preview = next(
308                      (
309                          str(msg.get("content") or "").strip()
310                          for msg in s.history
311                          if msg.get("role") == "user" and str(msg.get("content") or "").strip()
312                      ),
313                      persisted.get("preview") or "",
314                  )
315                  results.append(
316                      {
317                          "session_id": s.session_id,
318                          "cwd": s.cwd,
319                          "model": s.model,
320                          "history_len": history_len,
321                          "title": _build_session_title(persisted.get("title"), preview, s.cwd),
322                          "updated_at": _format_updated_at(
323                              persisted.get("last_active") or persisted.get("started_at") or time.time()
324                          ),
325                      }
326                  )
327  
328          # Merge any persisted sessions not currently in memory.
329          for sid, row in persisted_rows.items():
330              if sid in seen_ids:
331                  continue
332              message_count = int(row.get("message_count") or 0)
333              if message_count <= 0:
334                  continue
335              # Extract cwd from model_config JSON.
336              session_cwd = "."
337              mc = row.get("model_config")
338              if mc:
339                  try:
340                      session_cwd = json.loads(mc).get("cwd", ".")
341                  except (json.JSONDecodeError, TypeError):
342                      pass
343              if normalized_cwd and _normalize_cwd_for_compare(session_cwd) != normalized_cwd:
344                  continue
345              results.append({
346                  "session_id": sid,
347                  "cwd": session_cwd,
348                  "model": row.get("model") or "",
349                  "history_len": message_count,
350                  "title": _build_session_title(row.get("title"), row.get("preview"), session_cwd),
351                  "updated_at": _format_updated_at(row.get("last_active") or row.get("started_at")),
352              })
353  
354          results.sort(key=lambda item: _updated_at_sort_key(item.get("updated_at")), reverse=True)
355          return results
356  
357      def update_cwd(self, session_id: str, cwd: str) -> Optional[SessionState]:
358          """Update the working directory for a session and its tool overrides."""
359          cwd = _translate_acp_cwd(cwd)
360          state = self.get_session(session_id)  # checks DB too
361          if state is None:
362              return None
363          state.cwd = cwd
364          _register_task_cwd(session_id, cwd)
365          self._persist(state)
366          return state
367  
368      def cleanup(self) -> None:
369          """Remove all sessions (memory and database) and clear task-specific cwd overrides."""
370          with self._lock:
371              session_ids = list(self._sessions.keys())
372              self._sessions.clear()
373          for session_id in session_ids:
374              _clear_task_cwd(session_id)
375              self._delete_persisted(session_id)
376          # Also remove any DB-only ACP sessions not currently in memory.
377          db = self._get_db()
378          if db is not None:
379              try:
380                  rows = db.search_sessions(source="acp", limit=10000)
381                  for row in rows:
382                      sid = row["id"]
383                      _clear_task_cwd(sid)
384                      db.delete_session(sid)
385              except Exception:
386                  logger.debug("Failed to cleanup ACP sessions from DB", exc_info=True)
387  
388      def save_session(self, session_id: str) -> None:
389          """Persist the current state of a session to the database.
390  
391          Called by the server after prompt completion, slash commands that
392          mutate history, and model switches.
393          """
394          with self._lock:
395              state = self._sessions.get(session_id)
396          if state is not None:
397              self._persist(state)
398  
399      # ---- persistence via SessionDB ------------------------------------------
400  
401      def _get_db(self):
402          """Lazily initialise and return the SessionDB instance.
403  
404          Returns ``None`` if the DB is unavailable (e.g. import error in a
405          minimal test environment).
406  
407          Note: we resolve ``HERMES_HOME`` dynamically rather than relying on
408          the module-level ``DEFAULT_DB_PATH`` constant, because that constant
409          is evaluated at import time and won't reflect env-var changes made
410          later (e.g. by the test fixture ``_isolate_hermes_home``).
411          """
412          if self._db_instance is not None:
413              return self._db_instance
414          try:
415              from hermes_state import SessionDB
416              hermes_home = get_hermes_home()
417              self._db_instance = SessionDB(db_path=hermes_home / "state.db")
418              return self._db_instance
419          except Exception:
420              logger.debug("SessionDB unavailable for ACP persistence", exc_info=True)
421              return None
422  
423      def _persist(self, state: SessionState) -> None:
424          """Write session state to the database.
425  
426          Creates the session record if it doesn't exist, then replaces all
427          stored messages with the current in-memory history.
428          """
429          db = self._get_db()
430          if db is None:
431              return
432  
433          # Ensure model is a plain string (not a MagicMock or other proxy).
434          model_str = str(state.model) if state.model else None
435          session_meta = {"cwd": state.cwd}
436          provider = getattr(state.agent, "provider", None)
437          base_url = getattr(state.agent, "base_url", None)
438          api_mode = getattr(state.agent, "api_mode", None)
439          if isinstance(provider, str) and provider.strip():
440              session_meta["provider"] = provider.strip()
441          if isinstance(base_url, str) and base_url.strip():
442              session_meta["base_url"] = base_url.strip()
443          if isinstance(api_mode, str) and api_mode.strip():
444              session_meta["api_mode"] = api_mode.strip()
445          cwd_json = json.dumps(session_meta)
446  
447          try:
448              # Ensure the session record exists.
449              existing = db.get_session(state.session_id)
450              if existing is None:
451                  db.create_session(
452                      session_id=state.session_id,
453                      source="acp",
454                      model=model_str,
455                      model_config={"cwd": state.cwd},
456                  )
457              else:
458                  # Update model_config (contains cwd) if changed.
459                  try:
460                      with db._lock:
461                          db._conn.execute(
462                              "UPDATE sessions SET model_config = ?, model = COALESCE(?, model) WHERE id = ?",
463                              (cwd_json, model_str, state.session_id),
464                          )
465                          db._conn.commit()
466                  except Exception:
467                      logger.debug("Failed to update ACP session metadata", exc_info=True)
468  
469              # Replace stored messages with current history.
470              db.clear_messages(state.session_id)
471              for msg in state.history:
472                  db.append_message(
473                      session_id=state.session_id,
474                      role=msg.get("role", "user"),
475                      content=msg.get("content"),
476                      tool_name=msg.get("tool_name") or msg.get("name"),
477                      tool_calls=msg.get("tool_calls"),
478                      tool_call_id=msg.get("tool_call_id"),
479                  )
480          except Exception:
481              logger.warning("Failed to persist ACP session %s", state.session_id, exc_info=True)
482  
483      def _restore(self, session_id: str) -> Optional[SessionState]:
484          """Load a session from the database into memory, recreating the AIAgent."""
485          import threading
486  
487          db = self._get_db()
488          if db is None:
489              return None
490  
491          try:
492              row = db.get_session(session_id)
493          except Exception:
494              logger.debug("Failed to query DB for ACP session %s", session_id, exc_info=True)
495              return None
496  
497          if row is None:
498              return None
499  
500          # Only restore ACP sessions.
501          if row.get("source") != "acp":
502              return None
503  
504          # Extract cwd from model_config.
505          cwd = "."
506          requested_provider = row.get("billing_provider")
507          restored_base_url = row.get("billing_base_url")
508          restored_api_mode = None
509          mc = row.get("model_config")
510          if mc:
511              try:
512                  meta = json.loads(mc)
513                  if isinstance(meta, dict):
514                      cwd = meta.get("cwd", ".")
515                      requested_provider = meta.get("provider") or requested_provider
516                      restored_base_url = meta.get("base_url") or restored_base_url
517                      restored_api_mode = meta.get("api_mode") or restored_api_mode
518              except (json.JSONDecodeError, TypeError):
519                  pass
520  
521          model = row.get("model") or None
522  
523          # Load conversation history.
524          try:
525              history = db.get_messages_as_conversation(session_id)
526          except Exception:
527              logger.warning("Failed to load messages for ACP session %s", session_id, exc_info=True)
528              history = []
529  
530          try:
531              agent = self._make_agent(
532                  session_id=session_id,
533                  cwd=cwd,
534                  model=model,
535                  requested_provider=requested_provider,
536                  base_url=restored_base_url,
537                  api_mode=restored_api_mode,
538              )
539          except Exception:
540              logger.warning("Failed to recreate agent for ACP session %s", session_id, exc_info=True)
541              return None
542  
543          state = SessionState(
544              session_id=session_id,
545              agent=agent,
546              cwd=cwd,
547              model=model or getattr(agent, "model", "") or "",
548              history=history,
549              cancel_event=threading.Event(),
550          )
551          with self._lock:
552              self._sessions[session_id] = state
553          _register_task_cwd(session_id, cwd)
554          logger.info("Restored ACP session %s from DB (%d messages)", session_id, len(history))
555          return state
556  
557      def _delete_persisted(self, session_id: str) -> bool:
558          """Delete a session from the database. Returns True if it existed."""
559          db = self._get_db()
560          if db is None:
561              return False
562          try:
563              return db.delete_session(session_id)
564          except Exception:
565              logger.debug("Failed to delete ACP session %s from DB", session_id, exc_info=True)
566              return False
567  
568      # ---- internal -----------------------------------------------------------
569  
570      def _make_agent(
571          self,
572          *,
573          session_id: str,
574          cwd: str,
575          model: str | None = None,
576          requested_provider: str | None = None,
577          base_url: str | None = None,
578          api_mode: str | None = None,
579      ):
580          if self._agent_factory is not None:
581              return self._agent_factory()
582  
583          from run_agent import AIAgent
584          from hermes_cli.config import load_config
585          from hermes_cli.runtime_provider import resolve_runtime_provider
586  
587          config = load_config()
588          model_cfg = config.get("model")
589          default_model = ""
590          config_provider = None
591          if isinstance(model_cfg, dict):
592              default_model = str(model_cfg.get("default") or default_model)
593              config_provider = model_cfg.get("provider")
594          elif isinstance(model_cfg, str) and model_cfg.strip():
595              default_model = model_cfg.strip()
596  
597          configured_mcp_servers = [
598              name
599              for name, cfg in (config.get("mcp_servers") or {}).items()
600              if not isinstance(cfg, dict) or cfg.get("enabled", True) is not False
601          ]
602  
603          kwargs = {
604              "platform": "acp",
605              "enabled_toolsets": _expand_acp_enabled_toolsets(
606                  ["hermes-acp"],
607                  mcp_server_names=configured_mcp_servers,
608              ),
609              "quiet_mode": True,
610              "session_id": session_id,
611              "model": model or default_model,
612          }
613  
614          try:
615              runtime = resolve_runtime_provider(requested=requested_provider or config_provider)
616              kwargs.update(
617                  {
618                      "provider": runtime.get("provider"),
619                      "api_mode": api_mode or runtime.get("api_mode"),
620                      "base_url": base_url or runtime.get("base_url"),
621                      "api_key": runtime.get("api_key"),
622                      "command": runtime.get("command"),
623                      "args": list(runtime.get("args") or []),
624                  }
625              )
626          except Exception:
627              logger.debug("ACP session falling back to default provider resolution", exc_info=True)
628  
629          _register_task_cwd(session_id, cwd)
630          agent = AIAgent(**kwargs)
631          # ACP stdio transport requires stdout to remain protocol-only JSON-RPC.
632          # Route any incidental human-readable agent output to stderr instead.
633          agent._print_fn = _acp_stderr_print
634          return agent