test_dataset_registry.py
1 from unittest import mock 2 3 import pytest 4 5 import mlflow.data 6 from mlflow.data.dataset import Dataset 7 from mlflow.data.dataset_registry import DatasetRegistry, register_constructor 8 from mlflow.data.dataset_source_registry import DatasetSourceRegistry, resolve_dataset_source 9 from mlflow.exceptions import MlflowException 10 11 from tests.resources.data.dataset import SampleDataset 12 from tests.resources.data.dataset_source import SampleDatasetSource 13 14 15 @pytest.fixture 16 def dataset_source_registry(): 17 registry = DatasetSourceRegistry() 18 with mock.patch("mlflow.data.dataset_source_registry._dataset_source_registry", wraps=registry): 19 yield registry 20 21 22 @pytest.fixture 23 def dataset_registry(): 24 registry = DatasetRegistry() 25 with mock.patch("mlflow.data.dataset_registry._dataset_registry", wraps=registry): 26 yield registry 27 28 29 def test_register_constructor_function_performs_validation(): 30 registry = DatasetRegistry() 31 32 def from_good_function( 33 path: str, 34 name: str | None = None, 35 digest: str | None = None, 36 ) -> Dataset: 37 pass 38 39 registry.register_constructor(from_good_function) 40 41 def bad_name_fn( 42 name: str | None = None, 43 digest: str | None = None, 44 ) -> Dataset: 45 pass 46 47 with pytest.raises(MlflowException, match="Constructor name must start with"): 48 registry.register_constructor(bad_name_fn) 49 50 with pytest.raises(MlflowException, match="Constructor name must start with"): 51 registry.register_constructor( 52 constructor_fn=from_good_function, constructor_name="bad_name" 53 ) 54 55 def from_no_name_fn( 56 digest: str | None = None, 57 ) -> Dataset: 58 pass 59 60 with pytest.raises(MlflowException, match="must define an optional parameter named 'name'"): 61 registry.register_constructor(from_no_name_fn) 62 63 def from_no_digest_fn( 64 name: str | None = None, 65 ) -> Dataset: 66 pass 67 68 with pytest.raises(MlflowException, match="must define an optional parameter named 'digest'"): 69 registry.register_constructor(from_no_digest_fn) 70 71 def from_bad_return_type_fn( 72 path: str, 73 name: str | None = None, 74 digest: str | None = None, 75 ) -> str: 76 pass 77 78 with pytest.raises(MlflowException, match="must have a return type annotation.*Dataset"): 79 registry.register_constructor(from_bad_return_type_fn) 80 81 def from_no_return_type_fn( 82 path: str, 83 name: str | None = None, 84 digest: str | None = None, 85 ): 86 pass 87 88 with pytest.raises(MlflowException, match="must have a return type annotation.*Dataset"): 89 registry.register_constructor(from_no_return_type_fn) 90 91 92 def test_register_constructor_from_entrypoints_and_call(dataset_registry, tmp_path): 93 from mlflow_test_plugin.dummy_dataset import DummyDataset 94 95 dataset_registry.register_entrypoints() 96 97 dataset = mlflow.data.from_dummy( 98 data_list=[1, 2, 3], 99 # Use a DummyDatasetSource URI from mlflow_test_plugin.dummy_dataset_source, which 100 # is registered as an entrypoint whenever mlflow-test-plugin is installed 101 source="dummy:" + str(tmp_path), 102 name="dataset_name", 103 digest="foo", 104 ) 105 assert isinstance(dataset, DummyDataset) 106 assert dataset.data_list == [1, 2, 3] 107 assert dataset.name == "dataset_name" 108 assert dataset.digest == "foo" 109 110 111 def test_register_constructor_and_call(dataset_registry, dataset_source_registry, tmp_path): 112 dataset_source_registry.register(SampleDatasetSource) 113 114 def from_test(data_list, source, name=None, digest=None) -> SampleDataset: 115 resolved_source: SampleDatasetSource = resolve_dataset_source( 116 source, candidate_sources=[SampleDatasetSource] 117 ) 118 return SampleDataset(data_list=data_list, source=resolved_source, name=name, digest=digest) 119 120 register_constructor(constructor_fn=from_test) 121 register_constructor(constructor_name="from_test_2", constructor_fn=from_test) 122 123 dataset1 = mlflow.data.from_test( 124 data_list=[1, 2, 3], 125 # Use a SampleDatasetSourceURI 126 source="test:" + str(tmp_path), 127 name="name1", 128 digest="digest1", 129 ) 130 assert isinstance(dataset1, SampleDataset) 131 assert dataset1.data_list == [1, 2, 3] 132 assert dataset1.name == "name1" 133 assert dataset1.digest == "digest1" 134 135 dataset2 = mlflow.data.from_test_2( 136 data_list=[4, 5, 6], 137 # Use a SampleDatasetSourceURI 138 source="test:" + str(tmp_path), 139 name="name2", 140 digest="digest2", 141 ) 142 assert isinstance(dataset2, SampleDataset) 143 assert dataset2.data_list == [4, 5, 6] 144 assert dataset2.name == "name2" 145 assert dataset2.digest == "digest2" 146 147 148 def test_dataset_source_registration_failure(dataset_source_registry): 149 with mock.patch.object(dataset_source_registry, "register", side_effect=ImportError("Error")): 150 with pytest.warns(UserWarning, match="Failure attempting to register dataset constructor"): 151 dataset_source_registry.register_entrypoints()