/ tests / tracking / context / test_registry.py
test_registry.py
  1  from importlib import reload
  2  from unittest import mock
  3  
  4  import pytest
  5  
  6  import mlflow.tracking.context.registry
  7  from mlflow.tracking.context.databricks_job_context import DatabricksJobRunContext
  8  from mlflow.tracking.context.databricks_notebook_context import DatabricksNotebookRunContext
  9  from mlflow.tracking.context.databricks_repo_context import DatabricksRepoRunContext
 10  from mlflow.tracking.context.default_context import DefaultRunContext
 11  from mlflow.tracking.context.git_context import GitRunContext
 12  from mlflow.tracking.context.jupyter_notebook_context import JupyterNotebookRunContext
 13  from mlflow.tracking.context.registry import RunContextProviderRegistry, resolve_tags
 14  
 15  
 16  def test_run_context_provider_registry_register():
 17      provider_class = mock.Mock()
 18  
 19      registry = RunContextProviderRegistry()
 20      registry.register(provider_class)
 21  
 22      assert set(registry) == {provider_class.return_value}
 23  
 24  
 25  def test_run_context_provider_registry_register_entrypoints():
 26      provider_class = mock.Mock()
 27      mock_entrypoint = mock.Mock()
 28      mock_entrypoint.load.return_value = provider_class
 29  
 30      with mock.patch(
 31          "mlflow.utils.plugins._get_entry_points", return_value=[mock_entrypoint]
 32      ) as mock_get_group_all:
 33          registry = RunContextProviderRegistry()
 34          registry.register_entrypoints()
 35  
 36      assert set(registry) == {provider_class.return_value}
 37      mock_entrypoint.load.assert_called_once_with()
 38      mock_get_group_all.assert_called_once_with("mlflow.run_context_provider")
 39  
 40  
 41  @pytest.mark.parametrize(
 42      "exception", [AttributeError("test exception"), ImportError("test exception")]
 43  )
 44  def test_run_context_provider_registry_register_entrypoints_handles_exception(exception):
 45      mock_entrypoint = mock.Mock()
 46      mock_entrypoint.load.side_effect = exception
 47  
 48      with mock.patch(
 49          "mlflow.utils.plugins._get_entry_points", return_value=[mock_entrypoint]
 50      ) as mock_get_group_all:
 51          registry = RunContextProviderRegistry()
 52          # Check that the raised warning contains the message from the original exception
 53          with pytest.warns(UserWarning, match="test exception"):
 54              registry.register_entrypoints()
 55  
 56      mock_entrypoint.load.assert_called_once_with()
 57      mock_get_group_all.assert_called_once_with("mlflow.run_context_provider")
 58  
 59  
 60  def _currently_registered_run_context_provider_classes():
 61      return {
 62          provider.__class__
 63          for provider in mlflow.tracking.context.registry._run_context_provider_registry
 64      }
 65  
 66  
 67  def test_registry_instance_defaults():
 68      expected_classes = {
 69          DefaultRunContext,
 70          GitRunContext,
 71          JupyterNotebookRunContext,
 72          DatabricksNotebookRunContext,
 73          DatabricksJobRunContext,
 74          DatabricksRepoRunContext,
 75      }
 76      assert expected_classes.issubset(_currently_registered_run_context_provider_classes())
 77  
 78  
 79  def test_registry_instance_loads_entrypoints():
 80      class MockRunContext:
 81          pass
 82  
 83      mock_entrypoint = mock.Mock()
 84      mock_entrypoint.load.return_value = MockRunContext
 85  
 86      with mock.patch(
 87          "mlflow.utils.plugins._get_entry_points", return_value=[mock_entrypoint]
 88      ) as mock_get_group_all:
 89          # Entrypoints are registered at import time, so we need to reload the module to register the
 90          # entrypoint given by the mocked entrypoints.get_group_all
 91          reload(mlflow.tracking.context.registry)
 92  
 93      assert MockRunContext in _currently_registered_run_context_provider_classes()
 94      mock_get_group_all.assert_called_once_with("mlflow.run_context_provider")
 95  
 96  
 97  def test_run_context_provider_registry_with_installed_plugin(tmp_path, monkeypatch):
 98      monkeypatch.chdir(tmp_path)
 99  
100      reload(mlflow.tracking.context.registry)
101  
102      from mlflow_test_plugin.run_context_provider import PluginRunContextProvider
103  
104      assert PluginRunContextProvider in _currently_registered_run_context_provider_classes()
105  
106      # The test plugin's context provider always returns False from in_context
107      # to avoid polluting tags in developers' environments. The following mock overrides this to
108      # perform the integration test.
109      with mock.patch.object(PluginRunContextProvider, "in_context", return_value=True):
110          assert resolve_tags()["test"] == "tag"
111  
112  
113  @pytest.fixture
114  def mock_run_context_providers():
115      base_provider = mock.Mock()
116      base_provider.in_context.return_value = True
117      base_provider.tags.return_value = {"one": "one-val", "two": "two-val", "three": "three-val"}
118  
119      skipped_provider = mock.Mock()
120      skipped_provider.in_context.return_value = False
121  
122      exception_provider = mock.Mock()
123      exception_provider.in_context.return_value = True
124      exception_provider.tags.return_value = {
125          "random-key": "This val will never make it to tag resolution"
126      }
127      exception_provider.tags.side_effect = Exception(
128          "This should be caught by logic in resolve_tags()"
129      )
130  
131      override_provider = mock.Mock()
132      override_provider.in_context.return_value = True
133      override_provider.tags.return_value = {"one": "override", "new": "new-val"}
134  
135      providers = [base_provider, skipped_provider, exception_provider, override_provider]
136  
137      with mock.patch("mlflow.tracking.context.registry._run_context_provider_registry", providers):
138          yield
139  
140      skipped_provider.tags.assert_not_called()
141  
142  
143  def test_resolve_tags(mock_run_context_providers):
144      tags_arg = {"two": "arg-override", "arg": "arg-val"}
145      assert resolve_tags(tags_arg) == {
146          "one": "override",
147          "two": "arg-override",
148          "three": "three-val",
149          "new": "new-val",
150          "arg": "arg-val",
151      }
152  
153  
154  def test_resolve_tags_no_arg(mock_run_context_providers):
155      assert resolve_tags() == {
156          "one": "override",
157          "two": "two-val",
158          "three": "three-val",
159          "new": "new-val",
160      }