/ tests / server / test_workspace_endpoints.py
test_workspace_endpoints.py
  1  from __future__ import annotations
  2  
  3  import json
  4  from unittest import mock
  5  
  6  import pytest
  7  from flask import Flask
  8  
  9  from mlflow.entities.workspace import Workspace, WorkspaceDeletionMode
 10  from mlflow.server.handlers import get_endpoints
 11  
 12  
 13  @pytest.fixture(autouse=True)
 14  def enable_workspaces(monkeypatch):
 15      monkeypatch.setenv("MLFLOW_ENABLE_WORKSPACES", "true")
 16  
 17  
 18  @pytest.fixture
 19  def app(monkeypatch):
 20      flask_app = Flask(__name__)
 21      for rule, view_func, methods in get_endpoints():
 22          flask_app.add_url_rule(rule, view_func=view_func, methods=methods)
 23      return flask_app
 24  
 25  
 26  @pytest.fixture
 27  def mock_workspace_store(monkeypatch):
 28      store = mock.Mock()
 29      monkeypatch.setattr(
 30          "mlflow.server.handlers._get_workspace_store",
 31          lambda *_, **__: store,
 32      )
 33      return store
 34  
 35  
 36  @pytest.fixture
 37  def mock_tracking_store(monkeypatch):
 38      store = mock.Mock()
 39      store.artifact_root_uri = "/default/artifact/root"
 40      monkeypatch.setattr(
 41          "mlflow.server.handlers._get_tracking_store",
 42          lambda *_, **__: store,
 43      )
 44      return store
 45  
 46  
 47  def _workspace_to_json(payload):
 48      return json.loads(payload)
 49  
 50  
 51  def test_list_workspaces_endpoint(app, mock_workspace_store):
 52      mock_workspace_store.list_workspaces.return_value = [
 53          Workspace(name="default", description="Default"),
 54          Workspace(name="team-a", description=None),
 55      ]
 56      with app.test_client() as client:
 57          response = client.get("/api/3.0/mlflow/workspaces")
 58  
 59      assert response.status_code == 200
 60      payload = _workspace_to_json(response.get_data(True))
 61      assert payload["workspaces"][0] == {"name": "default", "description": "Default"}
 62      assert payload["workspaces"][1] == {"name": "team-a"}
 63      mock_workspace_store.list_workspaces.assert_called_once_with()
 64  
 65  
 66  def test_create_workspace_endpoint(app, mock_workspace_store, mock_tracking_store):
 67      created = Workspace(name="team-b", description="Team B")
 68      mock_workspace_store.create_workspace.return_value = created
 69      with app.test_client() as client:
 70          response = client.post(
 71              "/api/3.0/mlflow/workspaces",
 72              json={"name": "team-b", "description": "Team B"},
 73          )
 74  
 75      assert response.status_code == 201
 76      payload = _workspace_to_json(response.get_data(True))
 77      assert payload == {"workspace": {"name": "team-b", "description": "Team B"}}
 78      mock_workspace_store.create_workspace.assert_called_once()
 79      mock_tracking_store.get_experiment_by_name.assert_not_called()
 80      mock_tracking_store.create_experiment.assert_not_called()
 81  
 82  
 83  def test_get_workspace_endpoint(app, mock_workspace_store):
 84      mock_workspace_store.get_workspace.return_value = Workspace(name="team-c", description="Team C")
 85      with app.test_client() as client:
 86          response = client.get("/api/3.0/mlflow/workspaces/team-c")
 87  
 88      assert response.status_code == 200
 89      payload = _workspace_to_json(response.get_data(True))
 90      assert payload == {"workspace": {"name": "team-c", "description": "Team C"}}
 91      mock_workspace_store.get_workspace.assert_called_once_with("team-c")
 92  
 93  
 94  def test_update_workspace_endpoint(app, mock_workspace_store):
 95      updated = Workspace(name="team-d", description="Updated")
 96      mock_workspace_store.update_workspace.return_value = updated
 97      with app.test_client() as client:
 98          response = client.patch(
 99              "/api/3.0/mlflow/workspaces/team-d",
100              json={"description": "Updated"},
101          )
102  
103      assert response.status_code == 200
104      payload = _workspace_to_json(response.get_data(True))
105      assert payload == {"workspace": {"name": "team-d", "description": "Updated"}}
106      mock_workspace_store.update_workspace.assert_called_once()
107  
108  
109  def test_update_default_workspace_allows_reserved_name(app, mock_workspace_store):
110      updated = Workspace(name="default", default_artifact_root="s3://bucket/root")
111      mock_workspace_store.update_workspace.return_value = updated
112  
113      with app.test_client() as client:
114          response = client.patch(
115              "/api/3.0/mlflow/workspaces/default",
116              json={"default_artifact_root": "s3://bucket/root"},
117          )
118  
119      assert response.status_code == 200
120      payload = _workspace_to_json(response.get_data(True))
121      assert payload == {
122          "workspace": {"name": "default", "default_artifact_root": "s3://bucket/root"}
123      }
124      args, _ = mock_workspace_store.update_workspace.call_args
125      assert args[0].name == "default"
126      assert args[0].default_artifact_root == "s3://bucket/root"
127  
128  
129  def test_update_workspace_can_clear_default_artifact_root(
130      app, mock_workspace_store, mock_tracking_store
131  ):
132      cleared = Workspace(name="team-clear", description=None, default_artifact_root=None)
133      mock_workspace_store.update_workspace.return_value = cleared
134      with app.test_client() as client:
135          response = client.patch(
136              "/api/3.0/mlflow/workspaces/team-clear",
137              json={"default_artifact_root": " "},
138          )
139  
140      assert response.status_code == 200
141      payload = _workspace_to_json(response.get_data(True))
142      assert payload == {"workspace": {"name": "team-clear"}}
143      args, _ = mock_workspace_store.update_workspace.call_args
144      assert isinstance(args[0], Workspace)
145      assert args[0].name == "team-clear"
146      # Handler passes "" to indicate "clear"; the store converts "" to None
147      assert args[0].default_artifact_root == ""
148  
149  
150  def test_delete_workspace_endpoint(app, mock_workspace_store):
151      with app.test_client() as client:
152          response = client.delete("/api/3.0/mlflow/workspaces/team-e")
153  
154      assert response.status_code == 204
155      mock_workspace_store.delete_workspace.assert_called_once_with(
156          "team-e", mode=WorkspaceDeletionMode.RESTRICT
157      )
158  
159  
160  def test_delete_default_workspace_rejected_by_validation(app, mock_workspace_store):
161      with app.test_client() as client:
162          response = client.delete("/api/3.0/mlflow/workspaces/default")
163  
164      assert response.status_code == 400
165      payload = _workspace_to_json(response.get_data(True))
166      assert "cannot be deleted" in payload["message"]
167      mock_workspace_store.delete_workspace.assert_not_called()
168  
169  
170  def test_create_workspace_fails_without_artifact_root(app, mock_workspace_store, monkeypatch):
171      tracking_store = mock.Mock()
172      tracking_store.artifact_root_uri = None
173      monkeypatch.setattr(
174          "mlflow.server.handlers._get_tracking_store",
175          lambda *_, **__: tracking_store,
176      )
177      with app.test_client() as client:
178          response = client.post(
179              "/api/3.0/mlflow/workspaces",
180              json={"name": "team-no-root"},
181          )
182  
183      assert response.status_code == 400
184      payload = _workspace_to_json(response.get_data(True))
185      assert "artifact root" in payload["message"].lower()
186  
187  
188  def test_create_workspace_with_artifact_root_succeeds_without_server_default(
189      app, mock_workspace_store, monkeypatch
190  ):
191      tracking_store = mock.Mock()
192      tracking_store.artifact_root_uri = None
193      monkeypatch.setattr(
194          "mlflow.server.handlers._get_tracking_store",
195          lambda *_, **__: tracking_store,
196      )
197      created = Workspace(name="team-with-root", default_artifact_root="s3://bucket/path")
198      mock_workspace_store.create_workspace.return_value = created
199      with app.test_client() as client:
200          response = client.post(
201              "/api/3.0/mlflow/workspaces",
202              json={"name": "team-with-root", "default_artifact_root": "s3://bucket/path"},
203          )
204  
205      assert response.status_code == 201
206  
207  
208  def test_create_default_workspace_rejected(app, mock_workspace_store, mock_tracking_store):
209      with app.test_client() as client:
210          response = client.post(
211              "/api/3.0/mlflow/workspaces",
212              json={"name": "default"},
213          )
214  
215      assert response.status_code == 400
216      payload = _workspace_to_json(response.get_data(True))
217      assert "reserved" in payload["message"]
218      mock_workspace_store.create_workspace.assert_not_called()
219  
220  
221  def test_update_workspace_clear_artifact_root_fails_without_server_default(
222      app, mock_workspace_store, monkeypatch
223  ):
224      tracking_store = mock.Mock()
225      tracking_store.artifact_root_uri = None
226      monkeypatch.setattr(
227          "mlflow.server.handlers._get_tracking_store",
228          lambda *_, **__: tracking_store,
229      )
230      with app.test_client() as client:
231          response = client.patch(
232              "/api/3.0/mlflow/workspaces/team-clear",
233              json={"default_artifact_root": ""},
234          )
235  
236      assert response.status_code == 400
237      payload = _workspace_to_json(response.get_data(True))
238      assert "artifact root" in payload["message"].lower()