/ mlflow / server / assistant / api.py
api.py
  1  """
  2  Assistant API endpoints for MLflow Server.
  3  
  4  This module provides endpoints for integrating AI assistants with MLflow UI,
  5  enabling AI-powered helper through a chat interface.
  6  """
  7  
  8  import ipaddress
  9  import uuid
 10  from pathlib import Path
 11  from typing import Any, AsyncGenerator, Literal
 12  
 13  from fastapi import APIRouter, Depends, HTTPException, Request
 14  from fastapi.responses import StreamingResponse
 15  from pydantic import BaseModel, Field
 16  
 17  from mlflow.assistant import clear_project_path_cache, get_project_path
 18  from mlflow.assistant.config import AssistantConfig, PermissionsConfig, ProjectConfig
 19  from mlflow.assistant.providers.base import (
 20      CLINotInstalledError,
 21      NotAuthenticatedError,
 22      clear_config_cache,
 23  )
 24  from mlflow.assistant.providers.claude_code import ClaudeCodeProvider
 25  from mlflow.assistant.skill_installer import install_skills, list_installed_skills
 26  from mlflow.assistant.types import EventType
 27  from mlflow.server.assistant.session import SessionManager, terminate_session_process
 28  
 29  # TODO: Hardcoded provider until supporting multiple providers
 30  _provider = ClaudeCodeProvider()
 31  
 32  
 33  # Update the message when we support proxy access
 34  _BLOCK_REMOTE_ACCESS_ERROR_MSG = (
 35      "Assistant API is only accessible from the same host where the mLflow server is running."
 36  )
 37  
 38  
 39  async def _require_localhost(request: Request) -> None:
 40      """
 41      Dependency that restricts access to localhost only.
 42  
 43      Uses ipaddress library for robust loopback detection.
 44  
 45      Raises:
 46          HTTPException: If request is not from localhost
 47      """
 48      client_host = request.client.host if request.client else None
 49  
 50      if not client_host:
 51          raise HTTPException(status_code=403, detail=_BLOCK_REMOTE_ACCESS_ERROR_MSG)
 52  
 53      try:
 54          ip = ipaddress.ip_address(client_host)
 55      except ValueError:
 56          raise HTTPException(status_code=403, detail=_BLOCK_REMOTE_ACCESS_ERROR_MSG)
 57  
 58      if not ip.is_loopback:
 59          raise HTTPException(status_code=403, detail=_BLOCK_REMOTE_ACCESS_ERROR_MSG)
 60  
 61  
 62  assistant_router = APIRouter(
 63      prefix="/ajax-api/3.0/mlflow/assistant",
 64      tags=["assistant"],
 65      dependencies=[Depends(_require_localhost)],
 66  )
 67  
 68  
 69  class MessageRequest(BaseModel):
 70      message: str
 71      session_id: str | None = None  # empty for the first message
 72      experiment_id: str | None = None
 73      context: dict[str, Any] = Field(default_factory=dict)
 74  
 75  
 76  class MessageResponse(BaseModel):
 77      session_id: str
 78      stream_url: str
 79  
 80  
 81  # Config-related models
 82  class ConfigResponse(BaseModel):
 83      providers: dict[str, Any] = Field(default_factory=dict)
 84      projects: dict[str, Any] = Field(default_factory=dict)
 85  
 86  
 87  class ConfigUpdateRequest(BaseModel):
 88      providers: dict[str, Any] | None = None
 89      projects: dict[str, Any] | None = None
 90  
 91  
 92  class SessionPatchRequest(BaseModel):
 93      status: Literal["cancelled"]
 94  
 95  
 96  class SessionPatchResponse(BaseModel):
 97      message: str
 98  
 99  
100  # Skills-related models
101  class SkillsInstallRequest(BaseModel):
102      type: Literal["global", "project", "custom"] = "global"
103      custom_path: str | None = None  # Required if type="custom"
104      experiment_id: str | None = None  # Used to get project_path for type="project"
105  
106  
107  class SkillsInstallResponse(BaseModel):
108      installed_skills: list[str]
109      skills_directory: str
110  
111  
112  @assistant_router.post("/message")
113  async def send_message(request: MessageRequest) -> MessageResponse:
114      """
115      Send a message to the assistant and get a session for streaming the response.
116  
117      Args:
118          request: MessageRequest with message, context, and optional session_id
119  
120      Returns:
121          MessageResponse with session_id and stream_url
122      """
123      # Generate or use existing session ID
124      session_id = request.session_id or str(uuid.uuid4())
125  
126      project_path = get_project_path(request.experiment_id) if request.experiment_id else None
127  
128      # Create or update session
129      session = SessionManager.load(session_id)
130      if session is None:
131          session = SessionManager.create(
132              context=request.context, working_dir=Path(project_path) if project_path else None
133          )
134      elif request.context:
135          session.update_context(request.context)
136  
137      # Store the pending message with role
138      session.set_pending_message(role="user", content=request.message)
139      session.add_message(role="user", content=request.message)
140      SessionManager.save(session_id, session)
141  
142      return MessageResponse(
143          session_id=session_id,
144          stream_url=f"/ajax-api/3.0/mlflow/assistant/stream/{session_id}",
145      )
146  
147  
148  @assistant_router.get("/sessions/{session_id}/stream")
149  async def stream_response(request: Request, session_id: str) -> StreamingResponse:
150      """
151      Stream the assistant's response via Server-Sent Events.
152  
153      Args:
154          request: The FastAPI request object
155          session_id: The session ID returned from /message
156  
157      Returns:
158          StreamingResponse with SSE events
159      """
160      session = SessionManager.load(session_id)
161      if session is None:
162          raise HTTPException(status_code=404, detail="Session not found")
163  
164      # Get and clear the pending message
165      pending_message = session.clear_pending_message()
166      if not pending_message:
167          raise HTTPException(status_code=400, detail="No pending message to process")
168      SessionManager.save(session_id, session)
169  
170      # Extract the MLflow server URL from the request for the assistant to use.
171      # This assumes the assistant is accessing the same MLflow server that serves this API,
172      # which works because the assistant endpoint is localhost-only.
173      # TODO: Extend this to support remote/proxy scenarios where the tracking URI may differ.
174      tracking_uri = str(request.base_url).rstrip("/")
175  
176      async def event_generator() -> AsyncGenerator[str, None]:
177          nonlocal session
178          async for event in _provider.astream(
179              prompt=pending_message.content,
180              tracking_uri=tracking_uri,
181              session_id=session.provider_session_id,
182              mlflow_session_id=session_id,
183              cwd=session.working_dir,
184              context=session.context,
185          ):
186              # Store provider session ID if returned (for conversation continuity)
187              if event.type == EventType.DONE:
188                  session.provider_session_id = event.data.get("session_id")
189                  SessionManager.save(session_id, session)
190  
191              yield event.to_sse_event()
192  
193      return StreamingResponse(
194          event_generator(),
195          media_type="text/event-stream",
196          headers={
197              "Cache-Control": "no-cache",
198              "Connection": "keep-alive",
199              "X-Accel-Buffering": "no",
200          },
201      )
202  
203  
204  @assistant_router.patch("/sessions/{session_id}")
205  async def patch_session(session_id: str, request: SessionPatchRequest) -> SessionPatchResponse:
206      """
207      Update session status.
208  
209      Currently supports cancelling an active session, which terminates
210      the running assistant process.
211  
212      Args:
213          session_id: The session ID
214          request: SessionPatchRequest with status to set
215  
216      Returns:
217          SessionPatchResponse indicating success
218      """
219      session = SessionManager.load(session_id)
220      if session is None:
221          raise HTTPException(status_code=404, detail="Session not found")
222  
223      if request.status == "cancelled":
224          terminated = terminate_session_process(session_id)
225          msg = "Session cancelled and process terminated" if terminated else "Session cancelled"
226          return SessionPatchResponse(message=msg)
227  
228      # This branch is unreachable due to Literal type, but satisfies type checker
229      raise HTTPException(status_code=400, detail=f"Unknown status: {request.status}")
230  
231  
232  @assistant_router.get("/providers/{provider}/health")
233  async def provider_health_check(provider: str) -> dict[str, str]:
234      """
235      Check if a specific provider is ready (CLI installed and authenticated).
236  
237      Args:
238          provider: The provider name (e.g., "claude_code").
239  
240      Returns:
241          200 with { status: "ok" } if ready.
242  
243      Raises:
244          HTTPException 404: If provider is not found.
245          HTTPException 412: If preconditions not met (CLI not installed or not authenticated).
246      """
247      # TODO: Support multiple providers via registry
248      if provider != _provider.name:
249          raise HTTPException(status_code=404, detail=f"Provider '{provider}' not found")
250  
251      try:
252          _provider.check_connection()
253      except CLINotInstalledError as e:
254          raise HTTPException(status_code=412, detail=str(e))
255      except NotAuthenticatedError as e:
256          raise HTTPException(status_code=401, detail=str(e))
257  
258      return {"status": "ok"}
259  
260  
261  @assistant_router.get("/config")
262  async def get_config() -> ConfigResponse:
263      """
264      Get the current assistant configuration.
265  
266      Returns:
267          Current configuration including providers and projects.
268      """
269      config = AssistantConfig.load()
270      return ConfigResponse(
271          providers={name: p.model_dump() for name, p in config.providers.items()},
272          projects={exp_id: p.model_dump() for exp_id, p in config.projects.items()},
273      )
274  
275  
276  @assistant_router.put("/config")
277  async def update_config(request: ConfigUpdateRequest) -> ConfigResponse:
278      """
279      Update the assistant configuration.
280  
281      Args:
282          request: Partial configuration update.
283  
284      Returns:
285          Updated configuration.
286      """
287      config = AssistantConfig.load()
288  
289      # Update providers
290      if request.providers:
291          for name, provider_data in request.providers.items():
292              model = provider_data.get("model", "default")
293              permissions = None
294              if "permissions" in provider_data:
295                  perm_data = provider_data["permissions"]
296                  permissions = PermissionsConfig(
297                      allow_edit_files=perm_data.get("allow_edit_files", True),
298                      allow_read_docs=perm_data.get("allow_read_docs", True),
299                      full_access=perm_data.get("full_access", False),
300                  )
301              config.set_provider(name, model, permissions)
302  
303      # Update projects
304      if request.projects:
305          for exp_id, project_data in request.projects.items():
306              if project_data is None:
307                  # Remove project mapping
308                  config.projects.pop(exp_id, None)
309              else:
310                  location = project_data.get("location", "")
311                  project_path = Path(location).expanduser()
312                  if not project_path or not project_path.exists():
313                      raise HTTPException(
314                          status_code=400,
315                          detail=f"Project path does not exist: {location}",
316                      )
317                  config.projects[exp_id] = ProjectConfig(
318                      type=project_data.get("type", "local"),
319                      location=str(project_path),
320                  )
321  
322      config.save()
323  
324      # Clear caches so provider and project path lookups pick up new settings
325      clear_config_cache()
326      clear_project_path_cache()
327  
328      return ConfigResponse(
329          providers={name: p.model_dump() for name, p in config.providers.items()},
330          projects={exp_id: p.model_dump() for exp_id, p in config.projects.items()},
331      )
332  
333  
334  @assistant_router.post("/skills/install")
335  async def install_skills_endpoint(request: SkillsInstallRequest) -> SkillsInstallResponse:
336      """
337      Install skills bundled with MLflow.
338      This endpoint only handles installation. Config updates should be done via PUT /config.
339  
340      Args:
341          request: SkillsInstallRequest with type, custom_path, and experiment_id.
342  
343      Returns:
344          SkillsInstallResponse with installed skill names and directory.
345  
346      Raises:
347          HTTPException 400: If custom type without custom_path or project type without experiment_id.
348      """
349      config = AssistantConfig.load()
350  
351      # Resolve project_path for "project" type
352      project_path: Path | None = None
353      if request.type == "project":
354          if not request.experiment_id:
355              raise HTTPException(status_code=400, detail="experiment_id required for 'project' type")
356          project_location = config.get_project_path(request.experiment_id)
357          if not project_location:
358              raise HTTPException(
359                  status_code=400,
360                  detail=f"No project path configured for experiment {request.experiment_id}",
361              )
362          project_path = Path(project_location)
363  
364      # Get the destination path to install skills to
365      match request.type:
366          case "global":
367              destination = _provider.resolve_skills_path(Path.home())
368          case "project":
369              destination = _provider.resolve_skills_path(project_path)
370          case "custom":
371              destination = Path(request.custom_path).expanduser()
372  
373      # Check if skills already exist - skip re-installation
374      if destination.exists():
375          if current_skills := list_installed_skills(destination):
376              return SkillsInstallResponse(
377                  installed_skills=current_skills, skills_directory=str(destination)
378              )
379  
380      installed = install_skills(destination)
381  
382      return SkillsInstallResponse(installed_skills=installed, skills_directory=str(destination))