/ tests / tracking / test_workspace_registry.py
test_workspace_registry.py
 1  from __future__ import annotations
 2  
 3  import pytest
 4  
 5  from mlflow.store.workspace.rest_store import RestWorkspaceStore
 6  from mlflow.store.workspace.sqlalchemy_store import SqlAlchemyStore
 7  from mlflow.tracking._workspace.registry import (
 8      UnsupportedWorkspaceStoreURIException,
 9      _get_workspace_store_registry,
10      get_workspace_store,
11  )
12  
13  
14  @pytest.fixture(autouse=True)
15  def _clear_workspace_store_cache():
16      registry = _get_workspace_store_registry()
17      registry._get_store_with_resolved_uri.cache_clear()
18      yield
19      registry._get_store_with_resolved_uri.cache_clear()
20  
21  
22  def test_get_workspace_store_resolves_sqlalchemy(tmp_path):
23      workspace_uri = f"sqlite:///{tmp_path / 'workspace.db'}"
24      store = get_workspace_store(workspace_uri=workspace_uri)
25      assert isinstance(store, SqlAlchemyStore)
26      assert store._workspace_uri == workspace_uri
27      store._engine.dispose()
28  
29  
30  def test_get_workspace_store_resolves_rest():
31      store = get_workspace_store(workspace_uri="http://example.com")
32      assert isinstance(store, RestWorkspaceStore)
33  
34  
35  def test_get_workspace_store_unsupported_scheme():
36      with pytest.raises(
37          UnsupportedWorkspaceStoreURIException,
38          match="got unsupported URI 'foo://workspace'",
39      ):
40          get_workspace_store(workspace_uri="foo://workspace")