/ tests / server / assistant / test_api.py
test_api.py
  1  import os
  2  import shutil
  3  import subprocess
  4  from pathlib import Path
  5  from typing import Any
  6  from unittest.mock import MagicMock, patch
  7  
  8  import pytest
  9  from fastapi import FastAPI, HTTPException
 10  from fastapi.testclient import TestClient
 11  
 12  from mlflow.assistant.config import AssistantConfig, ProjectConfig
 13  from mlflow.assistant.config import ProviderConfig as AssistantProviderConfig
 14  from mlflow.assistant.providers.base import (
 15      AssistantProvider,
 16      CLINotInstalledError,
 17      NotAuthenticatedError,
 18      ProviderConfig,
 19  )
 20  from mlflow.assistant.types import Event, Message
 21  from mlflow.server.assistant.api import _require_localhost, assistant_router
 22  from mlflow.server.assistant.session import SESSION_DIR, SessionManager, save_process_pid
 23  from mlflow.utils.os import is_windows
 24  
 25  
 26  class MockProvider(AssistantProvider):
 27      """Mock provider for testing."""
 28  
 29      @property
 30      def name(self) -> str:
 31          return "mock_provider"
 32  
 33      @property
 34      def display_name(self) -> str:
 35          return "Mock Provider"
 36  
 37      @property
 38      def description(self) -> str:
 39          return "Mock provider for testing"
 40  
 41      @property
 42      def config_path(self) -> Path:
 43          return Path.home() / ".mlflow" / "assistant" / "mock-config.json"
 44  
 45      def is_available(self) -> bool:
 46          return True
 47  
 48      def load_config(self) -> ProviderConfig:
 49          return ProviderConfig()
 50  
 51      def check_connection(self, echo=print) -> None:
 52          pass
 53  
 54      def resolve_skills_path(self, base_directory: Path) -> Path:
 55          return base_directory / ".mock" / "skills"
 56  
 57      async def astream(
 58          self,
 59          prompt: str,
 60          tracking_uri: str,
 61          session_id: str | None = None,
 62          cwd: Path | None = None,
 63          context: dict[str, Any] | None = None,
 64          mlflow_session_id: str | None = None,
 65      ):
 66          yield Event.from_message(message=Message(role="user", content="Hello from mock"))
 67          yield Event.from_result(result="complete", session_id="mock-session-123")
 68  
 69  
 70  @pytest.fixture(autouse=True)
 71  def isolated_config(tmp_path, monkeypatch):
 72      """Redirect config to tmp_path to avoid modifying real user config."""
 73      import mlflow.assistant.config as config_module
 74  
 75      config_home = tmp_path / ".mlflow" / "assistant"
 76      config_path = config_home / "config.json"
 77  
 78      monkeypatch.setattr(config_module, "MLFLOW_ASSISTANT_HOME", config_home)
 79      monkeypatch.setattr(config_module, "CONFIG_PATH", config_path)
 80  
 81      return config_home
 82  
 83  
 84  @pytest.fixture(autouse=True)
 85  def clear_sessions():
 86      """Clear session storage before each test."""
 87      if SESSION_DIR.exists():
 88          shutil.rmtree(SESSION_DIR)
 89      yield
 90      if SESSION_DIR.exists():
 91          shutil.rmtree(SESSION_DIR)
 92  
 93  
 94  @pytest.fixture
 95  def client():
 96      """Create test client with mock provider and bypassed localhost check."""
 97      app = FastAPI()
 98      app.include_router(assistant_router)
 99  
100      # Override localhost dependency to allow TestClient requests
101      async def mock_require_localhost():
102          pass
103  
104      app.dependency_overrides[_require_localhost] = mock_require_localhost
105  
106      with patch("mlflow.server.assistant.api._provider", MockProvider()):
107          yield TestClient(app)
108  
109  
110  def test_message(client):
111      response = client.post(
112          "/ajax-api/3.0/mlflow/assistant/message",
113          json={
114              "message": "Hello",
115              "context": {"trace_id": "tr-123", "experiment_id": "exp-456"},
116          },
117      )
118  
119      assert response.status_code == 200
120      data = response.json()
121      session_id = data["session_id"]
122      assert session_id is not None
123      assert data["stream_url"] == f"/ajax-api/3.0/mlflow/assistant/stream/{data['session_id']}"
124  
125      # continue the conversation
126      response = client.post(
127          "/ajax-api/3.0/mlflow/assistant/message",
128          json={"message": "Second message", "session_id": session_id},
129      )
130  
131      assert response.status_code == 200
132      assert response.json()["session_id"] == session_id
133  
134  
135  def test_stream_not_found_for_invalid_session(client):
136      response = client.get("/ajax-api/3.0/mlflow/assistant/sessions/invalid-session-id/stream")
137      assert response.status_code == 404
138      assert "Session not found" in response.json()["detail"]
139  
140  
141  def test_stream_bad_request_when_no_pending_message(client):
142      # Create session and consume the pending message
143      r = client.post("/ajax-api/3.0/mlflow/assistant/message", json={"message": "Hi"})
144      session_id = r.json()["session_id"]
145      client.get(f"/ajax-api/3.0/mlflow/assistant/sessions/{session_id}/stream")
146  
147      # Try to stream again without a new message
148      response = client.get(f"/ajax-api/3.0/mlflow/assistant/sessions/{session_id}/stream")
149  
150      assert response.status_code == 400
151      assert "No pending message" in response.json()["detail"]
152  
153  
154  def test_stream_returns_sse_events(client):
155      r = client.post("/ajax-api/3.0/mlflow/assistant/message", json={"message": "Hi"})
156      session_id = r.json()["session_id"]
157  
158      response = client.get(f"/ajax-api/3.0/mlflow/assistant/sessions/{session_id}/stream")
159  
160      assert response.status_code == 200
161      assert "text/event-stream" in response.headers["content-type"]
162  
163      content = response.text
164      assert "event: message" in content
165      assert "event: done" in content
166      assert "Hello from mock" in content
167  
168  
169  def test_health_check_returns_ok_when_healthy(client):
170      response = client.get("/ajax-api/3.0/mlflow/assistant/providers/mock_provider/health")
171      assert response.status_code == 200
172      assert response.json() == {"status": "ok"}
173  
174  
175  def test_health_check_returns_404_for_unknown_provider(client):
176      response = client.get("/ajax-api/3.0/mlflow/assistant/providers/unknown_provider/health")
177      assert response.status_code == 404
178      assert "not found" in response.json()["detail"]
179  
180  
181  def test_health_check_returns_412_when_cli_not_installed():
182      app = FastAPI()
183      app.include_router(assistant_router)
184  
185      async def mock_require_localhost():
186          pass
187  
188      app.dependency_overrides[_require_localhost] = mock_require_localhost
189  
190      class CLINotInstalledProvider(MockProvider):
191          def check_connection(self, echo=None):
192              raise CLINotInstalledError("CLI not installed")
193  
194      with patch("mlflow.server.assistant.api._provider", CLINotInstalledProvider()):
195          client = TestClient(app)
196          response = client.get("/ajax-api/3.0/mlflow/assistant/providers/mock_provider/health")
197          assert response.status_code == 412
198          assert "CLI not installed" in response.json()["detail"]
199  
200  
201  def test_health_check_returns_401_when_not_authenticated():
202      app = FastAPI()
203      app.include_router(assistant_router)
204  
205      async def mock_require_localhost():
206          pass
207  
208      app.dependency_overrides[_require_localhost] = mock_require_localhost
209  
210      class NotAuthenticatedProvider(MockProvider):
211          def check_connection(self, echo=None):
212              raise NotAuthenticatedError("Not authenticated")
213  
214      with patch("mlflow.server.assistant.api._provider", NotAuthenticatedProvider()):
215          client = TestClient(app)
216          response = client.get("/ajax-api/3.0/mlflow/assistant/providers/mock_provider/health")
217          assert response.status_code == 401
218          assert "Not authenticated" in response.json()["detail"]
219  
220  
221  def test_get_config_returns_empty_config(client):
222      response = client.get("/ajax-api/3.0/mlflow/assistant/config")
223      assert response.status_code == 200
224      data = response.json()
225      assert data["providers"] == {}
226      assert data["projects"] == {}
227  
228  
229  def test_get_config_returns_existing_config(client, tmp_path):
230      # Set up existing config by saving it first
231      project_dir = tmp_path / "project"
232      project_dir.mkdir()
233  
234      config = AssistantConfig(
235          providers={"claude_code": AssistantProviderConfig(model="default", selected=True)},
236          projects={"exp-123": ProjectConfig(type="local", location=str(project_dir))},
237      )
238      config.save()
239  
240      response = client.get("/ajax-api/3.0/mlflow/assistant/config")
241      assert response.status_code == 200
242      data = response.json()
243      assert data["providers"]["claude_code"]["model"] == "default"
244      assert data["providers"]["claude_code"]["selected"] is True
245      assert data["projects"]["exp-123"]["location"] == str(project_dir)
246  
247  
248  def test_update_config_sets_provider(client):
249      response = client.put(
250          "/ajax-api/3.0/mlflow/assistant/config",
251          json={"providers": {"claude_code": {"model": "opus", "selected": True}}},
252      )
253      assert response.status_code == 200
254      data = response.json()
255      assert data["providers"]["claude_code"]["selected"] is True
256  
257  
258  def test_update_config_sets_project(client, tmp_path):
259      project_dir = tmp_path / "my_project"
260      project_dir.mkdir()
261  
262      response = client.put(
263          "/ajax-api/3.0/mlflow/assistant/config",
264          json={"projects": {"exp-456": {"type": "local", "location": str(project_dir)}}},
265      )
266      assert response.status_code == 200
267      data = response.json()
268      assert data["projects"]["exp-456"]["location"] == str(project_dir)
269  
270  
271  def test_update_config_expand_user_home(client, tmp_path):
272      # Create a directory under a "fake home" structure to test ~ expansion
273      fake_home = tmp_path / "home" / "user"
274      project_dir = fake_home / "my_project"
275      project_dir.mkdir(parents=True)
276  
277      with patch("mlflow.server.assistant.api.Path.expanduser") as mock_expanduser:
278          # Make expanduser return our tmp_path directory
279          mock_expanduser.return_value = project_dir
280  
281          response = client.put(
282              "/ajax-api/3.0/mlflow/assistant/config",
283              json={"projects": {"exp-456": {"type": "local", "location": "~/my_project"}}},
284          )
285          assert response.status_code == 200
286          data = response.json()
287          assert data["projects"]["exp-456"]["location"] == str(project_dir)
288  
289  
290  @pytest.mark.asyncio
291  async def test_localhost_allows_ipv4():
292      mock_request = MagicMock()
293      mock_request.client.host = "127.0.0.1"
294      await _require_localhost(mock_request)
295  
296  
297  @pytest.mark.asyncio
298  async def test_localhost_allows_ipv6():
299      mock_request = MagicMock()
300      mock_request.client.host = "::1"
301      await _require_localhost(mock_request)
302  
303  
304  @pytest.mark.asyncio
305  async def test_localhost_blocks_external_ip():
306      mock_request = MagicMock()
307      mock_request.client.host = "192.168.1.100"
308  
309      with pytest.raises(HTTPException, match="same host"):
310          await _require_localhost(mock_request)
311  
312  
313  @pytest.mark.asyncio
314  async def test_localhost_blocks_external_hostname():
315      mock_request = MagicMock()
316      mock_request.client.host = "external.example.com"
317  
318      with pytest.raises(HTTPException, match="same host"):
319          await _require_localhost(mock_request)
320  
321  
322  @pytest.mark.asyncio
323  async def test_localhost_blocks_when_no_client():
324      mock_request = MagicMock()
325      mock_request.client = None
326  
327      with pytest.raises(HTTPException, match="same host"):
328          await _require_localhost(mock_request)
329  
330  
331  def test_validate_session_id_accepts_valid_uuid():
332      valid_uuid = "f5f28c66-5ec6-46a1-9a2e-ca55fb64bf47"
333      SessionManager.validate_session_id(valid_uuid)  # Should not raise
334  
335  
336  def test_validate_session_id_rejects_invalid_format():
337      with pytest.raises(ValueError, match="Invalid session ID format"):
338          SessionManager.validate_session_id("invalid-session-id")
339  
340  
341  def test_validate_session_id_rejects_path_traversal():
342      with pytest.raises(ValueError, match="Invalid session ID format"):
343          SessionManager.validate_session_id("../../../etc/passwd")
344  
345  
346  def _is_process_running(pid: int) -> bool:
347      try:
348          os.kill(pid, 0)
349          return True
350      except (OSError, ValueError):  # ValueError is raised on Windows
351          return False
352  
353  
354  def test_patch_session_cancel_with_process(client):
355      r = client.post("/ajax-api/3.0/mlflow/assistant/message", json={"message": "Hi"})
356      session_id = r.json()["session_id"]
357  
358      # Start a real subprocess and register it with the session
359      with subprocess.Popen(["sleep", "10"]) as proc:
360          save_process_pid(session_id, proc.pid)
361  
362          assert _is_process_running(proc.pid)
363  
364          response = client.patch(
365              f"/ajax-api/3.0/mlflow/assistant/sessions/{session_id}",
366              json={"status": "cancelled"},
367          )
368  
369          assert response.status_code == 200
370          data = response.json()
371          assert "terminated" in data["message"]
372  
373          # Wait for the process to actually terminate
374          proc.wait(timeout=5)
375          assert proc.returncode is not None
376          # On non-Windows, verify the process is no longer running via PID check.
377          # Skip on Windows because PIDs are reused more aggressively.
378          if not is_windows():
379              assert not _is_process_running(proc.pid)
380  
381  
382  def test_install_skills_success(client):
383      with patch(
384          "mlflow.server.assistant.api.install_skills", return_value=["skill1", "skill2"]
385      ) as mock_install:
386          response = client.post(
387              "/ajax-api/3.0/mlflow/assistant/skills/install",
388              json={"type": "custom", "custom_path": "/tmp/test-skills"},
389          )
390  
391          assert response.status_code == 200
392          data = response.json()
393          assert data["installed_skills"] == ["skill1", "skill2"]
394          expected_path = os.path.join(os.sep, "tmp", "test-skills")
395          assert data["skills_directory"] == expected_path
396          mock_install.assert_called_once_with(Path(expected_path))
397  
398  
399  def test_install_skills_skips_when_already_installed(client):
400      with (
401          patch("mlflow.server.assistant.api.Path.exists", return_value=True),
402          patch(
403              "mlflow.server.assistant.api.list_installed_skills",
404              return_value=["existing_skill"],
405          ) as mock_list,
406          patch("mlflow.server.assistant.api.install_skills") as mock_install,
407      ):
408          response = client.post(
409              "/ajax-api/3.0/mlflow/assistant/skills/install",
410              json={"type": "custom", "custom_path": "/tmp/test-skills"},
411          )
412  
413          assert response.status_code == 200
414          data = response.json()
415          assert data["installed_skills"] == ["existing_skill"]
416          mock_install.assert_not_called()
417          mock_list.assert_called_once()