test_workspace_endpoints.py
1 from __future__ import annotations 2 3 import json 4 from unittest import mock 5 6 import pytest 7 from flask import Flask 8 9 from mlflow.entities.workspace import Workspace, WorkspaceDeletionMode 10 from mlflow.server.handlers import get_endpoints 11 12 13 @pytest.fixture(autouse=True) 14 def enable_workspaces(monkeypatch): 15 monkeypatch.setenv("MLFLOW_ENABLE_WORKSPACES", "true") 16 17 18 @pytest.fixture 19 def app(monkeypatch): 20 flask_app = Flask(__name__) 21 for rule, view_func, methods in get_endpoints(): 22 flask_app.add_url_rule(rule, view_func=view_func, methods=methods) 23 return flask_app 24 25 26 @pytest.fixture 27 def mock_workspace_store(monkeypatch): 28 store = mock.Mock() 29 monkeypatch.setattr( 30 "mlflow.server.handlers._get_workspace_store", 31 lambda *_, **__: store, 32 ) 33 return store 34 35 36 @pytest.fixture 37 def mock_tracking_store(monkeypatch): 38 store = mock.Mock() 39 store.artifact_root_uri = "/default/artifact/root" 40 monkeypatch.setattr( 41 "mlflow.server.handlers._get_tracking_store", 42 lambda *_, **__: store, 43 ) 44 return store 45 46 47 def _workspace_to_json(payload): 48 return json.loads(payload) 49 50 51 def test_list_workspaces_endpoint(app, mock_workspace_store): 52 mock_workspace_store.list_workspaces.return_value = [ 53 Workspace(name="default", description="Default"), 54 Workspace(name="team-a", description=None), 55 ] 56 with app.test_client() as client: 57 response = client.get("/api/3.0/mlflow/workspaces") 58 59 assert response.status_code == 200 60 payload = _workspace_to_json(response.get_data(True)) 61 assert payload["workspaces"][0] == {"name": "default", "description": "Default"} 62 assert payload["workspaces"][1] == {"name": "team-a"} 63 mock_workspace_store.list_workspaces.assert_called_once_with() 64 65 66 def test_create_workspace_endpoint(app, mock_workspace_store, mock_tracking_store): 67 created = Workspace(name="team-b", description="Team B") 68 mock_workspace_store.create_workspace.return_value = created 69 with app.test_client() as client: 70 response = client.post( 71 "/api/3.0/mlflow/workspaces", 72 json={"name": "team-b", "description": "Team B"}, 73 ) 74 75 assert response.status_code == 201 76 payload = _workspace_to_json(response.get_data(True)) 77 assert payload == {"workspace": {"name": "team-b", "description": "Team B"}} 78 mock_workspace_store.create_workspace.assert_called_once() 79 mock_tracking_store.get_experiment_by_name.assert_not_called() 80 mock_tracking_store.create_experiment.assert_not_called() 81 82 83 def test_get_workspace_endpoint(app, mock_workspace_store): 84 mock_workspace_store.get_workspace.return_value = Workspace(name="team-c", description="Team C") 85 with app.test_client() as client: 86 response = client.get("/api/3.0/mlflow/workspaces/team-c") 87 88 assert response.status_code == 200 89 payload = _workspace_to_json(response.get_data(True)) 90 assert payload == {"workspace": {"name": "team-c", "description": "Team C"}} 91 mock_workspace_store.get_workspace.assert_called_once_with("team-c") 92 93 94 def test_update_workspace_endpoint(app, mock_workspace_store): 95 updated = Workspace(name="team-d", description="Updated") 96 mock_workspace_store.update_workspace.return_value = updated 97 with app.test_client() as client: 98 response = client.patch( 99 "/api/3.0/mlflow/workspaces/team-d", 100 json={"description": "Updated"}, 101 ) 102 103 assert response.status_code == 200 104 payload = _workspace_to_json(response.get_data(True)) 105 assert payload == {"workspace": {"name": "team-d", "description": "Updated"}} 106 mock_workspace_store.update_workspace.assert_called_once() 107 108 109 def test_update_default_workspace_allows_reserved_name(app, mock_workspace_store): 110 updated = Workspace(name="default", default_artifact_root="s3://bucket/root") 111 mock_workspace_store.update_workspace.return_value = updated 112 113 with app.test_client() as client: 114 response = client.patch( 115 "/api/3.0/mlflow/workspaces/default", 116 json={"default_artifact_root": "s3://bucket/root"}, 117 ) 118 119 assert response.status_code == 200 120 payload = _workspace_to_json(response.get_data(True)) 121 assert payload == { 122 "workspace": {"name": "default", "default_artifact_root": "s3://bucket/root"} 123 } 124 args, _ = mock_workspace_store.update_workspace.call_args 125 assert args[0].name == "default" 126 assert args[0].default_artifact_root == "s3://bucket/root" 127 128 129 def test_update_workspace_can_clear_default_artifact_root( 130 app, mock_workspace_store, mock_tracking_store 131 ): 132 cleared = Workspace(name="team-clear", description=None, default_artifact_root=None) 133 mock_workspace_store.update_workspace.return_value = cleared 134 with app.test_client() as client: 135 response = client.patch( 136 "/api/3.0/mlflow/workspaces/team-clear", 137 json={"default_artifact_root": " "}, 138 ) 139 140 assert response.status_code == 200 141 payload = _workspace_to_json(response.get_data(True)) 142 assert payload == {"workspace": {"name": "team-clear"}} 143 args, _ = mock_workspace_store.update_workspace.call_args 144 assert isinstance(args[0], Workspace) 145 assert args[0].name == "team-clear" 146 # Handler passes "" to indicate "clear"; the store converts "" to None 147 assert args[0].default_artifact_root == "" 148 149 150 def test_delete_workspace_endpoint(app, mock_workspace_store): 151 with app.test_client() as client: 152 response = client.delete("/api/3.0/mlflow/workspaces/team-e") 153 154 assert response.status_code == 204 155 mock_workspace_store.delete_workspace.assert_called_once_with( 156 "team-e", mode=WorkspaceDeletionMode.RESTRICT 157 ) 158 159 160 def test_delete_default_workspace_rejected_by_validation(app, mock_workspace_store): 161 with app.test_client() as client: 162 response = client.delete("/api/3.0/mlflow/workspaces/default") 163 164 assert response.status_code == 400 165 payload = _workspace_to_json(response.get_data(True)) 166 assert "cannot be deleted" in payload["message"] 167 mock_workspace_store.delete_workspace.assert_not_called() 168 169 170 def test_create_workspace_fails_without_artifact_root(app, mock_workspace_store, monkeypatch): 171 tracking_store = mock.Mock() 172 tracking_store.artifact_root_uri = None 173 monkeypatch.setattr( 174 "mlflow.server.handlers._get_tracking_store", 175 lambda *_, **__: tracking_store, 176 ) 177 with app.test_client() as client: 178 response = client.post( 179 "/api/3.0/mlflow/workspaces", 180 json={"name": "team-no-root"}, 181 ) 182 183 assert response.status_code == 400 184 payload = _workspace_to_json(response.get_data(True)) 185 assert "artifact root" in payload["message"].lower() 186 187 188 def test_create_workspace_with_artifact_root_succeeds_without_server_default( 189 app, mock_workspace_store, monkeypatch 190 ): 191 tracking_store = mock.Mock() 192 tracking_store.artifact_root_uri = None 193 monkeypatch.setattr( 194 "mlflow.server.handlers._get_tracking_store", 195 lambda *_, **__: tracking_store, 196 ) 197 created = Workspace(name="team-with-root", default_artifact_root="s3://bucket/path") 198 mock_workspace_store.create_workspace.return_value = created 199 with app.test_client() as client: 200 response = client.post( 201 "/api/3.0/mlflow/workspaces", 202 json={"name": "team-with-root", "default_artifact_root": "s3://bucket/path"}, 203 ) 204 205 assert response.status_code == 201 206 207 208 def test_create_default_workspace_rejected(app, mock_workspace_store, mock_tracking_store): 209 with app.test_client() as client: 210 response = client.post( 211 "/api/3.0/mlflow/workspaces", 212 json={"name": "default"}, 213 ) 214 215 assert response.status_code == 400 216 payload = _workspace_to_json(response.get_data(True)) 217 assert "reserved" in payload["message"] 218 mock_workspace_store.create_workspace.assert_not_called() 219 220 221 def test_update_workspace_clear_artifact_root_fails_without_server_default( 222 app, mock_workspace_store, monkeypatch 223 ): 224 tracking_store = mock.Mock() 225 tracking_store.artifact_root_uri = None 226 monkeypatch.setattr( 227 "mlflow.server.handlers._get_tracking_store", 228 lambda *_, **__: tracking_store, 229 ) 230 with app.test_client() as client: 231 response = client.patch( 232 "/api/3.0/mlflow/workspaces/team-clear", 233 json={"default_artifact_root": ""}, 234 ) 235 236 assert response.status_code == 400 237 payload = _workspace_to_json(response.get_data(True)) 238 assert "artifact root" in payload["message"].lower()