/ tests / store / model_registry / test_sqlalchemy_workspace_store.py
test_sqlalchemy_workspace_store.py
  1  import shutil
  2  import uuid
  3  
  4  import pytest
  5  
  6  from mlflow.entities.model_registry import ModelVersionTag, RegisteredModelTag
  7  from mlflow.entities.webhook import WebhookAction, WebhookEntity, WebhookEvent
  8  from mlflow.environment_variables import MLFLOW_ENABLE_WORKSPACES
  9  from mlflow.exceptions import MlflowException
 10  from mlflow.store.model_registry.sqlalchemy_store import SqlAlchemyStore
 11  from mlflow.store.model_registry.sqlalchemy_workspace_store import WorkspaceAwareSqlAlchemyStore
 12  from mlflow.utils.workspace_context import WorkspaceContext, clear_server_request_workspace
 13  from mlflow.utils.workspace_utils import DEFAULT_WORKSPACE_NAME
 14  
 15  
 16  @pytest.fixture
 17  def workspace_registry_store(db_uri, monkeypatch):
 18      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true")
 19      store = WorkspaceAwareSqlAlchemyStore(db_uri)
 20      try:
 21          yield store
 22      finally:
 23          store.engine.dispose()
 24  
 25  
 26  def _names_from_search(results):
 27      return {rm.name for rm in results}
 28  
 29  
 30  def test_registered_model_operations_are_workspace_scoped(workspace_registry_store):
 31      with WorkspaceContext("team-a"):
 32          workspace_registry_store.create_registered_model("alpha")
 33          workspace_registry_store.set_registered_model_tag(
 34              "alpha", RegisteredModelTag("owner", "team-a")
 35          )
 36          rm = workspace_registry_store.get_registered_model("alpha")
 37          assert rm.tags == {"owner": "team-a"}
 38  
 39      with WorkspaceContext("team-b"):
 40          workspace_registry_store.create_registered_model("beta")
 41          with pytest.raises(
 42              MlflowException, match="Registered Model with name=alpha not found"
 43          ) as excinfo:
 44              workspace_registry_store.set_registered_model_tag(
 45                  "alpha", RegisteredModelTag("owner", "team-b")
 46              )
 47          assert excinfo.value.error_code == "RESOURCE_DOES_NOT_EXIST"
 48          with pytest.raises(MlflowException, match="Registered Model with name=alpha not found"):
 49              workspace_registry_store.rename_registered_model("alpha", "alpha-b")
 50          with pytest.raises(MlflowException, match="Registered Model with name=alpha not found"):
 51              workspace_registry_store.delete_registered_model("alpha")
 52  
 53      with WorkspaceContext("team-b"):
 54          names = _names_from_search(workspace_registry_store.search_registered_models())
 55          assert names == {"beta"}
 56  
 57      with WorkspaceContext("team-b"):
 58          with pytest.raises(
 59              MlflowException, match="Registered Model with name=alpha-renamed not found"
 60          ) as excinfo:
 61              workspace_registry_store.set_registered_model_tag(
 62                  "alpha-renamed", RegisteredModelTag("owner", "team-b")
 63              )
 64          assert excinfo.value.error_code == "RESOURCE_DOES_NOT_EXIST"
 65          with pytest.raises(
 66              MlflowException, match="Registered Model with name=alpha-renamed not found"
 67          ):
 68              workspace_registry_store.get_registered_model("alpha-renamed")
 69          # Ensure team-b model remains accessible
 70          beta = workspace_registry_store.get_registered_model("beta")
 71          assert beta.name == "beta"
 72  
 73      with WorkspaceContext("team-a"):
 74          workspace_registry_store.rename_registered_model("alpha", "alpha-renamed")
 75          renamed = workspace_registry_store.get_registered_model("alpha-renamed")
 76          assert renamed.name == "alpha-renamed"
 77          assert renamed.tags == {"owner": "team-a"}
 78  
 79      with WorkspaceContext("team-b"):
 80          with pytest.raises(
 81              MlflowException, match="Registered Model with name=alpha-renamed not found"
 82          ) as excinfo:
 83              workspace_registry_store.set_registered_model_tag(
 84                  "alpha-renamed", RegisteredModelTag("owner", "team-b")
 85              )
 86          assert excinfo.value.error_code == "RESOURCE_DOES_NOT_EXIST"
 87          with pytest.raises(
 88              MlflowException, match="Registered Model with name=alpha-renamed not found"
 89          ):
 90              workspace_registry_store.get_registered_model("alpha-renamed")
 91  
 92      with WorkspaceContext("team-a"):
 93          workspace_registry_store.delete_registered_model("alpha-renamed")
 94          with pytest.raises(
 95              MlflowException, match="Registered Model with name=alpha-renamed not found"
 96          ) as excinfo:
 97              workspace_registry_store.get_registered_model("alpha-renamed")
 98          assert excinfo.value.error_code == "RESOURCE_DOES_NOT_EXIST"
 99  
100  
101  def test_model_version_operations_are_workspace_scoped(workspace_registry_store):
102      with WorkspaceContext("team-a"):
103          workspace_registry_store.create_registered_model("alpha")
104          mv_a = workspace_registry_store.create_model_version(
105              "alpha", "s3://team-a/model", run_id=uuid.uuid4().hex
106          )
107          assert mv_a.version == 1
108          workspace_registry_store.set_model_version_tag(
109              "alpha", str(mv_a.version), ModelVersionTag("env", "prod")
110          )
111          workspace_registry_store.transition_model_version_stage(
112              "alpha", str(mv_a.version), "Production", archive_existing_versions=False
113          )
114          workspace_registry_store.set_registered_model_alias(
115              "alpha", "production", str(mv_a.version)
116          )
117          mv_detail = workspace_registry_store.get_model_version("alpha", "1")
118          assert mv_detail.current_stage == "Production"
119          assert mv_detail.tags == {"env": "prod"}
120          aliases = workspace_registry_store.get_registered_model("alpha").aliases
121          assert aliases == {"production": 1}
122          download_uri = workspace_registry_store.get_model_version_download_uri("alpha", "1")
123          assert download_uri == "s3://team-a/model"
124  
125      with WorkspaceContext("team-b"):
126          workspace_registry_store.create_registered_model("beta")
127          with pytest.raises(
128              MlflowException, match="Registered Model with name=alpha not found"
129          ) as excinfo:
130              workspace_registry_store.create_model_version(
131                  "alpha", "s3://team-b/model", run_id=uuid.uuid4().hex
132              )
133          assert excinfo.value.error_code == "RESOURCE_DOES_NOT_EXIST"
134  
135          version_scoped_calls = (
136              lambda: workspace_registry_store.transition_model_version_stage(
137                  "alpha", "1", "Archived", archive_existing_versions=False
138              ),
139              lambda: workspace_registry_store.set_model_version_tag(
140                  "alpha", "1", ModelVersionTag("env", "stage")
141              ),
142              lambda: workspace_registry_store.delete_model_version_tag("alpha", "1", "env"),
143              lambda: workspace_registry_store.delete_model_version("alpha", "1"),
144              lambda: workspace_registry_store.get_model_version_download_uri("alpha", "1"),
145          )
146          for call in version_scoped_calls:
147              with pytest.raises(
148                  MlflowException, match=r"Model Version \(name=alpha, version=1\) not found"
149              ) as excinfo:
150                  call()
151              assert excinfo.value.error_code == "RESOURCE_DOES_NOT_EXIST"
152  
153          alias_scoped_calls = (
154              lambda: workspace_registry_store.set_registered_model_alias("alpha", "shadow", "1"),
155              lambda: workspace_registry_store.delete_registered_model_alias("alpha", "production"),
156          )
157          for call in alias_scoped_calls:
158              with pytest.raises(
159                  MlflowException,
160                  match=(
161                      r"(Model Version \(name=alpha, version=1\) not found|"
162                      r"Registered Model with name=alpha not found)"
163                  ),
164              ) as excinfo:
165                  call()
166              assert excinfo.value.error_code == "RESOURCE_DOES_NOT_EXIST"
167  
168      with WorkspaceContext("team-a"):
169          workspace_registry_store.delete_model_version_tag("alpha", "1", "env")
170          mv_detail = workspace_registry_store.get_model_version("alpha", "1")
171          assert mv_detail.tags == {}
172          workspace_registry_store.delete_registered_model_alias("alpha", "production")
173          assert workspace_registry_store.get_registered_model("alpha").aliases == {}
174          workspace_registry_store.delete_model_version("alpha", "1")
175          with pytest.raises(
176              MlflowException, match=r"Model Version \(name=alpha, version=1\) not found"
177          ) as excinfo:
178              workspace_registry_store.get_model_version("alpha", "1")
179          assert excinfo.value.error_code == "RESOURCE_DOES_NOT_EXIST"
180  
181  
182  def test_model_version_read_helpers_are_workspace_scoped(workspace_registry_store):
183      with WorkspaceContext("team-a"):
184          workspace_registry_store.create_registered_model("alpha")
185          workspace_registry_store.create_model_version(
186              "alpha", "s3://team-a/model", run_id=uuid.uuid4().hex
187          )
188          versions = workspace_registry_store.search_model_versions("name='alpha'")
189          assert [mv.version for mv in versions] == [1]
190          latest_versions = workspace_registry_store.get_latest_versions("alpha")
191          assert [mv.version for mv in latest_versions] == [1]
192          fetched = workspace_registry_store.get_model_version("alpha", "1")
193          assert fetched.version == 1
194  
195      with WorkspaceContext("team-b"):
196          assert workspace_registry_store.search_model_versions("name='alpha'") == []
197          with pytest.raises(
198              MlflowException, match=r"Model Version \(name=alpha, version=1\) not found"
199          ) as excinfo:
200              workspace_registry_store.get_model_version("alpha", "1")
201          assert excinfo.value.error_code == "RESOURCE_DOES_NOT_EXIST"
202          with pytest.raises(
203              MlflowException, match="Registered Model with name=alpha not found"
204          ) as excinfo:
205              workspace_registry_store.get_latest_versions("alpha")
206          assert excinfo.value.error_code == "RESOURCE_DOES_NOT_EXIST"
207  
208  
209  def test_same_model_name_allowed_in_different_workspaces(workspace_registry_store):
210      with WorkspaceContext("team-a"):
211          workspace_registry_store.create_registered_model("shared-name")
212      with WorkspaceContext("team-b"):
213          workspace_registry_store.create_registered_model("shared-name")
214          names = _names_from_search(workspace_registry_store.search_registered_models())
215          assert names == {"shared-name"}
216  
217      with WorkspaceContext("team-a"):
218          names = _names_from_search(workspace_registry_store.search_registered_models())
219          assert names == {"shared-name"}
220  
221  
222  def test_update_and_delete_registered_model_metadata_are_workspace_scoped(
223      workspace_registry_store,
224  ):
225      with WorkspaceContext("team-a"):
226          workspace_registry_store.create_registered_model("alpha")
227          workspace_registry_store.set_registered_model_tag(
228              "alpha", RegisteredModelTag("owner", "team-a")
229          )
230          updated = workspace_registry_store.update_registered_model("alpha", "updated desc")
231          assert updated.description == "updated desc"
232          workspace_registry_store.delete_registered_model_tag("alpha", "owner")
233          assert workspace_registry_store.get_registered_model("alpha").tags == {}
234  
235      with WorkspaceContext("team-b"):
236          with pytest.raises(
237              MlflowException, match="Registered Model with name=alpha not found"
238          ) as excinfo:
239              workspace_registry_store.update_registered_model("alpha", "hijacked")
240          assert excinfo.value.error_code == "RESOURCE_DOES_NOT_EXIST"
241          with pytest.raises(
242              MlflowException, match="Registered Model with name=alpha not found"
243          ) as excinfo:
244              workspace_registry_store.delete_registered_model_tag("alpha", "owner")
245          assert excinfo.value.error_code == "RESOURCE_DOES_NOT_EXIST"
246  
247  
248  def test_model_version_allows_workspace_scoped_proxied_artifacts(
249      workspace_registry_store, monkeypatch
250  ):
251      monkeypatch.setenv("_MLFLOW_SERVER_SERVE_ARTIFACTS", "true")
252      with WorkspaceContext("team-a"):
253          workspace_registry_store.create_registered_model("alpha")
254          mv = workspace_registry_store.create_model_version(
255              "alpha",
256              "mlflow-artifacts:/workspaces/team-a/models/model-a",
257              run_id=uuid.uuid4().hex,
258          )
259          assert mv.version == 1
260  
261  
262  def test_webhook_operations_are_workspace_scoped(workspace_registry_store):
263      event = WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED)
264  
265      with WorkspaceContext("team-a"):
266          webhook = workspace_registry_store.create_webhook(
267              name="team-a-hook",
268              url="https://example.com/hook",
269              events=[event],
270              description="Team A hook",
271          )
272          assert webhook.workspace == "team-a"
273          owned_hooks = workspace_registry_store.list_webhooks()
274          assert len(owned_hooks) == 1
275          assert owned_hooks[0].webhook_id == webhook.webhook_id
276          assert owned_hooks[0].workspace == "team-a"
277  
278      with WorkspaceContext("team-b"):
279          assert len(workspace_registry_store.list_webhooks()) == 0
280          assert (
281              len(
282                  workspace_registry_store.list_webhooks_by_event(
283                      event, max_results=10, page_token=None
284                  )
285              )
286              == 0
287          )
288          with pytest.raises(
289              MlflowException, match=f"Webhook with ID {webhook.webhook_id} not found"
290          ) as excinfo:
291              workspace_registry_store.get_webhook(webhook.webhook_id)
292          assert excinfo.value.error_code == "RESOURCE_DOES_NOT_EXIST"
293          with pytest.raises(
294              MlflowException, match=f"Webhook with ID {webhook.webhook_id} not found"
295          ) as excinfo:
296              workspace_registry_store.update_webhook(webhook.webhook_id, name="should-fail")
297          assert excinfo.value.error_code == "RESOURCE_DOES_NOT_EXIST"
298          with pytest.raises(
299              MlflowException, match=f"Webhook with ID {webhook.webhook_id} not found"
300          ) as excinfo:
301              workspace_registry_store.delete_webhook(webhook.webhook_id)
302          assert excinfo.value.error_code == "RESOURCE_DOES_NOT_EXIST"
303  
304      with WorkspaceContext("team-a"):
305          fetched = workspace_registry_store.get_webhook(webhook.webhook_id)
306          assert fetched.webhook_id == webhook.webhook_id
307          assert fetched.workspace == "team-a"
308          workspace_registry_store.delete_webhook(webhook.webhook_id)
309          with pytest.raises(
310              MlflowException, match=f"Webhook with ID {webhook.webhook_id} not found"
311          ) as excinfo:
312              workspace_registry_store.get_webhook(webhook.webhook_id)
313          assert excinfo.value.error_code == "RESOURCE_DOES_NOT_EXIST"
314  
315  
316  def test_default_workspace_behavior_when_workspaces_disabled(db_uri, monkeypatch):
317      monkeypatch.delenv(MLFLOW_ENABLE_WORKSPACES.name, raising=False)
318      clear_server_request_workspace()
319      store = SqlAlchemyStore(db_uri)
320      try:
321          rm = store.create_registered_model("legacy-model")
322          assert rm.name == "legacy-model"
323          fetched = store.get_registered_model("legacy-model")
324          assert fetched.name == "legacy-model"
325      finally:
326          store.engine.dispose()
327  
328  
329  def test_default_workspace_context_allows_operations(workspace_registry_store):
330      with WorkspaceContext(DEFAULT_WORKSPACE_NAME):
331          workspace_registry_store.create_registered_model("default-model")
332          fetched = workspace_registry_store.get_registered_model("default-model")
333          assert fetched.name == "default-model"
334  
335  
336  def test_single_tenant_registry_startup_rejects_non_default_workspace_models(
337      tmp_path, db_uri, cached_db, monkeypatch
338  ):
339      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true")
340      workspace_store = WorkspaceAwareSqlAlchemyStore(db_uri)
341  
342      with WorkspaceContext("team-startup"):
343          workspace_store.create_registered_model("team-model")
344  
345      workspace_store.engine.dispose()
346  
347      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "false")
348      with pytest.raises(
349          MlflowException,
350          match="Cannot disable workspaces because registered models exist outside the default "
351          + "workspace",
352      ) as excinfo:
353          SqlAlchemyStore(db_uri)
354  
355      assert excinfo.value.error_code == "INVALID_STATE"
356  
357      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true")
358      webhook_db_path = tmp_path / "registry_webhook.db"
359      shutil.copy2(cached_db, webhook_db_path)
360      webhook_db_uri = f"sqlite:///{webhook_db_path}"
361      webhook_store = WorkspaceAwareSqlAlchemyStore(webhook_db_uri)
362      webhook_event = WebhookEvent(WebhookEntity.REGISTERED_MODEL, WebhookAction.CREATED)
363  
364      with WorkspaceContext("team-webhook"):
365          webhook_store.create_webhook(
366              name="team-webhook",
367              url="https://example.com/webhook",
368              events=[webhook_event],
369              description="non-default webhook",
370          )
371  
372      webhook_store.engine.dispose()
373  
374      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "false")
375      with pytest.raises(
376          MlflowException,
377          match="Cannot disable workspaces because webhooks exist outside the default workspace",
378      ) as excinfo:
379          SqlAlchemyStore(webhook_db_uri)
380  
381      assert excinfo.value.error_code == "INVALID_STATE"