test_api.py
1 import os 2 import shutil 3 import subprocess 4 from pathlib import Path 5 from typing import Any 6 from unittest.mock import MagicMock, patch 7 8 import pytest 9 from fastapi import FastAPI, HTTPException 10 from fastapi.testclient import TestClient 11 12 from mlflow.assistant.config import AssistantConfig, ProjectConfig 13 from mlflow.assistant.config import ProviderConfig as AssistantProviderConfig 14 from mlflow.assistant.providers.base import ( 15 AssistantProvider, 16 CLINotInstalledError, 17 NotAuthenticatedError, 18 ProviderConfig, 19 ) 20 from mlflow.assistant.types import Event, Message 21 from mlflow.server.assistant.api import _require_localhost, assistant_router 22 from mlflow.server.assistant.session import SESSION_DIR, SessionManager, save_process_pid 23 from mlflow.utils.os import is_windows 24 25 26 class MockProvider(AssistantProvider): 27 """Mock provider for testing.""" 28 29 @property 30 def name(self) -> str: 31 return "mock_provider" 32 33 @property 34 def display_name(self) -> str: 35 return "Mock Provider" 36 37 @property 38 def description(self) -> str: 39 return "Mock provider for testing" 40 41 @property 42 def config_path(self) -> Path: 43 return Path.home() / ".mlflow" / "assistant" / "mock-config.json" 44 45 def is_available(self) -> bool: 46 return True 47 48 def load_config(self) -> ProviderConfig: 49 return ProviderConfig() 50 51 def check_connection(self, echo=print) -> None: 52 pass 53 54 def resolve_skills_path(self, base_directory: Path) -> Path: 55 return base_directory / ".mock" / "skills" 56 57 async def astream( 58 self, 59 prompt: str, 60 tracking_uri: str, 61 session_id: str | None = None, 62 cwd: Path | None = None, 63 context: dict[str, Any] | None = None, 64 mlflow_session_id: str | None = None, 65 ): 66 yield Event.from_message(message=Message(role="user", content="Hello from mock")) 67 yield Event.from_result(result="complete", session_id="mock-session-123") 68 69 70 @pytest.fixture(autouse=True) 71 def isolated_config(tmp_path, monkeypatch): 72 """Redirect config to tmp_path to avoid modifying real user config.""" 73 import mlflow.assistant.config as config_module 74 75 config_home = tmp_path / ".mlflow" / "assistant" 76 config_path = config_home / "config.json" 77 78 monkeypatch.setattr(config_module, "MLFLOW_ASSISTANT_HOME", config_home) 79 monkeypatch.setattr(config_module, "CONFIG_PATH", config_path) 80 81 return config_home 82 83 84 @pytest.fixture(autouse=True) 85 def clear_sessions(): 86 """Clear session storage before each test.""" 87 if SESSION_DIR.exists(): 88 shutil.rmtree(SESSION_DIR) 89 yield 90 if SESSION_DIR.exists(): 91 shutil.rmtree(SESSION_DIR) 92 93 94 @pytest.fixture 95 def client(): 96 """Create test client with mock provider and bypassed localhost check.""" 97 app = FastAPI() 98 app.include_router(assistant_router) 99 100 # Override localhost dependency to allow TestClient requests 101 async def mock_require_localhost(): 102 pass 103 104 app.dependency_overrides[_require_localhost] = mock_require_localhost 105 106 with patch("mlflow.server.assistant.api._provider", MockProvider()): 107 yield TestClient(app) 108 109 110 def test_message(client): 111 response = client.post( 112 "/ajax-api/3.0/mlflow/assistant/message", 113 json={ 114 "message": "Hello", 115 "context": {"trace_id": "tr-123", "experiment_id": "exp-456"}, 116 }, 117 ) 118 119 assert response.status_code == 200 120 data = response.json() 121 session_id = data["session_id"] 122 assert session_id is not None 123 assert data["stream_url"] == f"/ajax-api/3.0/mlflow/assistant/stream/{data['session_id']}" 124 125 # continue the conversation 126 response = client.post( 127 "/ajax-api/3.0/mlflow/assistant/message", 128 json={"message": "Second message", "session_id": session_id}, 129 ) 130 131 assert response.status_code == 200 132 assert response.json()["session_id"] == session_id 133 134 135 def test_stream_not_found_for_invalid_session(client): 136 response = client.get("/ajax-api/3.0/mlflow/assistant/sessions/invalid-session-id/stream") 137 assert response.status_code == 404 138 assert "Session not found" in response.json()["detail"] 139 140 141 def test_stream_bad_request_when_no_pending_message(client): 142 # Create session and consume the pending message 143 r = client.post("/ajax-api/3.0/mlflow/assistant/message", json={"message": "Hi"}) 144 session_id = r.json()["session_id"] 145 client.get(f"/ajax-api/3.0/mlflow/assistant/sessions/{session_id}/stream") 146 147 # Try to stream again without a new message 148 response = client.get(f"/ajax-api/3.0/mlflow/assistant/sessions/{session_id}/stream") 149 150 assert response.status_code == 400 151 assert "No pending message" in response.json()["detail"] 152 153 154 def test_stream_returns_sse_events(client): 155 r = client.post("/ajax-api/3.0/mlflow/assistant/message", json={"message": "Hi"}) 156 session_id = r.json()["session_id"] 157 158 response = client.get(f"/ajax-api/3.0/mlflow/assistant/sessions/{session_id}/stream") 159 160 assert response.status_code == 200 161 assert "text/event-stream" in response.headers["content-type"] 162 163 content = response.text 164 assert "event: message" in content 165 assert "event: done" in content 166 assert "Hello from mock" in content 167 168 169 def test_health_check_returns_ok_when_healthy(client): 170 response = client.get("/ajax-api/3.0/mlflow/assistant/providers/mock_provider/health") 171 assert response.status_code == 200 172 assert response.json() == {"status": "ok"} 173 174 175 def test_health_check_returns_404_for_unknown_provider(client): 176 response = client.get("/ajax-api/3.0/mlflow/assistant/providers/unknown_provider/health") 177 assert response.status_code == 404 178 assert "not found" in response.json()["detail"] 179 180 181 def test_health_check_returns_412_when_cli_not_installed(): 182 app = FastAPI() 183 app.include_router(assistant_router) 184 185 async def mock_require_localhost(): 186 pass 187 188 app.dependency_overrides[_require_localhost] = mock_require_localhost 189 190 class CLINotInstalledProvider(MockProvider): 191 def check_connection(self, echo=None): 192 raise CLINotInstalledError("CLI not installed") 193 194 with patch("mlflow.server.assistant.api._provider", CLINotInstalledProvider()): 195 client = TestClient(app) 196 response = client.get("/ajax-api/3.0/mlflow/assistant/providers/mock_provider/health") 197 assert response.status_code == 412 198 assert "CLI not installed" in response.json()["detail"] 199 200 201 def test_health_check_returns_401_when_not_authenticated(): 202 app = FastAPI() 203 app.include_router(assistant_router) 204 205 async def mock_require_localhost(): 206 pass 207 208 app.dependency_overrides[_require_localhost] = mock_require_localhost 209 210 class NotAuthenticatedProvider(MockProvider): 211 def check_connection(self, echo=None): 212 raise NotAuthenticatedError("Not authenticated") 213 214 with patch("mlflow.server.assistant.api._provider", NotAuthenticatedProvider()): 215 client = TestClient(app) 216 response = client.get("/ajax-api/3.0/mlflow/assistant/providers/mock_provider/health") 217 assert response.status_code == 401 218 assert "Not authenticated" in response.json()["detail"] 219 220 221 def test_get_config_returns_empty_config(client): 222 response = client.get("/ajax-api/3.0/mlflow/assistant/config") 223 assert response.status_code == 200 224 data = response.json() 225 assert data["providers"] == {} 226 assert data["projects"] == {} 227 228 229 def test_get_config_returns_existing_config(client, tmp_path): 230 # Set up existing config by saving it first 231 project_dir = tmp_path / "project" 232 project_dir.mkdir() 233 234 config = AssistantConfig( 235 providers={"claude_code": AssistantProviderConfig(model="default", selected=True)}, 236 projects={"exp-123": ProjectConfig(type="local", location=str(project_dir))}, 237 ) 238 config.save() 239 240 response = client.get("/ajax-api/3.0/mlflow/assistant/config") 241 assert response.status_code == 200 242 data = response.json() 243 assert data["providers"]["claude_code"]["model"] == "default" 244 assert data["providers"]["claude_code"]["selected"] is True 245 assert data["projects"]["exp-123"]["location"] == str(project_dir) 246 247 248 def test_update_config_sets_provider(client): 249 response = client.put( 250 "/ajax-api/3.0/mlflow/assistant/config", 251 json={"providers": {"claude_code": {"model": "opus", "selected": True}}}, 252 ) 253 assert response.status_code == 200 254 data = response.json() 255 assert data["providers"]["claude_code"]["selected"] is True 256 257 258 def test_update_config_sets_project(client, tmp_path): 259 project_dir = tmp_path / "my_project" 260 project_dir.mkdir() 261 262 response = client.put( 263 "/ajax-api/3.0/mlflow/assistant/config", 264 json={"projects": {"exp-456": {"type": "local", "location": str(project_dir)}}}, 265 ) 266 assert response.status_code == 200 267 data = response.json() 268 assert data["projects"]["exp-456"]["location"] == str(project_dir) 269 270 271 def test_update_config_expand_user_home(client, tmp_path): 272 # Create a directory under a "fake home" structure to test ~ expansion 273 fake_home = tmp_path / "home" / "user" 274 project_dir = fake_home / "my_project" 275 project_dir.mkdir(parents=True) 276 277 with patch("mlflow.server.assistant.api.Path.expanduser") as mock_expanduser: 278 # Make expanduser return our tmp_path directory 279 mock_expanduser.return_value = project_dir 280 281 response = client.put( 282 "/ajax-api/3.0/mlflow/assistant/config", 283 json={"projects": {"exp-456": {"type": "local", "location": "~/my_project"}}}, 284 ) 285 assert response.status_code == 200 286 data = response.json() 287 assert data["projects"]["exp-456"]["location"] == str(project_dir) 288 289 290 @pytest.mark.asyncio 291 async def test_localhost_allows_ipv4(): 292 mock_request = MagicMock() 293 mock_request.client.host = "127.0.0.1" 294 await _require_localhost(mock_request) 295 296 297 @pytest.mark.asyncio 298 async def test_localhost_allows_ipv6(): 299 mock_request = MagicMock() 300 mock_request.client.host = "::1" 301 await _require_localhost(mock_request) 302 303 304 @pytest.mark.asyncio 305 async def test_localhost_blocks_external_ip(): 306 mock_request = MagicMock() 307 mock_request.client.host = "192.168.1.100" 308 309 with pytest.raises(HTTPException, match="same host"): 310 await _require_localhost(mock_request) 311 312 313 @pytest.mark.asyncio 314 async def test_localhost_blocks_external_hostname(): 315 mock_request = MagicMock() 316 mock_request.client.host = "external.example.com" 317 318 with pytest.raises(HTTPException, match="same host"): 319 await _require_localhost(mock_request) 320 321 322 @pytest.mark.asyncio 323 async def test_localhost_blocks_when_no_client(): 324 mock_request = MagicMock() 325 mock_request.client = None 326 327 with pytest.raises(HTTPException, match="same host"): 328 await _require_localhost(mock_request) 329 330 331 def test_validate_session_id_accepts_valid_uuid(): 332 valid_uuid = "f5f28c66-5ec6-46a1-9a2e-ca55fb64bf47" 333 SessionManager.validate_session_id(valid_uuid) # Should not raise 334 335 336 def test_validate_session_id_rejects_invalid_format(): 337 with pytest.raises(ValueError, match="Invalid session ID format"): 338 SessionManager.validate_session_id("invalid-session-id") 339 340 341 def test_validate_session_id_rejects_path_traversal(): 342 with pytest.raises(ValueError, match="Invalid session ID format"): 343 SessionManager.validate_session_id("../../../etc/passwd") 344 345 346 def _is_process_running(pid: int) -> bool: 347 try: 348 os.kill(pid, 0) 349 return True 350 except (OSError, ValueError): # ValueError is raised on Windows 351 return False 352 353 354 def test_patch_session_cancel_with_process(client): 355 r = client.post("/ajax-api/3.0/mlflow/assistant/message", json={"message": "Hi"}) 356 session_id = r.json()["session_id"] 357 358 # Start a real subprocess and register it with the session 359 with subprocess.Popen(["sleep", "10"]) as proc: 360 save_process_pid(session_id, proc.pid) 361 362 assert _is_process_running(proc.pid) 363 364 response = client.patch( 365 f"/ajax-api/3.0/mlflow/assistant/sessions/{session_id}", 366 json={"status": "cancelled"}, 367 ) 368 369 assert response.status_code == 200 370 data = response.json() 371 assert "terminated" in data["message"] 372 373 # Wait for the process to actually terminate 374 proc.wait(timeout=5) 375 assert proc.returncode is not None 376 # On non-Windows, verify the process is no longer running via PID check. 377 # Skip on Windows because PIDs are reused more aggressively. 378 if not is_windows(): 379 assert not _is_process_running(proc.pid) 380 381 382 def test_install_skills_success(client): 383 with patch( 384 "mlflow.server.assistant.api.install_skills", return_value=["skill1", "skill2"] 385 ) as mock_install: 386 response = client.post( 387 "/ajax-api/3.0/mlflow/assistant/skills/install", 388 json={"type": "custom", "custom_path": "/tmp/test-skills"}, 389 ) 390 391 assert response.status_code == 200 392 data = response.json() 393 assert data["installed_skills"] == ["skill1", "skill2"] 394 expected_path = os.path.join(os.sep, "tmp", "test-skills") 395 assert data["skills_directory"] == expected_path 396 mock_install.assert_called_once_with(Path(expected_path)) 397 398 399 def test_install_skills_skips_when_already_installed(client): 400 with ( 401 patch("mlflow.server.assistant.api.Path.exists", return_value=True), 402 patch( 403 "mlflow.server.assistant.api.list_installed_skills", 404 return_value=["existing_skill"], 405 ) as mock_list, 406 patch("mlflow.server.assistant.api.install_skills") as mock_install, 407 ): 408 response = client.post( 409 "/ajax-api/3.0/mlflow/assistant/skills/install", 410 json={"type": "custom", "custom_path": "/tmp/test-skills"}, 411 ) 412 413 assert response.status_code == 200 414 data = response.json() 415 assert data["installed_skills"] == ["existing_skill"] 416 mock_install.assert_not_called() 417 mock_list.assert_called_once()