test_meta_dataset.py
1 import json 2 from unittest.mock import patch 3 4 import pytest 5 6 pd = pytest.importorskip("pandas") 7 8 from mlflow.data.delta_dataset_source import DeltaDatasetSource 9 from mlflow.data.http_dataset_source import HTTPDatasetSource 10 from mlflow.data.huggingface_dataset_source import HuggingFaceDatasetSource 11 from mlflow.data.meta_dataset import MetaDataset 12 from mlflow.data.pandas_dataset import from_pandas 13 from mlflow.data.uc_volume_dataset_source import UCVolumeDatasetSource 14 from mlflow.exceptions import MlflowException 15 from mlflow.types import DataType 16 from mlflow.types.schema import ColSpec, Schema 17 18 19 @pytest.mark.parametrize( 20 ("dataset_source_class", "path"), 21 [ 22 (HTTPDatasetSource, "test:/my/test/uri"), 23 (DeltaDatasetSource, "fake/path/to/delta"), 24 (HuggingFaceDatasetSource, "databricks/databricks-dolly-15k"), 25 ], 26 ) 27 def test_create_meta_dataset_from_source(dataset_source_class, path): 28 source = dataset_source_class(path) 29 dataset = MetaDataset(source=source) 30 31 json_str = dataset.to_json() 32 parsed_json = json.loads(json_str) 33 34 assert parsed_json["digest"] is not None 35 assert path in parsed_json["source"] 36 assert parsed_json["source_type"] == dataset_source_class._get_source_type() 37 38 39 @pytest.mark.parametrize( 40 ("dataset_source_class", "path"), 41 [ 42 (HTTPDatasetSource, "test:/my/test/uri"), 43 (DeltaDatasetSource, "fake/path/to/delta"), 44 (HuggingFaceDatasetSource, "databricks/databricks-dolly-15k"), 45 ], 46 ) 47 def test_create_meta_dataset_from_source_with_schema(dataset_source_class, path): 48 source = dataset_source_class(path) 49 schema = Schema([ 50 ColSpec(type=DataType.long, name="foo"), 51 ColSpec(type=DataType.integer, name="bar"), 52 ]) 53 dataset = MetaDataset(source=source, schema=schema) 54 55 json_str = dataset.to_json() 56 parsed_json = json.loads(json_str) 57 58 assert parsed_json["digest"] is not None 59 assert path in parsed_json["source"] 60 assert parsed_json["source_type"] == dataset_source_class._get_source_type() 61 assert json.loads(parsed_json["schema"])["mlflow_colspec"] == schema.to_dict() 62 63 64 def test_meta_dataset_digest(): 65 http_source = HTTPDatasetSource("test:/my/test/uri") 66 dataset1 = MetaDataset(source=http_source) 67 schema = Schema([ 68 ColSpec(type=DataType.long, name="foo"), 69 ColSpec(type=DataType.integer, name="bar"), 70 ]) 71 dataset2 = MetaDataset(source=http_source, schema=schema) 72 73 assert dataset1.digest != dataset2.digest 74 75 delta_source = DeltaDatasetSource("fake/path/to/delta") 76 dataset3 = MetaDataset(source=delta_source) 77 assert dataset1.digest != dataset3.digest 78 79 80 def test_meta_dataset_with_uc_source(): 81 path = "/Volumes/dummy_catalog/dummy_schema/dummy_volume/tmp.yaml" 82 83 with ( 84 patch( 85 "mlflow.data.uc_volume_dataset_source.UCVolumeDatasetSource._verify_uc_path_is_valid", 86 side_effect=MlflowException(f"{path} does not exist in Databricks Unified Catalog."), 87 ), 88 pytest.raises( 89 MlflowException, match=f"{path} does not exist in Databricks Unified Catalog." 90 ), 91 ): 92 uc_volume_source = UCVolumeDatasetSource(path) 93 94 with patch( 95 "mlflow.data.uc_volume_dataset_source.UCVolumeDatasetSource._verify_uc_path_is_valid", 96 ): 97 uc_volume_source = UCVolumeDatasetSource(path) 98 dataset = MetaDataset(source=uc_volume_source) 99 json_str = dataset.to_json() 100 parsed_json = json.loads(json_str) 101 102 assert parsed_json["digest"] is not None 103 assert path in parsed_json["source"] 104 assert parsed_json["source_type"] == "uc_volume" 105 106 107 def test_create_meta_dataset_from_dataset(): 108 pandas_dataset = from_pandas( 109 df=pd.DataFrame({"a": [1, 2, 3]}), 110 source="/tmp/test.csv", 111 ) 112 113 meta_dataset = MetaDataset(source=pandas_dataset) 114 115 parsed_json = json.loads(meta_dataset.to_json()) 116 117 assert parsed_json["source_type"] == pandas_dataset._get_source_type() 118 dataset_json = json.loads(parsed_json["source"]) 119 assert dataset_json["source_type"] == pandas_dataset.source._get_source_type()