api.py
1 """ 2 Assistant API endpoints for MLflow Server. 3 4 This module provides endpoints for integrating AI assistants with MLflow UI, 5 enabling AI-powered helper through a chat interface. 6 """ 7 8 import ipaddress 9 import uuid 10 from pathlib import Path 11 from typing import Any, AsyncGenerator, Literal 12 13 from fastapi import APIRouter, Depends, HTTPException, Request 14 from fastapi.responses import StreamingResponse 15 from pydantic import BaseModel, Field 16 17 from mlflow.assistant import clear_project_path_cache, get_project_path 18 from mlflow.assistant.config import AssistantConfig, PermissionsConfig, ProjectConfig 19 from mlflow.assistant.providers.base import ( 20 CLINotInstalledError, 21 NotAuthenticatedError, 22 clear_config_cache, 23 ) 24 from mlflow.assistant.providers.claude_code import ClaudeCodeProvider 25 from mlflow.assistant.skill_installer import install_skills, list_installed_skills 26 from mlflow.assistant.types import EventType 27 from mlflow.server.assistant.session import SessionManager, terminate_session_process 28 29 # TODO: Hardcoded provider until supporting multiple providers 30 _provider = ClaudeCodeProvider() 31 32 33 # Update the message when we support proxy access 34 _BLOCK_REMOTE_ACCESS_ERROR_MSG = ( 35 "Assistant API is only accessible from the same host where the mLflow server is running." 36 ) 37 38 39 async def _require_localhost(request: Request) -> None: 40 """ 41 Dependency that restricts access to localhost only. 42 43 Uses ipaddress library for robust loopback detection. 44 45 Raises: 46 HTTPException: If request is not from localhost 47 """ 48 client_host = request.client.host if request.client else None 49 50 if not client_host: 51 raise HTTPException(status_code=403, detail=_BLOCK_REMOTE_ACCESS_ERROR_MSG) 52 53 try: 54 ip = ipaddress.ip_address(client_host) 55 except ValueError: 56 raise HTTPException(status_code=403, detail=_BLOCK_REMOTE_ACCESS_ERROR_MSG) 57 58 if not ip.is_loopback: 59 raise HTTPException(status_code=403, detail=_BLOCK_REMOTE_ACCESS_ERROR_MSG) 60 61 62 assistant_router = APIRouter( 63 prefix="/ajax-api/3.0/mlflow/assistant", 64 tags=["assistant"], 65 dependencies=[Depends(_require_localhost)], 66 ) 67 68 69 class MessageRequest(BaseModel): 70 message: str 71 session_id: str | None = None # empty for the first message 72 experiment_id: str | None = None 73 context: dict[str, Any] = Field(default_factory=dict) 74 75 76 class MessageResponse(BaseModel): 77 session_id: str 78 stream_url: str 79 80 81 # Config-related models 82 class ConfigResponse(BaseModel): 83 providers: dict[str, Any] = Field(default_factory=dict) 84 projects: dict[str, Any] = Field(default_factory=dict) 85 86 87 class ConfigUpdateRequest(BaseModel): 88 providers: dict[str, Any] | None = None 89 projects: dict[str, Any] | None = None 90 91 92 class SessionPatchRequest(BaseModel): 93 status: Literal["cancelled"] 94 95 96 class SessionPatchResponse(BaseModel): 97 message: str 98 99 100 # Skills-related models 101 class SkillsInstallRequest(BaseModel): 102 type: Literal["global", "project", "custom"] = "global" 103 custom_path: str | None = None # Required if type="custom" 104 experiment_id: str | None = None # Used to get project_path for type="project" 105 106 107 class SkillsInstallResponse(BaseModel): 108 installed_skills: list[str] 109 skills_directory: str 110 111 112 @assistant_router.post("/message") 113 async def send_message(request: MessageRequest) -> MessageResponse: 114 """ 115 Send a message to the assistant and get a session for streaming the response. 116 117 Args: 118 request: MessageRequest with message, context, and optional session_id 119 120 Returns: 121 MessageResponse with session_id and stream_url 122 """ 123 # Generate or use existing session ID 124 session_id = request.session_id or str(uuid.uuid4()) 125 126 project_path = get_project_path(request.experiment_id) if request.experiment_id else None 127 128 # Create or update session 129 session = SessionManager.load(session_id) 130 if session is None: 131 session = SessionManager.create( 132 context=request.context, working_dir=Path(project_path) if project_path else None 133 ) 134 elif request.context: 135 session.update_context(request.context) 136 137 # Store the pending message with role 138 session.set_pending_message(role="user", content=request.message) 139 session.add_message(role="user", content=request.message) 140 SessionManager.save(session_id, session) 141 142 return MessageResponse( 143 session_id=session_id, 144 stream_url=f"/ajax-api/3.0/mlflow/assistant/stream/{session_id}", 145 ) 146 147 148 @assistant_router.get("/sessions/{session_id}/stream") 149 async def stream_response(request: Request, session_id: str) -> StreamingResponse: 150 """ 151 Stream the assistant's response via Server-Sent Events. 152 153 Args: 154 request: The FastAPI request object 155 session_id: The session ID returned from /message 156 157 Returns: 158 StreamingResponse with SSE events 159 """ 160 session = SessionManager.load(session_id) 161 if session is None: 162 raise HTTPException(status_code=404, detail="Session not found") 163 164 # Get and clear the pending message 165 pending_message = session.clear_pending_message() 166 if not pending_message: 167 raise HTTPException(status_code=400, detail="No pending message to process") 168 SessionManager.save(session_id, session) 169 170 # Extract the MLflow server URL from the request for the assistant to use. 171 # This assumes the assistant is accessing the same MLflow server that serves this API, 172 # which works because the assistant endpoint is localhost-only. 173 # TODO: Extend this to support remote/proxy scenarios where the tracking URI may differ. 174 tracking_uri = str(request.base_url).rstrip("/") 175 176 async def event_generator() -> AsyncGenerator[str, None]: 177 nonlocal session 178 async for event in _provider.astream( 179 prompt=pending_message.content, 180 tracking_uri=tracking_uri, 181 session_id=session.provider_session_id, 182 mlflow_session_id=session_id, 183 cwd=session.working_dir, 184 context=session.context, 185 ): 186 # Store provider session ID if returned (for conversation continuity) 187 if event.type == EventType.DONE: 188 session.provider_session_id = event.data.get("session_id") 189 SessionManager.save(session_id, session) 190 191 yield event.to_sse_event() 192 193 return StreamingResponse( 194 event_generator(), 195 media_type="text/event-stream", 196 headers={ 197 "Cache-Control": "no-cache", 198 "Connection": "keep-alive", 199 "X-Accel-Buffering": "no", 200 }, 201 ) 202 203 204 @assistant_router.patch("/sessions/{session_id}") 205 async def patch_session(session_id: str, request: SessionPatchRequest) -> SessionPatchResponse: 206 """ 207 Update session status. 208 209 Currently supports cancelling an active session, which terminates 210 the running assistant process. 211 212 Args: 213 session_id: The session ID 214 request: SessionPatchRequest with status to set 215 216 Returns: 217 SessionPatchResponse indicating success 218 """ 219 session = SessionManager.load(session_id) 220 if session is None: 221 raise HTTPException(status_code=404, detail="Session not found") 222 223 if request.status == "cancelled": 224 terminated = terminate_session_process(session_id) 225 msg = "Session cancelled and process terminated" if terminated else "Session cancelled" 226 return SessionPatchResponse(message=msg) 227 228 # This branch is unreachable due to Literal type, but satisfies type checker 229 raise HTTPException(status_code=400, detail=f"Unknown status: {request.status}") 230 231 232 @assistant_router.get("/providers/{provider}/health") 233 async def provider_health_check(provider: str) -> dict[str, str]: 234 """ 235 Check if a specific provider is ready (CLI installed and authenticated). 236 237 Args: 238 provider: The provider name (e.g., "claude_code"). 239 240 Returns: 241 200 with { status: "ok" } if ready. 242 243 Raises: 244 HTTPException 404: If provider is not found. 245 HTTPException 412: If preconditions not met (CLI not installed or not authenticated). 246 """ 247 # TODO: Support multiple providers via registry 248 if provider != _provider.name: 249 raise HTTPException(status_code=404, detail=f"Provider '{provider}' not found") 250 251 try: 252 _provider.check_connection() 253 except CLINotInstalledError as e: 254 raise HTTPException(status_code=412, detail=str(e)) 255 except NotAuthenticatedError as e: 256 raise HTTPException(status_code=401, detail=str(e)) 257 258 return {"status": "ok"} 259 260 261 @assistant_router.get("/config") 262 async def get_config() -> ConfigResponse: 263 """ 264 Get the current assistant configuration. 265 266 Returns: 267 Current configuration including providers and projects. 268 """ 269 config = AssistantConfig.load() 270 return ConfigResponse( 271 providers={name: p.model_dump() for name, p in config.providers.items()}, 272 projects={exp_id: p.model_dump() for exp_id, p in config.projects.items()}, 273 ) 274 275 276 @assistant_router.put("/config") 277 async def update_config(request: ConfigUpdateRequest) -> ConfigResponse: 278 """ 279 Update the assistant configuration. 280 281 Args: 282 request: Partial configuration update. 283 284 Returns: 285 Updated configuration. 286 """ 287 config = AssistantConfig.load() 288 289 # Update providers 290 if request.providers: 291 for name, provider_data in request.providers.items(): 292 model = provider_data.get("model", "default") 293 permissions = None 294 if "permissions" in provider_data: 295 perm_data = provider_data["permissions"] 296 permissions = PermissionsConfig( 297 allow_edit_files=perm_data.get("allow_edit_files", True), 298 allow_read_docs=perm_data.get("allow_read_docs", True), 299 full_access=perm_data.get("full_access", False), 300 ) 301 config.set_provider(name, model, permissions) 302 303 # Update projects 304 if request.projects: 305 for exp_id, project_data in request.projects.items(): 306 if project_data is None: 307 # Remove project mapping 308 config.projects.pop(exp_id, None) 309 else: 310 location = project_data.get("location", "") 311 project_path = Path(location).expanduser() 312 if not project_path or not project_path.exists(): 313 raise HTTPException( 314 status_code=400, 315 detail=f"Project path does not exist: {location}", 316 ) 317 config.projects[exp_id] = ProjectConfig( 318 type=project_data.get("type", "local"), 319 location=str(project_path), 320 ) 321 322 config.save() 323 324 # Clear caches so provider and project path lookups pick up new settings 325 clear_config_cache() 326 clear_project_path_cache() 327 328 return ConfigResponse( 329 providers={name: p.model_dump() for name, p in config.providers.items()}, 330 projects={exp_id: p.model_dump() for exp_id, p in config.projects.items()}, 331 ) 332 333 334 @assistant_router.post("/skills/install") 335 async def install_skills_endpoint(request: SkillsInstallRequest) -> SkillsInstallResponse: 336 """ 337 Install skills bundled with MLflow. 338 This endpoint only handles installation. Config updates should be done via PUT /config. 339 340 Args: 341 request: SkillsInstallRequest with type, custom_path, and experiment_id. 342 343 Returns: 344 SkillsInstallResponse with installed skill names and directory. 345 346 Raises: 347 HTTPException 400: If custom type without custom_path or project type without experiment_id. 348 """ 349 config = AssistantConfig.load() 350 351 # Resolve project_path for "project" type 352 project_path: Path | None = None 353 if request.type == "project": 354 if not request.experiment_id: 355 raise HTTPException(status_code=400, detail="experiment_id required for 'project' type") 356 project_location = config.get_project_path(request.experiment_id) 357 if not project_location: 358 raise HTTPException( 359 status_code=400, 360 detail=f"No project path configured for experiment {request.experiment_id}", 361 ) 362 project_path = Path(project_location) 363 364 # Get the destination path to install skills to 365 match request.type: 366 case "global": 367 destination = _provider.resolve_skills_path(Path.home()) 368 case "project": 369 destination = _provider.resolve_skills_path(project_path) 370 case "custom": 371 destination = Path(request.custom_path).expanduser() 372 373 # Check if skills already exist - skip re-installation 374 if destination.exists(): 375 if current_skills := list_installed_skills(destination): 376 return SkillsInstallResponse( 377 installed_skills=current_skills, skills_directory=str(destination) 378 ) 379 380 installed = install_skills(destination) 381 382 return SkillsInstallResponse(installed_skills=installed, skills_directory=str(destination))