session.py
1 import json 2 import os 3 import signal 4 import tempfile 5 import uuid 6 from dataclasses import dataclass, field 7 from pathlib import Path 8 from typing import Any 9 10 from mlflow.assistant.types import Message 11 12 SESSION_DIR = Path(tempfile.gettempdir()) / "mlflow-assistant-sessions" 13 14 15 @dataclass 16 class Session: 17 """Session state for assistant conversations.""" 18 19 context: dict[str, Any] = field(default_factory=dict) 20 messages: list[Message] = field(default_factory=list) 21 pending_message: Message | None = None 22 provider_session_id: str | None = None 23 working_dir: Path | None = None # Working directory for the session (e.g. project path) 24 25 def add_message(self, role: str, content: str) -> None: 26 """Add a message to the session history. 27 28 Args: 29 role: Role of the message sender (user, assistant, system) 30 content: Text content of the message 31 """ 32 self.messages.append(Message(role=role, content=content)) 33 34 def set_pending_message(self, role: str, content: str) -> None: 35 """Set the pending message to be processed. 36 37 Args: 38 role: Role of the message sender 39 content: Text content of the message 40 """ 41 self.pending_message = Message(role=role, content=content) 42 43 def clear_pending_message(self) -> Message | None: 44 """Clear and return the pending message. 45 46 Returns: 47 The pending message, or None if no message was pending 48 """ 49 msg = self.pending_message 50 self.pending_message = None 51 return msg 52 53 def update_context(self, context: dict[str, Any]) -> None: 54 """Update session context. 55 56 Args: 57 context: Context data to merge into session context 58 """ 59 self.context.update(context) 60 61 def to_dict(self) -> dict[str, Any]: 62 """Convert to dictionary for serialization. 63 64 Returns: 65 Dictionary representation of session 66 """ 67 return { 68 "context": self.context, 69 "messages": [msg.model_dump() for msg in self.messages], 70 "pending_message": self.pending_message.model_dump() if self.pending_message else None, 71 "provider_session_id": self.provider_session_id, 72 "working_dir": self.working_dir.as_posix() if self.working_dir else None, 73 } 74 75 @classmethod 76 def from_dict(cls, data: dict[str, Any]) -> "Session": 77 """Load from dictionary. 78 79 Args: 80 data: Dictionary representation of session 81 82 Returns: 83 Session instance 84 """ 85 messages = [Message.model_validate(m) for m in data.get("messages", [])] 86 pending = data.get("pending_message") 87 pending_msg = Message.model_validate(pending) if pending else None 88 89 return cls( 90 context=data.get("context", {}), 91 messages=messages, 92 pending_message=pending_msg, 93 provider_session_id=data.get("provider_session_id"), 94 working_dir=Path(data.get("working_dir")) if data.get("working_dir") else None, 95 ) 96 97 98 class SessionManager: 99 """Manages session storage and retrieval. 100 101 Provides static methods for session operations, keeping 102 Session as a simple data container. 103 """ 104 105 @staticmethod 106 def validate_session_id(session_id: str) -> None: 107 """Validate that session_id is a valid UUID to prevent path traversal. 108 109 Args: 110 session_id: Session ID to validate 111 112 Raises: 113 ValueError: If session ID is not a valid UUID 114 """ 115 try: 116 uuid.UUID(session_id) 117 except (ValueError, TypeError) as e: 118 raise ValueError("Invalid session ID format") from e 119 120 @staticmethod 121 def get_session_file(session_id: str) -> Path: 122 """Get the file path for a session. 123 124 Args: 125 session_id: Session ID 126 127 Returns: 128 Path to session file 129 130 Raises: 131 ValueError: If session ID is invalid 132 """ 133 SessionManager.validate_session_id(session_id) 134 return SESSION_DIR / f"{session_id}.json" 135 136 @staticmethod 137 def save(session_id: str, session: Session) -> None: 138 """Save session to disk atomically. 139 140 Args: 141 session_id: Session ID 142 session: Session to save 143 144 Raises: 145 ValueError: If session ID is invalid 146 """ 147 SessionManager.validate_session_id(session_id) 148 SESSION_DIR.mkdir(parents=True, exist_ok=True) 149 session_file = SessionManager.get_session_file(session_id) 150 151 # Write to temp file, then rename (atomic on POSIX) 152 fd, temp_path = tempfile.mkstemp(dir=SESSION_DIR, suffix=".tmp") 153 try: 154 with os.fdopen(fd, "w") as f: 155 json.dump(session.to_dict(), f) 156 os.replace(temp_path, session_file) 157 except Exception: 158 os.unlink(temp_path) 159 raise 160 161 @staticmethod 162 def load(session_id: str) -> Session | None: 163 """Load session from disk. Returns a Session instance, or None if not found""" 164 try: 165 session_file = SessionManager.get_session_file(session_id) 166 except ValueError: 167 return None 168 if not session_file.exists(): 169 return None 170 data = json.loads(session_file.read_text()) 171 return Session.from_dict(data) 172 173 @staticmethod 174 def create(context: dict[str, Any] | None = None, working_dir: Path | None = None) -> Session: 175 """Create a new session. 176 177 Args: 178 context: Initial context data, or None 179 working_dir: Working directory for the session 180 181 Returns: 182 New Session instance 183 """ 184 return Session(context=context or {}, working_dir=working_dir) 185 186 187 def get_process_file(session_id: str) -> Path: 188 """Get the file path for storing process PID.""" 189 SessionManager.validate_session_id(session_id) 190 return SESSION_DIR / f"{session_id}.process.json" 191 192 193 def save_process_pid(session_id: str, pid: int) -> None: 194 """Save process PID to file for cancellation support.""" 195 SESSION_DIR.mkdir(parents=True, exist_ok=True) 196 process_file = get_process_file(session_id) 197 process_file.write_text(json.dumps({"pid": pid})) 198 199 200 def get_process_pid(session_id: str) -> int | None: 201 try: 202 process_file = get_process_file(session_id) 203 except ValueError: 204 return None 205 if not process_file.exists(): 206 return None 207 data = json.loads(process_file.read_text()) 208 return data.get("pid") 209 210 211 def clear_process_pid(session_id: str) -> None: 212 try: 213 process_file = get_process_file(session_id) 214 except ValueError: 215 return 216 if process_file.exists(): 217 process_file.unlink() 218 219 220 def terminate_session_process(session_id: str) -> bool: 221 """Terminate the process associated with a session. 222 223 Args: 224 session_id: Session ID 225 226 Returns: 227 True if process was terminated, False otherwise 228 """ 229 if pid := get_process_pid(session_id): 230 try: 231 os.kill(pid, signal.SIGTERM) 232 clear_process_pid(session_id) 233 return True 234 except ProcessLookupError: 235 clear_process_pid(session_id) 236 return False