test_dataset_source.py
1 import json 2 3 import pandas as pd 4 import pytest 5 6 import mlflow.data 7 from mlflow.exceptions import MlflowException 8 9 from tests.resources.data.dataset_source import SampleDatasetSource 10 11 12 def test_load(tmp_path): 13 assert SampleDatasetSource("test:" + str(tmp_path)).load() == str(tmp_path) 14 15 16 def test_conversion_to_json_and_back(): 17 uri = "test:/my/test/uri" 18 source = SampleDatasetSource._resolve(uri) 19 source_json = source.to_json() 20 assert json.loads(source_json)["uri"] == uri 21 reloaded_source = SampleDatasetSource.from_json(source_json) 22 assert reloaded_source.uri == source.uri 23 24 25 def test_get_source_obtains_expected_file_source(tmp_path): 26 df = pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"]) 27 path = tmp_path / "temp.csv" 28 df.to_csv(path) 29 pandas_ds = mlflow.data.from_pandas(df, source=path) 30 31 source1 = mlflow.data.get_source(pandas_ds) 32 assert json.loads(source1.to_json()) == json.loads(pandas_ds.source.to_json()) 33 34 with mlflow.start_run() as r: 35 mlflow.log_input(pandas_ds) 36 37 run = mlflow.get_run(r.info.run_id) 38 39 ds_input = run.inputs.dataset_inputs[0] 40 source2 = mlflow.data.get_source(ds_input) 41 assert json.loads(source2.to_json()) == json.loads(pandas_ds.source.to_json()) 42 43 ds_entity = run.inputs.dataset_inputs[0].dataset 44 source3 = mlflow.data.get_source(ds_entity) 45 assert json.loads(source3.to_json()) == json.loads(pandas_ds.source.to_json()) 46 47 assert source1.load() == source2.load() == source3.load() == str(path) 48 49 50 def test_get_source_obtains_expected_code_source(): 51 df = pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"]) 52 pandas_ds = mlflow.data.from_pandas(df) 53 54 source1 = mlflow.data.get_source(pandas_ds) 55 assert json.loads(source1.to_json()) == json.loads(pandas_ds.source.to_json()) 56 57 with mlflow.start_run() as r: 58 mlflow.log_input(pandas_ds) 59 60 run = mlflow.get_run(r.info.run_id) 61 62 ds_input = run.inputs.dataset_inputs[0] 63 source2 = mlflow.data.get_source(ds_input) 64 assert json.loads(source2.to_json()) == json.loads(pandas_ds.source.to_json()) 65 66 ds_entity = run.inputs.dataset_inputs[0].dataset 67 source3 = mlflow.data.get_source(ds_entity) 68 assert json.loads(source3.to_json()) == json.loads(pandas_ds.source.to_json()) 69 70 71 def test_get_source_throws_for_invalid_input(tmp_path): 72 with pytest.raises(MlflowException, match="Unrecognized dataset type.*str"): 73 mlflow.data.get_source(str(tmp_path))