test_rest_store.py
1 from __future__ import annotations 2 3 import json 4 from types import SimpleNamespace 5 from unittest import mock 6 7 import pytest 8 9 from mlflow.entities.workspace import Workspace, WorkspaceDeletionMode 10 from mlflow.exceptions import MlflowException, RestException 11 from mlflow.protos.service_pb2 import ( 12 CreateWorkspace, 13 DeleteWorkspace, 14 GetWorkspace, 15 ListWorkspaces, 16 UpdateWorkspace, 17 ) 18 from mlflow.store.workspace.rest_store import WORKSPACES_ENDPOINT, RestWorkspaceStore 19 20 21 @pytest.fixture 22 def host_creds(): 23 return SimpleNamespace() 24 25 26 @pytest.fixture 27 def store(host_creds): 28 return RestWorkspaceStore(lambda: host_creds) 29 30 31 def test_list_workspaces_parses_response(store, host_creds): 32 response = ListWorkspaces.Response() 33 response.workspaces.add(name="default", description="Default workspace") 34 response.workspaces.add(name="team-a", description="Team A") 35 with mock.patch( 36 "mlflow.store.workspace.rest_store.call_endpoint", return_value=response 37 ) as call_endpoint: 38 workspaces = store.list_workspaces() 39 40 assert [ws.name for ws in workspaces] == ["default", "team-a"] 41 call_endpoint.assert_called_once() 42 kwargs = call_endpoint.call_args.kwargs 43 assert kwargs["host_creds"] is host_creds 44 assert kwargs["endpoint"] == WORKSPACES_ENDPOINT 45 assert kwargs["method"] == "GET" 46 assert kwargs["json_body"] is None 47 assert kwargs.get("expected_status", 200) == 200 48 49 50 def test_get_workspace_returns_entity(store, host_creds): 51 response = GetWorkspace.Response() 52 response.workspace.name = "team-b" 53 response.workspace.description = "Team B" 54 with mock.patch( 55 "mlflow.store.workspace.rest_store.call_endpoint", return_value=response 56 ) as call_endpoint: 57 workspace = store.get_workspace("team-b") 58 59 assert workspace.name == "team-b" 60 assert workspace.description == "Team B" 61 call_endpoint.assert_called_once() 62 kwargs = call_endpoint.call_args.kwargs 63 assert kwargs["endpoint"] == f"{WORKSPACES_ENDPOINT}/team-b" 64 assert kwargs["method"] == "GET" 65 66 67 def test_create_workspace_sends_payload(store, host_creds): 68 response = CreateWorkspace.Response() 69 response.workspace.name = "team-c" 70 response.workspace.description = "Team C" 71 with mock.patch( 72 "mlflow.store.workspace.rest_store.call_endpoint", return_value=response 73 ) as call_endpoint: 74 workspace = store.create_workspace(Workspace(name="team-c", description="Team C")) 75 76 assert workspace.name == "team-c" 77 assert workspace.description == "Team C" 78 call_endpoint.assert_called_once() 79 kwargs = call_endpoint.call_args.kwargs 80 assert kwargs["endpoint"] == WORKSPACES_ENDPOINT 81 assert kwargs["method"] == "POST" 82 assert kwargs["expected_status"] == 201 83 assert json.loads(kwargs["json_body"]) == {"name": "team-c", "description": "Team C"} 84 85 86 def test_create_workspace_conflict_raises_resource_exists(store, monkeypatch): 87 exc = RestException({"error_code": "RESOURCE_ALREADY_EXISTS", "message": "already exists"}) 88 monkeypatch.setattr( 89 "mlflow.store.workspace.rest_store.call_endpoint", 90 mock.Mock(side_effect=exc), 91 ) 92 93 with pytest.raises( 94 MlflowException, 95 match="already exists", 96 ) as exc_info: 97 store.create_workspace(Workspace(name="team-a")) 98 99 assert exc_info.value.error_code == "RESOURCE_ALREADY_EXISTS" 100 assert "already exists" in exc_info.value.message 101 102 103 def test_update_workspace_returns_new_description(store, host_creds): 104 response = UpdateWorkspace.Response() 105 response.workspace.name = "team-e" 106 response.workspace.description = "updated" 107 with mock.patch( 108 "mlflow.store.workspace.rest_store.call_endpoint", return_value=response 109 ) as call_endpoint: 110 workspace = store.update_workspace(Workspace(name="team-e", description="updated")) 111 112 assert workspace.description == "updated" 113 call_endpoint.assert_called_once() 114 kwargs = call_endpoint.call_args.kwargs 115 assert kwargs["endpoint"] == f"{WORKSPACES_ENDPOINT}/team-e" 116 assert kwargs["method"] == "PATCH" 117 assert json.loads(kwargs["json_body"]) == {"description": "updated"} 118 119 120 def test_delete_workspace_returns_on_success(store, host_creds): 121 response = DeleteWorkspace.Response() 122 with mock.patch( 123 "mlflow.store.workspace.rest_store.call_endpoint", return_value=response 124 ) as call_endpoint: 125 store.delete_workspace("team-f") 126 127 call_endpoint.assert_called_once() 128 kwargs = call_endpoint.call_args.kwargs 129 assert kwargs["endpoint"] == f"{WORKSPACES_ENDPOINT}/team-f" 130 assert kwargs["method"] == "DELETE" 131 assert kwargs["expected_status"] == 204 132 assert kwargs["json_body"] is None 133 134 135 @pytest.mark.parametrize( 136 ("mode", "expected_suffix"), 137 [ 138 (WorkspaceDeletionMode.RESTRICT, ""), 139 (WorkspaceDeletionMode.CASCADE, "?mode=CASCADE"), 140 (WorkspaceDeletionMode.SET_DEFAULT, "?mode=SET_DEFAULT"), 141 ], 142 ) 143 def test_delete_workspace_sends_mode_query_param(store, host_creds, mode, expected_suffix): 144 response = DeleteWorkspace.Response() 145 with mock.patch( 146 "mlflow.store.workspace.rest_store.call_endpoint", return_value=response 147 ) as call_endpoint: 148 store.delete_workspace("team-f", mode=mode) 149 150 kwargs = call_endpoint.call_args.kwargs 151 assert kwargs["endpoint"] == f"{WORKSPACES_ENDPOINT}/team-f{expected_suffix}" 152 153 154 def test_get_default_workspace_not_supported(store): 155 with pytest.raises( 156 NotImplementedError, 157 match="REST workspace provider does not expose a default workspace", 158 ): 159 store.get_default_workspace() 160 161 162 def test_rest_store_validates_workspace_names_before_http(monkeypatch, store): 163 mock_call = mock.Mock() 164 monkeypatch.setattr("mlflow.store.workspace.rest_store.call_endpoint", mock_call) 165 166 with pytest.raises(MlflowException, match="must match the pattern"): 167 store.get_workspace("Invalid") 168 169 mock_call.assert_not_called()