test_registry.py
1 from importlib import reload 2 from unittest import mock 3 4 import pytest 5 6 import mlflow.tracking.context.registry 7 from mlflow.tracking.context.databricks_job_context import DatabricksJobRunContext 8 from mlflow.tracking.context.databricks_notebook_context import DatabricksNotebookRunContext 9 from mlflow.tracking.context.databricks_repo_context import DatabricksRepoRunContext 10 from mlflow.tracking.context.default_context import DefaultRunContext 11 from mlflow.tracking.context.git_context import GitRunContext 12 from mlflow.tracking.context.jupyter_notebook_context import JupyterNotebookRunContext 13 from mlflow.tracking.context.registry import RunContextProviderRegistry, resolve_tags 14 15 16 def test_run_context_provider_registry_register(): 17 provider_class = mock.Mock() 18 19 registry = RunContextProviderRegistry() 20 registry.register(provider_class) 21 22 assert set(registry) == {provider_class.return_value} 23 24 25 def test_run_context_provider_registry_register_entrypoints(): 26 provider_class = mock.Mock() 27 mock_entrypoint = mock.Mock() 28 mock_entrypoint.load.return_value = provider_class 29 30 with mock.patch( 31 "mlflow.utils.plugins._get_entry_points", return_value=[mock_entrypoint] 32 ) as mock_get_group_all: 33 registry = RunContextProviderRegistry() 34 registry.register_entrypoints() 35 36 assert set(registry) == {provider_class.return_value} 37 mock_entrypoint.load.assert_called_once_with() 38 mock_get_group_all.assert_called_once_with("mlflow.run_context_provider") 39 40 41 @pytest.mark.parametrize( 42 "exception", [AttributeError("test exception"), ImportError("test exception")] 43 ) 44 def test_run_context_provider_registry_register_entrypoints_handles_exception(exception): 45 mock_entrypoint = mock.Mock() 46 mock_entrypoint.load.side_effect = exception 47 48 with mock.patch( 49 "mlflow.utils.plugins._get_entry_points", return_value=[mock_entrypoint] 50 ) as mock_get_group_all: 51 registry = RunContextProviderRegistry() 52 # Check that the raised warning contains the message from the original exception 53 with pytest.warns(UserWarning, match="test exception"): 54 registry.register_entrypoints() 55 56 mock_entrypoint.load.assert_called_once_with() 57 mock_get_group_all.assert_called_once_with("mlflow.run_context_provider") 58 59 60 def _currently_registered_run_context_provider_classes(): 61 return { 62 provider.__class__ 63 for provider in mlflow.tracking.context.registry._run_context_provider_registry 64 } 65 66 67 def test_registry_instance_defaults(): 68 expected_classes = { 69 DefaultRunContext, 70 GitRunContext, 71 JupyterNotebookRunContext, 72 DatabricksNotebookRunContext, 73 DatabricksJobRunContext, 74 DatabricksRepoRunContext, 75 } 76 assert expected_classes.issubset(_currently_registered_run_context_provider_classes()) 77 78 79 def test_registry_instance_loads_entrypoints(): 80 class MockRunContext: 81 pass 82 83 mock_entrypoint = mock.Mock() 84 mock_entrypoint.load.return_value = MockRunContext 85 86 with mock.patch( 87 "mlflow.utils.plugins._get_entry_points", return_value=[mock_entrypoint] 88 ) as mock_get_group_all: 89 # Entrypoints are registered at import time, so we need to reload the module to register the 90 # entrypoint given by the mocked entrypoints.get_group_all 91 reload(mlflow.tracking.context.registry) 92 93 assert MockRunContext in _currently_registered_run_context_provider_classes() 94 mock_get_group_all.assert_called_once_with("mlflow.run_context_provider") 95 96 97 def test_run_context_provider_registry_with_installed_plugin(tmp_path, monkeypatch): 98 monkeypatch.chdir(tmp_path) 99 100 reload(mlflow.tracking.context.registry) 101 102 from mlflow_test_plugin.run_context_provider import PluginRunContextProvider 103 104 assert PluginRunContextProvider in _currently_registered_run_context_provider_classes() 105 106 # The test plugin's context provider always returns False from in_context 107 # to avoid polluting tags in developers' environments. The following mock overrides this to 108 # perform the integration test. 109 with mock.patch.object(PluginRunContextProvider, "in_context", return_value=True): 110 assert resolve_tags()["test"] == "tag" 111 112 113 @pytest.fixture 114 def mock_run_context_providers(): 115 base_provider = mock.Mock() 116 base_provider.in_context.return_value = True 117 base_provider.tags.return_value = {"one": "one-val", "two": "two-val", "three": "three-val"} 118 119 skipped_provider = mock.Mock() 120 skipped_provider.in_context.return_value = False 121 122 exception_provider = mock.Mock() 123 exception_provider.in_context.return_value = True 124 exception_provider.tags.return_value = { 125 "random-key": "This val will never make it to tag resolution" 126 } 127 exception_provider.tags.side_effect = Exception( 128 "This should be caught by logic in resolve_tags()" 129 ) 130 131 override_provider = mock.Mock() 132 override_provider.in_context.return_value = True 133 override_provider.tags.return_value = {"one": "override", "new": "new-val"} 134 135 providers = [base_provider, skipped_provider, exception_provider, override_provider] 136 137 with mock.patch("mlflow.tracking.context.registry._run_context_provider_registry", providers): 138 yield 139 140 skipped_provider.tags.assert_not_called() 141 142 143 def test_resolve_tags(mock_run_context_providers): 144 tags_arg = {"two": "arg-override", "arg": "arg-val"} 145 assert resolve_tags(tags_arg) == { 146 "one": "override", 147 "two": "arg-override", 148 "three": "three-val", 149 "new": "new-val", 150 "arg": "arg-val", 151 } 152 153 154 def test_resolve_tags_no_arg(mock_run_context_providers): 155 assert resolve_tags() == { 156 "one": "override", 157 "two": "two-val", 158 "three": "three-val", 159 "new": "new-val", 160 }