/ tests / data / test_meta_dataset.py
test_meta_dataset.py
  1  import json
  2  from unittest.mock import patch
  3  
  4  import pytest
  5  
  6  pd = pytest.importorskip("pandas")
  7  
  8  from mlflow.data.delta_dataset_source import DeltaDatasetSource
  9  from mlflow.data.http_dataset_source import HTTPDatasetSource
 10  from mlflow.data.huggingface_dataset_source import HuggingFaceDatasetSource
 11  from mlflow.data.meta_dataset import MetaDataset
 12  from mlflow.data.pandas_dataset import from_pandas
 13  from mlflow.data.uc_volume_dataset_source import UCVolumeDatasetSource
 14  from mlflow.exceptions import MlflowException
 15  from mlflow.types import DataType
 16  from mlflow.types.schema import ColSpec, Schema
 17  
 18  
 19  @pytest.mark.parametrize(
 20      ("dataset_source_class", "path"),
 21      [
 22          (HTTPDatasetSource, "test:/my/test/uri"),
 23          (DeltaDatasetSource, "fake/path/to/delta"),
 24          (HuggingFaceDatasetSource, "databricks/databricks-dolly-15k"),
 25      ],
 26  )
 27  def test_create_meta_dataset_from_source(dataset_source_class, path):
 28      source = dataset_source_class(path)
 29      dataset = MetaDataset(source=source)
 30  
 31      json_str = dataset.to_json()
 32      parsed_json = json.loads(json_str)
 33  
 34      assert parsed_json["digest"] is not None
 35      assert path in parsed_json["source"]
 36      assert parsed_json["source_type"] == dataset_source_class._get_source_type()
 37  
 38  
 39  @pytest.mark.parametrize(
 40      ("dataset_source_class", "path"),
 41      [
 42          (HTTPDatasetSource, "test:/my/test/uri"),
 43          (DeltaDatasetSource, "fake/path/to/delta"),
 44          (HuggingFaceDatasetSource, "databricks/databricks-dolly-15k"),
 45      ],
 46  )
 47  def test_create_meta_dataset_from_source_with_schema(dataset_source_class, path):
 48      source = dataset_source_class(path)
 49      schema = Schema([
 50          ColSpec(type=DataType.long, name="foo"),
 51          ColSpec(type=DataType.integer, name="bar"),
 52      ])
 53      dataset = MetaDataset(source=source, schema=schema)
 54  
 55      json_str = dataset.to_json()
 56      parsed_json = json.loads(json_str)
 57  
 58      assert parsed_json["digest"] is not None
 59      assert path in parsed_json["source"]
 60      assert parsed_json["source_type"] == dataset_source_class._get_source_type()
 61      assert json.loads(parsed_json["schema"])["mlflow_colspec"] == schema.to_dict()
 62  
 63  
 64  def test_meta_dataset_digest():
 65      http_source = HTTPDatasetSource("test:/my/test/uri")
 66      dataset1 = MetaDataset(source=http_source)
 67      schema = Schema([
 68          ColSpec(type=DataType.long, name="foo"),
 69          ColSpec(type=DataType.integer, name="bar"),
 70      ])
 71      dataset2 = MetaDataset(source=http_source, schema=schema)
 72  
 73      assert dataset1.digest != dataset2.digest
 74  
 75      delta_source = DeltaDatasetSource("fake/path/to/delta")
 76      dataset3 = MetaDataset(source=delta_source)
 77      assert dataset1.digest != dataset3.digest
 78  
 79  
 80  def test_meta_dataset_with_uc_source():
 81      path = "/Volumes/dummy_catalog/dummy_schema/dummy_volume/tmp.yaml"
 82  
 83      with (
 84          patch(
 85              "mlflow.data.uc_volume_dataset_source.UCVolumeDatasetSource._verify_uc_path_is_valid",
 86              side_effect=MlflowException(f"{path} does not exist in Databricks Unified Catalog."),
 87          ),
 88          pytest.raises(
 89              MlflowException, match=f"{path} does not exist in Databricks Unified Catalog."
 90          ),
 91      ):
 92          uc_volume_source = UCVolumeDatasetSource(path)
 93  
 94      with patch(
 95          "mlflow.data.uc_volume_dataset_source.UCVolumeDatasetSource._verify_uc_path_is_valid",
 96      ):
 97          uc_volume_source = UCVolumeDatasetSource(path)
 98          dataset = MetaDataset(source=uc_volume_source)
 99          json_str = dataset.to_json()
100          parsed_json = json.loads(json_str)
101  
102          assert parsed_json["digest"] is not None
103          assert path in parsed_json["source"]
104          assert parsed_json["source_type"] == "uc_volume"
105  
106  
107  def test_create_meta_dataset_from_dataset():
108      pandas_dataset = from_pandas(
109          df=pd.DataFrame({"a": [1, 2, 3]}),
110          source="/tmp/test.csv",
111      )
112  
113      meta_dataset = MetaDataset(source=pandas_dataset)
114  
115      parsed_json = json.loads(meta_dataset.to_json())
116  
117      assert parsed_json["source_type"] == pandas_dataset._get_source_type()
118      dataset_json = json.loads(parsed_json["source"])
119      assert dataset_json["source_type"] == pandas_dataset.source._get_source_type()