/ acp_adapter / session.py
session.py
1 """ACP session manager — maps ACP sessions to Hermes AIAgent instances. 2 3 Sessions are persisted to the shared SessionDB (``~/.hermes/state.db``) so they 4 survive process restarts and appear in ``session_search``. When the editor 5 reconnects after idle/restart, the ``load_session`` / ``resume_session`` calls 6 find the persisted session in the database and restore the full conversation 7 history. 8 """ 9 from __future__ import annotations 10 11 from hermes_constants import get_hermes_home 12 13 import copy 14 import json 15 import logging 16 import os 17 import re 18 import sys 19 import time 20 import uuid 21 from datetime import datetime, timezone 22 from dataclasses import dataclass, field 23 from threading import Lock 24 from typing import Any, Dict, List, Optional 25 26 logger = logging.getLogger(__name__) 27 28 29 def _win_path_to_wsl(path: str) -> str | None: 30 """Convert a Windows drive path to its WSL /mnt/<drive>/... equivalent.""" 31 match = re.match(r"^([A-Za-z]):[\\/](.*)$", path) 32 if not match: 33 return None 34 drive = match.group(1).lower() 35 tail = match.group(2).replace("\\", "/") 36 return f"/mnt/{drive}/{tail}" 37 38 39 def _translate_acp_cwd(cwd: str) -> str: 40 """Translate Windows ACP cwd values when Hermes itself is running in WSL. 41 42 Windows ACP clients can launch ``hermes acp`` inside WSL while still sending 43 editor workspaces as Windows drive paths such as ``E:\\Projects``. Store 44 and execute against the WSL mount path so agents, tools, and persisted ACP 45 sessions all agree on the usable workspace. Native Linux/macOS keeps the 46 original cwd unchanged. 47 """ 48 from hermes_constants import is_wsl 49 50 if not is_wsl(): 51 return cwd 52 translated = _win_path_to_wsl(str(cwd)) 53 return translated if translated is not None else cwd 54 55 56 def _normalize_cwd_for_compare(cwd: str | None) -> str: 57 raw = str(cwd or ".").strip() 58 if not raw: 59 raw = "." 60 expanded = os.path.expanduser(raw) 61 62 # Normalize Windows drive paths into the equivalent WSL mount form so 63 # ACP history filters match the same workspace across Windows and WSL. 64 translated = _win_path_to_wsl(expanded) 65 if translated is not None: 66 expanded = translated 67 elif re.match(r"^/mnt/[A-Za-z]/", expanded): 68 expanded = f"/mnt/{expanded[5].lower()}/{expanded[7:]}" 69 70 return os.path.normpath(expanded) 71 72 73 def _build_session_title(title: Any, preview: Any, cwd: str | None) -> str: 74 explicit = str(title or "").strip() 75 if explicit: 76 return explicit 77 preview_text = str(preview or "").strip() 78 if preview_text: 79 return preview_text 80 leaf = os.path.basename(str(cwd or "").rstrip("/\\")) 81 return leaf or "New thread" 82 83 84 def _format_updated_at(value: Any) -> str | None: 85 if value is None: 86 return None 87 if isinstance(value, str) and value.strip(): 88 return value 89 try: 90 return datetime.fromtimestamp(float(value), tz=timezone.utc).isoformat() 91 except Exception: 92 return None 93 94 95 def _updated_at_sort_key(value: Any) -> float: 96 if value is None: 97 return float("-inf") 98 if isinstance(value, (int, float)): 99 return float(value) 100 raw = str(value).strip() 101 if not raw: 102 return float("-inf") 103 try: 104 return datetime.fromisoformat(raw.replace("Z", "+00:00")).timestamp() 105 except Exception: 106 try: 107 return float(raw) 108 except Exception: 109 return float("-inf") 110 111 112 def _acp_stderr_print(*args, **kwargs) -> None: 113 """Best-effort human-readable output sink for ACP stdio sessions. 114 115 ACP reserves stdout for JSON-RPC frames, so any incidental CLI/status output 116 from AIAgent must be redirected away from stdout. Route it to stderr instead. 117 """ 118 kwargs = dict(kwargs) 119 kwargs.setdefault("file", sys.stderr) 120 print(*args, **kwargs) 121 122 123 def _register_task_cwd(task_id: str, cwd: str) -> None: 124 """Bind a task/session id to the editor's working directory for tools. 125 126 Zed can launch Hermes from a Windows workspace while the ACP process runs 127 inside WSL. In that case ACP sends cwd as e.g. ``E:\\Projects\\POTI``; 128 local tools need the WSL mount equivalent or subprocess creation fails 129 before the command can run. 130 """ 131 if not task_id: 132 return 133 try: 134 from tools.terminal_tool import register_task_env_overrides 135 register_task_env_overrides(task_id, {"cwd": _translate_acp_cwd(cwd)}) 136 except Exception: 137 logger.debug("Failed to register ACP task cwd override", exc_info=True) 138 139 140 def _expand_acp_enabled_toolsets( 141 toolsets: List[str] | None = None, 142 mcp_server_names: List[str] | None = None, 143 ) -> List[str]: 144 """Return ACP toolsets plus explicit MCP server toolsets for this session.""" 145 expanded: List[str] = [] 146 for name in list(toolsets or ["hermes-acp"]): 147 if name and name not in expanded: 148 expanded.append(name) 149 150 for server_name in list(mcp_server_names or []): 151 toolset_name = f"mcp-{server_name}" 152 if server_name and toolset_name not in expanded: 153 expanded.append(toolset_name) 154 155 return expanded 156 157 158 def _clear_task_cwd(task_id: str) -> None: 159 """Remove task-specific cwd overrides for an ACP session.""" 160 if not task_id: 161 return 162 try: 163 from tools.terminal_tool import clear_task_env_overrides 164 clear_task_env_overrides(task_id) 165 except Exception: 166 logger.debug("Failed to clear ACP task cwd override", exc_info=True) 167 168 169 @dataclass 170 class SessionState: 171 """Tracks per-session state for an ACP-managed Hermes agent.""" 172 173 session_id: str 174 agent: Any # AIAgent instance 175 cwd: str = "." 176 model: str = "" 177 history: List[Dict[str, Any]] = field(default_factory=list) 178 cancel_event: Any = None # threading.Event 179 is_running: bool = False 180 queued_prompts: List[str] = field(default_factory=list) 181 runtime_lock: Any = field(default_factory=Lock) 182 current_prompt_text: str = "" 183 interrupted_prompt_text: str = "" 184 185 186 class SessionManager: 187 """Thread-safe manager for ACP sessions backed by Hermes AIAgent instances. 188 189 Sessions are held in-memory for fast access **and** persisted to the 190 shared SessionDB so they survive process restarts and are searchable 191 via ``session_search``. 192 """ 193 194 def __init__(self, agent_factory=None, db=None): 195 """ 196 Args: 197 agent_factory: Optional callable that creates an AIAgent-like object. 198 Used by tests. When omitted, a real AIAgent is created 199 using the current Hermes runtime provider configuration. 200 db: Optional SessionDB instance. When omitted, the default 201 SessionDB (``~/.hermes/state.db``) is lazily created. 202 """ 203 self._sessions: Dict[str, SessionState] = {} 204 self._lock = Lock() 205 self._agent_factory = agent_factory 206 self._db_instance = db # None → lazy-init on first use 207 208 # ---- public API --------------------------------------------------------- 209 210 def create_session(self, cwd: str = ".") -> SessionState: 211 """Create a new session with a unique ID and a fresh AIAgent.""" 212 import threading 213 214 cwd = _translate_acp_cwd(cwd) 215 session_id = str(uuid.uuid4()) 216 agent = self._make_agent(session_id=session_id, cwd=cwd) 217 state = SessionState( 218 session_id=session_id, 219 agent=agent, 220 cwd=cwd, 221 model=getattr(agent, "model", "") or "", 222 cancel_event=threading.Event(), 223 ) 224 with self._lock: 225 self._sessions[session_id] = state 226 _register_task_cwd(session_id, cwd) 227 self._persist(state) 228 logger.info("Created ACP session %s (cwd=%s)", session_id, cwd) 229 return state 230 231 def get_session(self, session_id: str) -> Optional[SessionState]: 232 """Return the session for *session_id*, or ``None``. 233 234 If the session is not in memory but exists in the database (e.g. after 235 a process restart), it is transparently restored. 236 """ 237 with self._lock: 238 state = self._sessions.get(session_id) 239 if state is not None: 240 return state 241 # Attempt to restore from database. 242 return self._restore(session_id) 243 244 def remove_session(self, session_id: str) -> bool: 245 """Remove a session from memory and database. Returns True if it existed.""" 246 with self._lock: 247 existed = self._sessions.pop(session_id, None) is not None 248 db_existed = self._delete_persisted(session_id) 249 if existed or db_existed: 250 _clear_task_cwd(session_id) 251 return existed or db_existed 252 253 def fork_session(self, session_id: str, cwd: str = ".") -> Optional[SessionState]: 254 """Deep-copy a session's history into a new session.""" 255 import threading 256 257 cwd = _translate_acp_cwd(cwd) 258 original = self.get_session(session_id) # checks DB too 259 if original is None: 260 return None 261 262 new_id = str(uuid.uuid4()) 263 agent = self._make_agent( 264 session_id=new_id, 265 cwd=cwd, 266 model=original.model or None, 267 ) 268 state = SessionState( 269 session_id=new_id, 270 agent=agent, 271 cwd=cwd, 272 model=getattr(agent, "model", original.model) or original.model, 273 history=copy.deepcopy(original.history), 274 cancel_event=threading.Event(), 275 ) 276 with self._lock: 277 self._sessions[new_id] = state 278 _register_task_cwd(new_id, cwd) 279 self._persist(state) 280 logger.info("Forked ACP session %s -> %s", session_id, new_id) 281 return state 282 283 def list_sessions(self, cwd: str | None = None) -> List[Dict[str, Any]]: 284 """Return lightweight info dicts for all sessions (memory + database).""" 285 normalized_cwd = _normalize_cwd_for_compare(cwd) if cwd else None 286 db = self._get_db() 287 persisted_rows: dict[str, dict[str, Any]] = {} 288 289 if db is not None: 290 try: 291 for row in db.list_sessions_rich(source="acp", limit=1000): 292 persisted_rows[str(row["id"])] = dict(row) 293 except Exception: 294 logger.debug("Failed to load ACP sessions from DB", exc_info=True) 295 296 # Collect in-memory sessions first. 297 with self._lock: 298 seen_ids = set(self._sessions.keys()) 299 results = [] 300 for s in self._sessions.values(): 301 history_len = len(s.history) 302 if history_len <= 0: 303 continue 304 if normalized_cwd and _normalize_cwd_for_compare(s.cwd) != normalized_cwd: 305 continue 306 persisted = persisted_rows.get(s.session_id, {}) 307 preview = next( 308 ( 309 str(msg.get("content") or "").strip() 310 for msg in s.history 311 if msg.get("role") == "user" and str(msg.get("content") or "").strip() 312 ), 313 persisted.get("preview") or "", 314 ) 315 results.append( 316 { 317 "session_id": s.session_id, 318 "cwd": s.cwd, 319 "model": s.model, 320 "history_len": history_len, 321 "title": _build_session_title(persisted.get("title"), preview, s.cwd), 322 "updated_at": _format_updated_at( 323 persisted.get("last_active") or persisted.get("started_at") or time.time() 324 ), 325 } 326 ) 327 328 # Merge any persisted sessions not currently in memory. 329 for sid, row in persisted_rows.items(): 330 if sid in seen_ids: 331 continue 332 message_count = int(row.get("message_count") or 0) 333 if message_count <= 0: 334 continue 335 # Extract cwd from model_config JSON. 336 session_cwd = "." 337 mc = row.get("model_config") 338 if mc: 339 try: 340 session_cwd = json.loads(mc).get("cwd", ".") 341 except (json.JSONDecodeError, TypeError): 342 pass 343 if normalized_cwd and _normalize_cwd_for_compare(session_cwd) != normalized_cwd: 344 continue 345 results.append({ 346 "session_id": sid, 347 "cwd": session_cwd, 348 "model": row.get("model") or "", 349 "history_len": message_count, 350 "title": _build_session_title(row.get("title"), row.get("preview"), session_cwd), 351 "updated_at": _format_updated_at(row.get("last_active") or row.get("started_at")), 352 }) 353 354 results.sort(key=lambda item: _updated_at_sort_key(item.get("updated_at")), reverse=True) 355 return results 356 357 def update_cwd(self, session_id: str, cwd: str) -> Optional[SessionState]: 358 """Update the working directory for a session and its tool overrides.""" 359 cwd = _translate_acp_cwd(cwd) 360 state = self.get_session(session_id) # checks DB too 361 if state is None: 362 return None 363 state.cwd = cwd 364 _register_task_cwd(session_id, cwd) 365 self._persist(state) 366 return state 367 368 def cleanup(self) -> None: 369 """Remove all sessions (memory and database) and clear task-specific cwd overrides.""" 370 with self._lock: 371 session_ids = list(self._sessions.keys()) 372 self._sessions.clear() 373 for session_id in session_ids: 374 _clear_task_cwd(session_id) 375 self._delete_persisted(session_id) 376 # Also remove any DB-only ACP sessions not currently in memory. 377 db = self._get_db() 378 if db is not None: 379 try: 380 rows = db.search_sessions(source="acp", limit=10000) 381 for row in rows: 382 sid = row["id"] 383 _clear_task_cwd(sid) 384 db.delete_session(sid) 385 except Exception: 386 logger.debug("Failed to cleanup ACP sessions from DB", exc_info=True) 387 388 def save_session(self, session_id: str) -> None: 389 """Persist the current state of a session to the database. 390 391 Called by the server after prompt completion, slash commands that 392 mutate history, and model switches. 393 """ 394 with self._lock: 395 state = self._sessions.get(session_id) 396 if state is not None: 397 self._persist(state) 398 399 # ---- persistence via SessionDB ------------------------------------------ 400 401 def _get_db(self): 402 """Lazily initialise and return the SessionDB instance. 403 404 Returns ``None`` if the DB is unavailable (e.g. import error in a 405 minimal test environment). 406 407 Note: we resolve ``HERMES_HOME`` dynamically rather than relying on 408 the module-level ``DEFAULT_DB_PATH`` constant, because that constant 409 is evaluated at import time and won't reflect env-var changes made 410 later (e.g. by the test fixture ``_isolate_hermes_home``). 411 """ 412 if self._db_instance is not None: 413 return self._db_instance 414 try: 415 from hermes_state import SessionDB 416 hermes_home = get_hermes_home() 417 self._db_instance = SessionDB(db_path=hermes_home / "state.db") 418 return self._db_instance 419 except Exception: 420 logger.debug("SessionDB unavailable for ACP persistence", exc_info=True) 421 return None 422 423 def _persist(self, state: SessionState) -> None: 424 """Write session state to the database. 425 426 Creates the session record if it doesn't exist, then replaces all 427 stored messages with the current in-memory history. 428 """ 429 db = self._get_db() 430 if db is None: 431 return 432 433 # Ensure model is a plain string (not a MagicMock or other proxy). 434 model_str = str(state.model) if state.model else None 435 session_meta = {"cwd": state.cwd} 436 provider = getattr(state.agent, "provider", None) 437 base_url = getattr(state.agent, "base_url", None) 438 api_mode = getattr(state.agent, "api_mode", None) 439 if isinstance(provider, str) and provider.strip(): 440 session_meta["provider"] = provider.strip() 441 if isinstance(base_url, str) and base_url.strip(): 442 session_meta["base_url"] = base_url.strip() 443 if isinstance(api_mode, str) and api_mode.strip(): 444 session_meta["api_mode"] = api_mode.strip() 445 cwd_json = json.dumps(session_meta) 446 447 try: 448 # Ensure the session record exists. 449 existing = db.get_session(state.session_id) 450 if existing is None: 451 db.create_session( 452 session_id=state.session_id, 453 source="acp", 454 model=model_str, 455 model_config={"cwd": state.cwd}, 456 ) 457 else: 458 # Update model_config (contains cwd) if changed. 459 try: 460 with db._lock: 461 db._conn.execute( 462 "UPDATE sessions SET model_config = ?, model = COALESCE(?, model) WHERE id = ?", 463 (cwd_json, model_str, state.session_id), 464 ) 465 db._conn.commit() 466 except Exception: 467 logger.debug("Failed to update ACP session metadata", exc_info=True) 468 469 # Replace stored messages with current history. 470 db.clear_messages(state.session_id) 471 for msg in state.history: 472 db.append_message( 473 session_id=state.session_id, 474 role=msg.get("role", "user"), 475 content=msg.get("content"), 476 tool_name=msg.get("tool_name") or msg.get("name"), 477 tool_calls=msg.get("tool_calls"), 478 tool_call_id=msg.get("tool_call_id"), 479 ) 480 except Exception: 481 logger.warning("Failed to persist ACP session %s", state.session_id, exc_info=True) 482 483 def _restore(self, session_id: str) -> Optional[SessionState]: 484 """Load a session from the database into memory, recreating the AIAgent.""" 485 import threading 486 487 db = self._get_db() 488 if db is None: 489 return None 490 491 try: 492 row = db.get_session(session_id) 493 except Exception: 494 logger.debug("Failed to query DB for ACP session %s", session_id, exc_info=True) 495 return None 496 497 if row is None: 498 return None 499 500 # Only restore ACP sessions. 501 if row.get("source") != "acp": 502 return None 503 504 # Extract cwd from model_config. 505 cwd = "." 506 requested_provider = row.get("billing_provider") 507 restored_base_url = row.get("billing_base_url") 508 restored_api_mode = None 509 mc = row.get("model_config") 510 if mc: 511 try: 512 meta = json.loads(mc) 513 if isinstance(meta, dict): 514 cwd = meta.get("cwd", ".") 515 requested_provider = meta.get("provider") or requested_provider 516 restored_base_url = meta.get("base_url") or restored_base_url 517 restored_api_mode = meta.get("api_mode") or restored_api_mode 518 except (json.JSONDecodeError, TypeError): 519 pass 520 521 model = row.get("model") or None 522 523 # Load conversation history. 524 try: 525 history = db.get_messages_as_conversation(session_id) 526 except Exception: 527 logger.warning("Failed to load messages for ACP session %s", session_id, exc_info=True) 528 history = [] 529 530 try: 531 agent = self._make_agent( 532 session_id=session_id, 533 cwd=cwd, 534 model=model, 535 requested_provider=requested_provider, 536 base_url=restored_base_url, 537 api_mode=restored_api_mode, 538 ) 539 except Exception: 540 logger.warning("Failed to recreate agent for ACP session %s", session_id, exc_info=True) 541 return None 542 543 state = SessionState( 544 session_id=session_id, 545 agent=agent, 546 cwd=cwd, 547 model=model or getattr(agent, "model", "") or "", 548 history=history, 549 cancel_event=threading.Event(), 550 ) 551 with self._lock: 552 self._sessions[session_id] = state 553 _register_task_cwd(session_id, cwd) 554 logger.info("Restored ACP session %s from DB (%d messages)", session_id, len(history)) 555 return state 556 557 def _delete_persisted(self, session_id: str) -> bool: 558 """Delete a session from the database. Returns True if it existed.""" 559 db = self._get_db() 560 if db is None: 561 return False 562 try: 563 return db.delete_session(session_id) 564 except Exception: 565 logger.debug("Failed to delete ACP session %s from DB", session_id, exc_info=True) 566 return False 567 568 # ---- internal ----------------------------------------------------------- 569 570 def _make_agent( 571 self, 572 *, 573 session_id: str, 574 cwd: str, 575 model: str | None = None, 576 requested_provider: str | None = None, 577 base_url: str | None = None, 578 api_mode: str | None = None, 579 ): 580 if self._agent_factory is not None: 581 return self._agent_factory() 582 583 from run_agent import AIAgent 584 from hermes_cli.config import load_config 585 from hermes_cli.runtime_provider import resolve_runtime_provider 586 587 config = load_config() 588 model_cfg = config.get("model") 589 default_model = "" 590 config_provider = None 591 if isinstance(model_cfg, dict): 592 default_model = str(model_cfg.get("default") or default_model) 593 config_provider = model_cfg.get("provider") 594 elif isinstance(model_cfg, str) and model_cfg.strip(): 595 default_model = model_cfg.strip() 596 597 configured_mcp_servers = [ 598 name 599 for name, cfg in (config.get("mcp_servers") or {}).items() 600 if not isinstance(cfg, dict) or cfg.get("enabled", True) is not False 601 ] 602 603 kwargs = { 604 "platform": "acp", 605 "enabled_toolsets": _expand_acp_enabled_toolsets( 606 ["hermes-acp"], 607 mcp_server_names=configured_mcp_servers, 608 ), 609 "quiet_mode": True, 610 "session_id": session_id, 611 "model": model or default_model, 612 } 613 614 try: 615 runtime = resolve_runtime_provider(requested=requested_provider or config_provider) 616 kwargs.update( 617 { 618 "provider": runtime.get("provider"), 619 "api_mode": api_mode or runtime.get("api_mode"), 620 "base_url": base_url or runtime.get("base_url"), 621 "api_key": runtime.get("api_key"), 622 "command": runtime.get("command"), 623 "args": list(runtime.get("args") or []), 624 } 625 ) 626 except Exception: 627 logger.debug("ACP session falling back to default provider resolution", exc_info=True) 628 629 _register_task_cwd(session_id, cwd) 630 agent = AIAgent(**kwargs) 631 # ACP stdio transport requires stdout to remain protocol-only JSON-RPC. 632 # Route any incidental human-readable agent output to stderr instead. 633 agent._print_fn = _acp_stderr_print 634 return agent