test_polars_dataset.py
1 from __future__ import annotations 2 3 import json 4 import re 5 from datetime import date, datetime 6 from pathlib import Path 7 8 import pandas as pd 9 import polars as pl 10 import pytest 11 12 from mlflow.data.code_dataset_source import CodeDatasetSource 13 from mlflow.data.evaluation_dataset import EvaluationDataset 14 from mlflow.data.filesystem_dataset_source import FileSystemDatasetSource 15 from mlflow.data.polars_dataset import PolarsDataset, from_polars, infer_schema 16 from mlflow.data.pyfunc_dataset_mixin import PyFuncInputsOutputs 17 from mlflow.exceptions import MlflowException 18 from mlflow.types.schema import Array, ColSpec, DataType, Object, Property, Schema 19 20 from tests.resources.data.dataset_source import SampleDatasetSource 21 22 23 @pytest.fixture(name="source", scope="module") 24 def sample_source() -> SampleDatasetSource: 25 source_uri = "test:/my/test/uri" 26 return SampleDatasetSource._resolve(source_uri) 27 28 29 def test_infer_schema() -> None: 30 data = [ 31 [ 32 b"asd", 33 True, 34 datetime(2024, 1, 1, 12, 34, 56, 789), 35 10, 36 10, 37 10, 38 10, 39 10, 40 10, 41 "asd", 42 "😆", 43 "category", 44 "val2", 45 date(2024, 1, 1), 46 10, 47 10, 48 10, 49 [1, 2, 3], 50 [1, 2, 3], 51 {"col1": 1}, 52 ] 53 ] 54 schema = { 55 "Binary": pl.Binary, 56 "Boolean": pl.Boolean, 57 "Datetime": pl.Datetime, 58 "Float32": pl.Float32, 59 "Float64": pl.Float64, 60 "Int8": pl.Int8, 61 "Int16": pl.Int16, 62 "Int32": pl.Int32, 63 "Int64": pl.Int64, 64 "String": pl.String, 65 "Utf8": pl.Utf8, 66 "Categorical": pl.Categorical, 67 "Enum": pl.Enum(["val1", "val2"]), 68 "Date": pl.Date, 69 "UInt8": pl.UInt8, 70 "UInt16": pl.UInt16, 71 "UInt32": pl.UInt32, 72 "List": pl.List(pl.Int8), 73 "Array": pl.Array(pl.Int8, 3), 74 "Struct": pl.Struct({"col1": pl.Int8}), 75 } 76 df = pl.DataFrame(data=data, schema=schema) 77 78 assert infer_schema(df) == Schema([ 79 ColSpec(name="Binary", type=DataType.binary), 80 ColSpec(name="Boolean", type=DataType.boolean), 81 ColSpec(name="Datetime", type=DataType.datetime), 82 ColSpec(name="Float32", type=DataType.float), 83 ColSpec(name="Float64", type=DataType.double), 84 ColSpec(name="Int8", type=DataType.integer), 85 ColSpec(name="Int16", type=DataType.integer), 86 ColSpec(name="Int32", type=DataType.integer), 87 ColSpec(name="Int64", type=DataType.long), 88 ColSpec(name="String", type=DataType.string), 89 ColSpec(name="Utf8", type=DataType.string), 90 ColSpec(name="Categorical", type=DataType.string), 91 ColSpec(name="Enum", type=DataType.string), 92 ColSpec(name="Date", type=DataType.datetime), 93 ColSpec(name="UInt8", type=DataType.integer), 94 ColSpec(name="UInt16", type=DataType.integer), 95 ColSpec(name="UInt32", type=DataType.long), 96 ColSpec(name="List", type=Array(DataType.integer)), 97 ColSpec(name="Array", type=Array(DataType.integer)), 98 ColSpec(name="Struct", type=Object([Property(name="col1", dtype=DataType.integer)])), 99 ]) 100 101 102 def test_conversion_to_json(source: SampleDatasetSource) -> None: 103 dataset = PolarsDataset( 104 df=pl.DataFrame([1, 2, 3], schema=["Numbers"]), source=source, name="testname" 105 ) 106 107 dataset_json = dataset.to_json() 108 parsed_json = json.loads(dataset_json) 109 110 assert parsed_json.keys() <= {"name", "digest", "source", "source_type", "schema", "profile"} 111 assert parsed_json["name"] == dataset.name 112 assert parsed_json["digest"] == dataset.digest 113 assert parsed_json["source"] == dataset.source.to_json() 114 assert parsed_json["source_type"] == dataset.source._get_source_type() 115 assert parsed_json["profile"] == json.dumps(dataset.profile) 116 117 schema_json = json.dumps(json.loads(parsed_json["schema"])["mlflow_colspec"]) 118 assert Schema.from_json(schema_json) == dataset.schema 119 120 121 def test_digest_property_has_expected_value(source: SampleDatasetSource) -> None: 122 dataset = PolarsDataset(df=pl.DataFrame([1, 2, 3], schema=["Numbers"]), source=source) 123 assert dataset.digest == dataset._compute_digest() 124 # Digest value varies across Polars versions due to hash_rows() implementation changes 125 assert re.match(r"^\d+$", dataset.digest) 126 127 128 def test_digest_consistent(source: SampleDatasetSource) -> None: 129 dataset1 = PolarsDataset( 130 df=pl.DataFrame({"numbers": [1, 2, 3], "strs": ["a", "b", "c"]}), source=source 131 ) 132 133 dataset2 = PolarsDataset( 134 df=pl.DataFrame({"numbers": [2, 3, 1], "strs": ["b", "c", "a"]}), source=source 135 ) 136 assert dataset1.digest == dataset2.digest 137 138 139 def test_digest_change(source: SampleDatasetSource) -> None: 140 dataset1 = PolarsDataset( 141 df=pl.DataFrame({"numbers": [1, 2, 3], "strs": ["a", "b", "c"]}), source=source 142 ) 143 144 dataset2 = PolarsDataset( 145 df=pl.DataFrame({"numbers": [10, 20, 30], "strs": ["aa", "bb", "cc"]}), source=source 146 ) 147 assert dataset1.digest != dataset2.digest 148 149 150 def test_df_property(source: SampleDatasetSource) -> None: 151 df = pl.DataFrame({"numbers": [1, 2, 3]}) 152 dataset = PolarsDataset(df=df, source=source) 153 assert dataset.df.equals(df) 154 155 156 def test_targets_none(source: SampleDatasetSource) -> None: 157 df_no_targets = pl.DataFrame({"numbers": [1, 2, 3]}) 158 dataset_no_targets = PolarsDataset(df=df_no_targets, source=source) 159 assert dataset_no_targets._targets is None 160 161 162 def test_targets_not_none(source: SampleDatasetSource) -> None: 163 df_with_targets = pl.DataFrame({"a": [1, 1], "b": [2, 2], "c": [3, 3]}) 164 dataset_with_targets = PolarsDataset(df=df_with_targets, source=source, targets="c") 165 assert dataset_with_targets._targets == "c" 166 167 168 def test_targets_invalid(source: SampleDatasetSource) -> None: 169 df = pl.DataFrame({"a": [1, 1], "b": [2, 2], "c": [3, 3]}) 170 with pytest.raises( 171 MlflowException, 172 match="DataFrame does not contain specified targets column: 'd'", 173 ): 174 PolarsDataset(df=df, source=source, targets="d") 175 176 177 def test_to_pyfunc_wo_outputs(source: SampleDatasetSource) -> None: 178 df = pl.DataFrame({"numbers": [1, 2, 3]}) 179 dataset = PolarsDataset(df=df, source=source) 180 181 input_outputs = dataset.to_pyfunc() 182 183 assert isinstance(input_outputs, PyFuncInputsOutputs) 184 assert len(input_outputs.inputs) == 1 185 assert isinstance(input_outputs.inputs[0], pd.DataFrame) 186 assert input_outputs.inputs[0].equals(pd.DataFrame({"numbers": [1, 2, 3]})) 187 188 189 def test_to_pyfunc_with_outputs(source: SampleDatasetSource) -> None: 190 df = pl.DataFrame({"a": [1, 1], "b": [2, 2], "c": [3, 3]}) 191 dataset = PolarsDataset(df=df, source=source, targets="c") 192 193 input_outputs = dataset.to_pyfunc() 194 195 assert isinstance(input_outputs, PyFuncInputsOutputs) 196 assert len(input_outputs.inputs) == 1 197 assert isinstance(input_outputs.inputs[0], pd.DataFrame) 198 assert input_outputs.inputs[0].equals(pd.DataFrame({"a": [1, 1], "b": [2, 2]})) 199 assert len(input_outputs.outputs) == 1 200 assert isinstance(input_outputs.outputs[0], pd.Series) 201 assert input_outputs.outputs[0].equals(pd.Series([3, 3], name="c")) 202 203 204 def test_from_polars_with_targets(tmp_path: Path) -> None: 205 df = pl.DataFrame({"a": [1, 1], "b": [2, 2], "c": [3, 3]}) 206 path = tmp_path / "temp.csv" 207 df.write_csv(path) 208 209 dataset = from_polars(df, targets="c", source=str(path)) 210 input_outputs = dataset.to_pyfunc() 211 212 assert isinstance(input_outputs, PyFuncInputsOutputs) 213 assert len(input_outputs.inputs) == 1 214 assert isinstance(input_outputs.inputs[0], pd.DataFrame) 215 assert input_outputs.inputs[0].equals(pd.DataFrame({"a": [1, 1], "b": [2, 2]})) 216 assert len(input_outputs.outputs) == 1 217 assert isinstance(input_outputs.outputs[0], pd.Series) 218 assert input_outputs.outputs[0].equals(pd.Series([3, 3], name="c")) 219 220 221 def test_from_polars_file_system_datasource(tmp_path: Path) -> None: 222 df = pl.DataFrame({"a": [1, 1], "b": [2, 2], "c": [3, 3]}) 223 path = tmp_path / "temp.csv" 224 df.write_csv(path) 225 226 mlflow_df = from_polars(df, source=str(path)) 227 228 assert isinstance(mlflow_df, PolarsDataset) 229 assert mlflow_df.df.equals(df) 230 assert mlflow_df.schema == infer_schema(df) 231 assert mlflow_df.profile == {"num_rows": 2, "num_elements": 6} 232 assert isinstance(mlflow_df.source, FileSystemDatasetSource) 233 234 235 def test_from_polars_no_source_specified() -> None: 236 df = pl.DataFrame({"a": [1, 1], "b": [2, 2], "c": [3, 3]}) 237 238 mlflow_df = from_polars(df) 239 240 assert isinstance(mlflow_df, PolarsDataset) 241 assert isinstance(mlflow_df.source, CodeDatasetSource) 242 assert "mlflow.source.name" in mlflow_df.source.to_json() 243 244 245 def test_to_evaluation_dataset(source: SampleDatasetSource) -> None: 246 import numpy as np 247 248 df = pl.DataFrame({"a": [1, 1], "b": [2, 2], "c": [3, 3]}) 249 dataset = PolarsDataset(df=df, source=source, targets="c", name="testname") 250 evaluation_dataset = dataset.to_evaluation_dataset() 251 252 assert evaluation_dataset.name is not None 253 assert evaluation_dataset.digest is not None 254 assert isinstance(evaluation_dataset, EvaluationDataset) 255 assert isinstance(evaluation_dataset.features_data, pd.DataFrame) 256 assert evaluation_dataset.features_data.equals(df.drop("c").to_pandas()) 257 assert isinstance(evaluation_dataset.labels_data, np.ndarray) 258 assert np.array_equal(evaluation_dataset.labels_data, df["c"].to_numpy())