/ mlflow / server / assistant / session.py
session.py
  1  import json
  2  import os
  3  import signal
  4  import tempfile
  5  import uuid
  6  from dataclasses import dataclass, field
  7  from pathlib import Path
  8  from typing import Any
  9  
 10  from mlflow.assistant.types import Message
 11  
 12  SESSION_DIR = Path(tempfile.gettempdir()) / "mlflow-assistant-sessions"
 13  
 14  
 15  @dataclass
 16  class Session:
 17      """Session state for assistant conversations."""
 18  
 19      context: dict[str, Any] = field(default_factory=dict)
 20      messages: list[Message] = field(default_factory=list)
 21      pending_message: Message | None = None
 22      provider_session_id: str | None = None
 23      working_dir: Path | None = None  # Working directory for the session (e.g. project path)
 24  
 25      def add_message(self, role: str, content: str) -> None:
 26          """Add a message to the session history.
 27  
 28          Args:
 29              role: Role of the message sender (user, assistant, system)
 30              content: Text content of the message
 31          """
 32          self.messages.append(Message(role=role, content=content))
 33  
 34      def set_pending_message(self, role: str, content: str) -> None:
 35          """Set the pending message to be processed.
 36  
 37          Args:
 38              role: Role of the message sender
 39              content: Text content of the message
 40          """
 41          self.pending_message = Message(role=role, content=content)
 42  
 43      def clear_pending_message(self) -> Message | None:
 44          """Clear and return the pending message.
 45  
 46          Returns:
 47              The pending message, or None if no message was pending
 48          """
 49          msg = self.pending_message
 50          self.pending_message = None
 51          return msg
 52  
 53      def update_context(self, context: dict[str, Any]) -> None:
 54          """Update session context.
 55  
 56          Args:
 57              context: Context data to merge into session context
 58          """
 59          self.context.update(context)
 60  
 61      def to_dict(self) -> dict[str, Any]:
 62          """Convert to dictionary for serialization.
 63  
 64          Returns:
 65              Dictionary representation of session
 66          """
 67          return {
 68              "context": self.context,
 69              "messages": [msg.model_dump() for msg in self.messages],
 70              "pending_message": self.pending_message.model_dump() if self.pending_message else None,
 71              "provider_session_id": self.provider_session_id,
 72              "working_dir": self.working_dir.as_posix() if self.working_dir else None,
 73          }
 74  
 75      @classmethod
 76      def from_dict(cls, data: dict[str, Any]) -> "Session":
 77          """Load from dictionary.
 78  
 79          Args:
 80              data: Dictionary representation of session
 81  
 82          Returns:
 83              Session instance
 84          """
 85          messages = [Message.model_validate(m) for m in data.get("messages", [])]
 86          pending = data.get("pending_message")
 87          pending_msg = Message.model_validate(pending) if pending else None
 88  
 89          return cls(
 90              context=data.get("context", {}),
 91              messages=messages,
 92              pending_message=pending_msg,
 93              provider_session_id=data.get("provider_session_id"),
 94              working_dir=Path(data.get("working_dir")) if data.get("working_dir") else None,
 95          )
 96  
 97  
 98  class SessionManager:
 99      """Manages session storage and retrieval.
100  
101      Provides static methods for session operations, keeping
102      Session as a simple data container.
103      """
104  
105      @staticmethod
106      def validate_session_id(session_id: str) -> None:
107          """Validate that session_id is a valid UUID to prevent path traversal.
108  
109          Args:
110              session_id: Session ID to validate
111  
112          Raises:
113              ValueError: If session ID is not a valid UUID
114          """
115          try:
116              uuid.UUID(session_id)
117          except (ValueError, TypeError) as e:
118              raise ValueError("Invalid session ID format") from e
119  
120      @staticmethod
121      def get_session_file(session_id: str) -> Path:
122          """Get the file path for a session.
123  
124          Args:
125              session_id: Session ID
126  
127          Returns:
128              Path to session file
129  
130          Raises:
131              ValueError: If session ID is invalid
132          """
133          SessionManager.validate_session_id(session_id)
134          return SESSION_DIR / f"{session_id}.json"
135  
136      @staticmethod
137      def save(session_id: str, session: Session) -> None:
138          """Save session to disk atomically.
139  
140          Args:
141              session_id: Session ID
142              session: Session to save
143  
144          Raises:
145              ValueError: If session ID is invalid
146          """
147          SessionManager.validate_session_id(session_id)
148          SESSION_DIR.mkdir(parents=True, exist_ok=True)
149          session_file = SessionManager.get_session_file(session_id)
150  
151          # Write to temp file, then rename (atomic on POSIX)
152          fd, temp_path = tempfile.mkstemp(dir=SESSION_DIR, suffix=".tmp")
153          try:
154              with os.fdopen(fd, "w") as f:
155                  json.dump(session.to_dict(), f)
156              os.replace(temp_path, session_file)
157          except Exception:
158              os.unlink(temp_path)
159              raise
160  
161      @staticmethod
162      def load(session_id: str) -> Session | None:
163          """Load session from disk. Returns a Session instance, or None if not found"""
164          try:
165              session_file = SessionManager.get_session_file(session_id)
166          except ValueError:
167              return None
168          if not session_file.exists():
169              return None
170          data = json.loads(session_file.read_text())
171          return Session.from_dict(data)
172  
173      @staticmethod
174      def create(context: dict[str, Any] | None = None, working_dir: Path | None = None) -> Session:
175          """Create a new session.
176  
177          Args:
178              context: Initial context data, or None
179              working_dir: Working directory for the session
180  
181          Returns:
182              New Session instance
183          """
184          return Session(context=context or {}, working_dir=working_dir)
185  
186  
187  def get_process_file(session_id: str) -> Path:
188      """Get the file path for storing process PID."""
189      SessionManager.validate_session_id(session_id)
190      return SESSION_DIR / f"{session_id}.process.json"
191  
192  
193  def save_process_pid(session_id: str, pid: int) -> None:
194      """Save process PID to file for cancellation support."""
195      SESSION_DIR.mkdir(parents=True, exist_ok=True)
196      process_file = get_process_file(session_id)
197      process_file.write_text(json.dumps({"pid": pid}))
198  
199  
200  def get_process_pid(session_id: str) -> int | None:
201      try:
202          process_file = get_process_file(session_id)
203      except ValueError:
204          return None
205      if not process_file.exists():
206          return None
207      data = json.loads(process_file.read_text())
208      return data.get("pid")
209  
210  
211  def clear_process_pid(session_id: str) -> None:
212      try:
213          process_file = get_process_file(session_id)
214      except ValueError:
215          return
216      if process_file.exists():
217          process_file.unlink()
218  
219  
220  def terminate_session_process(session_id: str) -> bool:
221      """Terminate the process associated with a session.
222  
223      Args:
224          session_id: Session ID
225  
226      Returns:
227          True if process was terminated, False otherwise
228      """
229      if pid := get_process_pid(session_id):
230          try:
231              os.kill(pid, signal.SIGTERM)
232              clear_process_pid(session_id)
233              return True
234          except ProcessLookupError:
235              clear_process_pid(session_id)
236      return False