/ tests / data / test_dataset_registry.py
test_dataset_registry.py
  1  from unittest import mock
  2  
  3  import pytest
  4  
  5  import mlflow.data
  6  from mlflow.data.dataset import Dataset
  7  from mlflow.data.dataset_registry import DatasetRegistry, register_constructor
  8  from mlflow.data.dataset_source_registry import DatasetSourceRegistry, resolve_dataset_source
  9  from mlflow.exceptions import MlflowException
 10  
 11  from tests.resources.data.dataset import SampleDataset
 12  from tests.resources.data.dataset_source import SampleDatasetSource
 13  
 14  
 15  @pytest.fixture
 16  def dataset_source_registry():
 17      registry = DatasetSourceRegistry()
 18      with mock.patch("mlflow.data.dataset_source_registry._dataset_source_registry", wraps=registry):
 19          yield registry
 20  
 21  
 22  @pytest.fixture
 23  def dataset_registry():
 24      registry = DatasetRegistry()
 25      with mock.patch("mlflow.data.dataset_registry._dataset_registry", wraps=registry):
 26          yield registry
 27  
 28  
 29  def test_register_constructor_function_performs_validation():
 30      registry = DatasetRegistry()
 31  
 32      def from_good_function(
 33          path: str,
 34          name: str | None = None,
 35          digest: str | None = None,
 36      ) -> Dataset:
 37          pass
 38  
 39      registry.register_constructor(from_good_function)
 40  
 41      def bad_name_fn(
 42          name: str | None = None,
 43          digest: str | None = None,
 44      ) -> Dataset:
 45          pass
 46  
 47      with pytest.raises(MlflowException, match="Constructor name must start with"):
 48          registry.register_constructor(bad_name_fn)
 49  
 50      with pytest.raises(MlflowException, match="Constructor name must start with"):
 51          registry.register_constructor(
 52              constructor_fn=from_good_function, constructor_name="bad_name"
 53          )
 54  
 55      def from_no_name_fn(
 56          digest: str | None = None,
 57      ) -> Dataset:
 58          pass
 59  
 60      with pytest.raises(MlflowException, match="must define an optional parameter named 'name'"):
 61          registry.register_constructor(from_no_name_fn)
 62  
 63      def from_no_digest_fn(
 64          name: str | None = None,
 65      ) -> Dataset:
 66          pass
 67  
 68      with pytest.raises(MlflowException, match="must define an optional parameter named 'digest'"):
 69          registry.register_constructor(from_no_digest_fn)
 70  
 71      def from_bad_return_type_fn(
 72          path: str,
 73          name: str | None = None,
 74          digest: str | None = None,
 75      ) -> str:
 76          pass
 77  
 78      with pytest.raises(MlflowException, match="must have a return type annotation.*Dataset"):
 79          registry.register_constructor(from_bad_return_type_fn)
 80  
 81      def from_no_return_type_fn(
 82          path: str,
 83          name: str | None = None,
 84          digest: str | None = None,
 85      ):
 86          pass
 87  
 88      with pytest.raises(MlflowException, match="must have a return type annotation.*Dataset"):
 89          registry.register_constructor(from_no_return_type_fn)
 90  
 91  
 92  def test_register_constructor_from_entrypoints_and_call(dataset_registry, tmp_path):
 93      from mlflow_test_plugin.dummy_dataset import DummyDataset
 94  
 95      dataset_registry.register_entrypoints()
 96  
 97      dataset = mlflow.data.from_dummy(
 98          data_list=[1, 2, 3],
 99          # Use a DummyDatasetSource URI from mlflow_test_plugin.dummy_dataset_source, which
100          # is registered as an entrypoint whenever mlflow-test-plugin is installed
101          source="dummy:" + str(tmp_path),
102          name="dataset_name",
103          digest="foo",
104      )
105      assert isinstance(dataset, DummyDataset)
106      assert dataset.data_list == [1, 2, 3]
107      assert dataset.name == "dataset_name"
108      assert dataset.digest == "foo"
109  
110  
111  def test_register_constructor_and_call(dataset_registry, dataset_source_registry, tmp_path):
112      dataset_source_registry.register(SampleDatasetSource)
113  
114      def from_test(data_list, source, name=None, digest=None) -> SampleDataset:
115          resolved_source: SampleDatasetSource = resolve_dataset_source(
116              source, candidate_sources=[SampleDatasetSource]
117          )
118          return SampleDataset(data_list=data_list, source=resolved_source, name=name, digest=digest)
119  
120      register_constructor(constructor_fn=from_test)
121      register_constructor(constructor_name="from_test_2", constructor_fn=from_test)
122  
123      dataset1 = mlflow.data.from_test(
124          data_list=[1, 2, 3],
125          # Use a SampleDatasetSourceURI
126          source="test:" + str(tmp_path),
127          name="name1",
128          digest="digest1",
129      )
130      assert isinstance(dataset1, SampleDataset)
131      assert dataset1.data_list == [1, 2, 3]
132      assert dataset1.name == "name1"
133      assert dataset1.digest == "digest1"
134  
135      dataset2 = mlflow.data.from_test_2(
136          data_list=[4, 5, 6],
137          # Use a SampleDatasetSourceURI
138          source="test:" + str(tmp_path),
139          name="name2",
140          digest="digest2",
141      )
142      assert isinstance(dataset2, SampleDataset)
143      assert dataset2.data_list == [4, 5, 6]
144      assert dataset2.name == "name2"
145      assert dataset2.digest == "digest2"
146  
147  
148  def test_dataset_source_registration_failure(dataset_source_registry):
149      with mock.patch.object(dataset_source_registry, "register", side_effect=ImportError("Error")):
150          with pytest.warns(UserWarning, match="Failure attempting to register dataset constructor"):
151              dataset_source_registry.register_entrypoints()