/ tests / data / test_dataset_source.py
test_dataset_source.py
 1  import json
 2  
 3  import pandas as pd
 4  import pytest
 5  
 6  import mlflow.data
 7  from mlflow.exceptions import MlflowException
 8  
 9  from tests.resources.data.dataset_source import SampleDatasetSource
10  
11  
12  def test_load(tmp_path):
13      assert SampleDatasetSource("test:" + str(tmp_path)).load() == str(tmp_path)
14  
15  
16  def test_conversion_to_json_and_back():
17      uri = "test:/my/test/uri"
18      source = SampleDatasetSource._resolve(uri)
19      source_json = source.to_json()
20      assert json.loads(source_json)["uri"] == uri
21      reloaded_source = SampleDatasetSource.from_json(source_json)
22      assert reloaded_source.uri == source.uri
23  
24  
25  def test_get_source_obtains_expected_file_source(tmp_path):
26      df = pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"])
27      path = tmp_path / "temp.csv"
28      df.to_csv(path)
29      pandas_ds = mlflow.data.from_pandas(df, source=path)
30  
31      source1 = mlflow.data.get_source(pandas_ds)
32      assert json.loads(source1.to_json()) == json.loads(pandas_ds.source.to_json())
33  
34      with mlflow.start_run() as r:
35          mlflow.log_input(pandas_ds)
36  
37      run = mlflow.get_run(r.info.run_id)
38  
39      ds_input = run.inputs.dataset_inputs[0]
40      source2 = mlflow.data.get_source(ds_input)
41      assert json.loads(source2.to_json()) == json.loads(pandas_ds.source.to_json())
42  
43      ds_entity = run.inputs.dataset_inputs[0].dataset
44      source3 = mlflow.data.get_source(ds_entity)
45      assert json.loads(source3.to_json()) == json.loads(pandas_ds.source.to_json())
46  
47      assert source1.load() == source2.load() == source3.load() == str(path)
48  
49  
50  def test_get_source_obtains_expected_code_source():
51      df = pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"])
52      pandas_ds = mlflow.data.from_pandas(df)
53  
54      source1 = mlflow.data.get_source(pandas_ds)
55      assert json.loads(source1.to_json()) == json.loads(pandas_ds.source.to_json())
56  
57      with mlflow.start_run() as r:
58          mlflow.log_input(pandas_ds)
59  
60      run = mlflow.get_run(r.info.run_id)
61  
62      ds_input = run.inputs.dataset_inputs[0]
63      source2 = mlflow.data.get_source(ds_input)
64      assert json.loads(source2.to_json()) == json.loads(pandas_ds.source.to_json())
65  
66      ds_entity = run.inputs.dataset_inputs[0].dataset
67      source3 = mlflow.data.get_source(ds_entity)
68      assert json.loads(source3.to_json()) == json.loads(pandas_ds.source.to_json())
69  
70  
71  def test_get_source_throws_for_invalid_input(tmp_path):
72      with pytest.raises(MlflowException, match="Unrecognized dataset type.*str"):
73          mlflow.data.get_source(str(tmp_path))