/ tests / store / workspace / test_rest_store.py
test_rest_store.py
  1  from __future__ import annotations
  2  
  3  import json
  4  from types import SimpleNamespace
  5  from unittest import mock
  6  
  7  import pytest
  8  
  9  from mlflow.entities.workspace import Workspace, WorkspaceDeletionMode
 10  from mlflow.exceptions import MlflowException, RestException
 11  from mlflow.protos.service_pb2 import (
 12      CreateWorkspace,
 13      DeleteWorkspace,
 14      GetWorkspace,
 15      ListWorkspaces,
 16      UpdateWorkspace,
 17  )
 18  from mlflow.store.workspace.rest_store import WORKSPACES_ENDPOINT, RestWorkspaceStore
 19  
 20  
 21  @pytest.fixture
 22  def host_creds():
 23      return SimpleNamespace()
 24  
 25  
 26  @pytest.fixture
 27  def store(host_creds):
 28      return RestWorkspaceStore(lambda: host_creds)
 29  
 30  
 31  def test_list_workspaces_parses_response(store, host_creds):
 32      response = ListWorkspaces.Response()
 33      response.workspaces.add(name="default", description="Default workspace")
 34      response.workspaces.add(name="team-a", description="Team A")
 35      with mock.patch(
 36          "mlflow.store.workspace.rest_store.call_endpoint", return_value=response
 37      ) as call_endpoint:
 38          workspaces = store.list_workspaces()
 39  
 40      assert [ws.name for ws in workspaces] == ["default", "team-a"]
 41      call_endpoint.assert_called_once()
 42      kwargs = call_endpoint.call_args.kwargs
 43      assert kwargs["host_creds"] is host_creds
 44      assert kwargs["endpoint"] == WORKSPACES_ENDPOINT
 45      assert kwargs["method"] == "GET"
 46      assert kwargs["json_body"] is None
 47      assert kwargs.get("expected_status", 200) == 200
 48  
 49  
 50  def test_get_workspace_returns_entity(store, host_creds):
 51      response = GetWorkspace.Response()
 52      response.workspace.name = "team-b"
 53      response.workspace.description = "Team B"
 54      with mock.patch(
 55          "mlflow.store.workspace.rest_store.call_endpoint", return_value=response
 56      ) as call_endpoint:
 57          workspace = store.get_workspace("team-b")
 58  
 59      assert workspace.name == "team-b"
 60      assert workspace.description == "Team B"
 61      call_endpoint.assert_called_once()
 62      kwargs = call_endpoint.call_args.kwargs
 63      assert kwargs["endpoint"] == f"{WORKSPACES_ENDPOINT}/team-b"
 64      assert kwargs["method"] == "GET"
 65  
 66  
 67  def test_create_workspace_sends_payload(store, host_creds):
 68      response = CreateWorkspace.Response()
 69      response.workspace.name = "team-c"
 70      response.workspace.description = "Team C"
 71      with mock.patch(
 72          "mlflow.store.workspace.rest_store.call_endpoint", return_value=response
 73      ) as call_endpoint:
 74          workspace = store.create_workspace(Workspace(name="team-c", description="Team C"))
 75  
 76      assert workspace.name == "team-c"
 77      assert workspace.description == "Team C"
 78      call_endpoint.assert_called_once()
 79      kwargs = call_endpoint.call_args.kwargs
 80      assert kwargs["endpoint"] == WORKSPACES_ENDPOINT
 81      assert kwargs["method"] == "POST"
 82      assert kwargs["expected_status"] == 201
 83      assert json.loads(kwargs["json_body"]) == {"name": "team-c", "description": "Team C"}
 84  
 85  
 86  def test_create_workspace_conflict_raises_resource_exists(store, monkeypatch):
 87      exc = RestException({"error_code": "RESOURCE_ALREADY_EXISTS", "message": "already exists"})
 88      monkeypatch.setattr(
 89          "mlflow.store.workspace.rest_store.call_endpoint",
 90          mock.Mock(side_effect=exc),
 91      )
 92  
 93      with pytest.raises(
 94          MlflowException,
 95          match="already exists",
 96      ) as exc_info:
 97          store.create_workspace(Workspace(name="team-a"))
 98  
 99      assert exc_info.value.error_code == "RESOURCE_ALREADY_EXISTS"
100      assert "already exists" in exc_info.value.message
101  
102  
103  def test_update_workspace_returns_new_description(store, host_creds):
104      response = UpdateWorkspace.Response()
105      response.workspace.name = "team-e"
106      response.workspace.description = "updated"
107      with mock.patch(
108          "mlflow.store.workspace.rest_store.call_endpoint", return_value=response
109      ) as call_endpoint:
110          workspace = store.update_workspace(Workspace(name="team-e", description="updated"))
111  
112      assert workspace.description == "updated"
113      call_endpoint.assert_called_once()
114      kwargs = call_endpoint.call_args.kwargs
115      assert kwargs["endpoint"] == f"{WORKSPACES_ENDPOINT}/team-e"
116      assert kwargs["method"] == "PATCH"
117      assert json.loads(kwargs["json_body"]) == {"description": "updated"}
118  
119  
120  def test_delete_workspace_returns_on_success(store, host_creds):
121      response = DeleteWorkspace.Response()
122      with mock.patch(
123          "mlflow.store.workspace.rest_store.call_endpoint", return_value=response
124      ) as call_endpoint:
125          store.delete_workspace("team-f")
126  
127      call_endpoint.assert_called_once()
128      kwargs = call_endpoint.call_args.kwargs
129      assert kwargs["endpoint"] == f"{WORKSPACES_ENDPOINT}/team-f"
130      assert kwargs["method"] == "DELETE"
131      assert kwargs["expected_status"] == 204
132      assert kwargs["json_body"] is None
133  
134  
135  @pytest.mark.parametrize(
136      ("mode", "expected_suffix"),
137      [
138          (WorkspaceDeletionMode.RESTRICT, ""),
139          (WorkspaceDeletionMode.CASCADE, "?mode=CASCADE"),
140          (WorkspaceDeletionMode.SET_DEFAULT, "?mode=SET_DEFAULT"),
141      ],
142  )
143  def test_delete_workspace_sends_mode_query_param(store, host_creds, mode, expected_suffix):
144      response = DeleteWorkspace.Response()
145      with mock.patch(
146          "mlflow.store.workspace.rest_store.call_endpoint", return_value=response
147      ) as call_endpoint:
148          store.delete_workspace("team-f", mode=mode)
149  
150      kwargs = call_endpoint.call_args.kwargs
151      assert kwargs["endpoint"] == f"{WORKSPACES_ENDPOINT}/team-f{expected_suffix}"
152  
153  
154  def test_get_default_workspace_not_supported(store):
155      with pytest.raises(
156          NotImplementedError,
157          match="REST workspace provider does not expose a default workspace",
158      ):
159          store.get_default_workspace()
160  
161  
162  def test_rest_store_validates_workspace_names_before_http(monkeypatch, store):
163      mock_call = mock.Mock()
164      monkeypatch.setattr("mlflow.store.workspace.rest_store.call_endpoint", mock_call)
165  
166      with pytest.raises(MlflowException, match="must match the pattern"):
167          store.get_workspace("Invalid")
168  
169      mock_call.assert_not_called()