/ tests / data / test_polars_dataset.py
test_polars_dataset.py
  1  from __future__ import annotations
  2  
  3  import json
  4  import re
  5  from datetime import date, datetime
  6  from pathlib import Path
  7  
  8  import pandas as pd
  9  import polars as pl
 10  import pytest
 11  
 12  from mlflow.data.code_dataset_source import CodeDatasetSource
 13  from mlflow.data.evaluation_dataset import EvaluationDataset
 14  from mlflow.data.filesystem_dataset_source import FileSystemDatasetSource
 15  from mlflow.data.polars_dataset import PolarsDataset, from_polars, infer_schema
 16  from mlflow.data.pyfunc_dataset_mixin import PyFuncInputsOutputs
 17  from mlflow.exceptions import MlflowException
 18  from mlflow.types.schema import Array, ColSpec, DataType, Object, Property, Schema
 19  
 20  from tests.resources.data.dataset_source import SampleDatasetSource
 21  
 22  
 23  @pytest.fixture(name="source", scope="module")
 24  def sample_source() -> SampleDatasetSource:
 25      source_uri = "test:/my/test/uri"
 26      return SampleDatasetSource._resolve(source_uri)
 27  
 28  
 29  def test_infer_schema() -> None:
 30      data = [
 31          [
 32              b"asd",
 33              True,
 34              datetime(2024, 1, 1, 12, 34, 56, 789),
 35              10,
 36              10,
 37              10,
 38              10,
 39              10,
 40              10,
 41              "asd",
 42              "😆",
 43              "category",
 44              "val2",
 45              date(2024, 1, 1),
 46              10,
 47              10,
 48              10,
 49              [1, 2, 3],
 50              [1, 2, 3],
 51              {"col1": 1},
 52          ]
 53      ]
 54      schema = {
 55          "Binary": pl.Binary,
 56          "Boolean": pl.Boolean,
 57          "Datetime": pl.Datetime,
 58          "Float32": pl.Float32,
 59          "Float64": pl.Float64,
 60          "Int8": pl.Int8,
 61          "Int16": pl.Int16,
 62          "Int32": pl.Int32,
 63          "Int64": pl.Int64,
 64          "String": pl.String,
 65          "Utf8": pl.Utf8,
 66          "Categorical": pl.Categorical,
 67          "Enum": pl.Enum(["val1", "val2"]),
 68          "Date": pl.Date,
 69          "UInt8": pl.UInt8,
 70          "UInt16": pl.UInt16,
 71          "UInt32": pl.UInt32,
 72          "List": pl.List(pl.Int8),
 73          "Array": pl.Array(pl.Int8, 3),
 74          "Struct": pl.Struct({"col1": pl.Int8}),
 75      }
 76      df = pl.DataFrame(data=data, schema=schema)
 77  
 78      assert infer_schema(df) == Schema([
 79          ColSpec(name="Binary", type=DataType.binary),
 80          ColSpec(name="Boolean", type=DataType.boolean),
 81          ColSpec(name="Datetime", type=DataType.datetime),
 82          ColSpec(name="Float32", type=DataType.float),
 83          ColSpec(name="Float64", type=DataType.double),
 84          ColSpec(name="Int8", type=DataType.integer),
 85          ColSpec(name="Int16", type=DataType.integer),
 86          ColSpec(name="Int32", type=DataType.integer),
 87          ColSpec(name="Int64", type=DataType.long),
 88          ColSpec(name="String", type=DataType.string),
 89          ColSpec(name="Utf8", type=DataType.string),
 90          ColSpec(name="Categorical", type=DataType.string),
 91          ColSpec(name="Enum", type=DataType.string),
 92          ColSpec(name="Date", type=DataType.datetime),
 93          ColSpec(name="UInt8", type=DataType.integer),
 94          ColSpec(name="UInt16", type=DataType.integer),
 95          ColSpec(name="UInt32", type=DataType.long),
 96          ColSpec(name="List", type=Array(DataType.integer)),
 97          ColSpec(name="Array", type=Array(DataType.integer)),
 98          ColSpec(name="Struct", type=Object([Property(name="col1", dtype=DataType.integer)])),
 99      ])
100  
101  
102  def test_conversion_to_json(source: SampleDatasetSource) -> None:
103      dataset = PolarsDataset(
104          df=pl.DataFrame([1, 2, 3], schema=["Numbers"]), source=source, name="testname"
105      )
106  
107      dataset_json = dataset.to_json()
108      parsed_json = json.loads(dataset_json)
109  
110      assert parsed_json.keys() <= {"name", "digest", "source", "source_type", "schema", "profile"}
111      assert parsed_json["name"] == dataset.name
112      assert parsed_json["digest"] == dataset.digest
113      assert parsed_json["source"] == dataset.source.to_json()
114      assert parsed_json["source_type"] == dataset.source._get_source_type()
115      assert parsed_json["profile"] == json.dumps(dataset.profile)
116  
117      schema_json = json.dumps(json.loads(parsed_json["schema"])["mlflow_colspec"])
118      assert Schema.from_json(schema_json) == dataset.schema
119  
120  
121  def test_digest_property_has_expected_value(source: SampleDatasetSource) -> None:
122      dataset = PolarsDataset(df=pl.DataFrame([1, 2, 3], schema=["Numbers"]), source=source)
123      assert dataset.digest == dataset._compute_digest()
124      # Digest value varies across Polars versions due to hash_rows() implementation changes
125      assert re.match(r"^\d+$", dataset.digest)
126  
127  
128  def test_digest_consistent(source: SampleDatasetSource) -> None:
129      dataset1 = PolarsDataset(
130          df=pl.DataFrame({"numbers": [1, 2, 3], "strs": ["a", "b", "c"]}), source=source
131      )
132  
133      dataset2 = PolarsDataset(
134          df=pl.DataFrame({"numbers": [2, 3, 1], "strs": ["b", "c", "a"]}), source=source
135      )
136      assert dataset1.digest == dataset2.digest
137  
138  
139  def test_digest_change(source: SampleDatasetSource) -> None:
140      dataset1 = PolarsDataset(
141          df=pl.DataFrame({"numbers": [1, 2, 3], "strs": ["a", "b", "c"]}), source=source
142      )
143  
144      dataset2 = PolarsDataset(
145          df=pl.DataFrame({"numbers": [10, 20, 30], "strs": ["aa", "bb", "cc"]}), source=source
146      )
147      assert dataset1.digest != dataset2.digest
148  
149  
150  def test_df_property(source: SampleDatasetSource) -> None:
151      df = pl.DataFrame({"numbers": [1, 2, 3]})
152      dataset = PolarsDataset(df=df, source=source)
153      assert dataset.df.equals(df)
154  
155  
156  def test_targets_none(source: SampleDatasetSource) -> None:
157      df_no_targets = pl.DataFrame({"numbers": [1, 2, 3]})
158      dataset_no_targets = PolarsDataset(df=df_no_targets, source=source)
159      assert dataset_no_targets._targets is None
160  
161  
162  def test_targets_not_none(source: SampleDatasetSource) -> None:
163      df_with_targets = pl.DataFrame({"a": [1, 1], "b": [2, 2], "c": [3, 3]})
164      dataset_with_targets = PolarsDataset(df=df_with_targets, source=source, targets="c")
165      assert dataset_with_targets._targets == "c"
166  
167  
168  def test_targets_invalid(source: SampleDatasetSource) -> None:
169      df = pl.DataFrame({"a": [1, 1], "b": [2, 2], "c": [3, 3]})
170      with pytest.raises(
171          MlflowException,
172          match="DataFrame does not contain specified targets column: 'd'",
173      ):
174          PolarsDataset(df=df, source=source, targets="d")
175  
176  
177  def test_to_pyfunc_wo_outputs(source: SampleDatasetSource) -> None:
178      df = pl.DataFrame({"numbers": [1, 2, 3]})
179      dataset = PolarsDataset(df=df, source=source)
180  
181      input_outputs = dataset.to_pyfunc()
182  
183      assert isinstance(input_outputs, PyFuncInputsOutputs)
184      assert len(input_outputs.inputs) == 1
185      assert isinstance(input_outputs.inputs[0], pd.DataFrame)
186      assert input_outputs.inputs[0].equals(pd.DataFrame({"numbers": [1, 2, 3]}))
187  
188  
189  def test_to_pyfunc_with_outputs(source: SampleDatasetSource) -> None:
190      df = pl.DataFrame({"a": [1, 1], "b": [2, 2], "c": [3, 3]})
191      dataset = PolarsDataset(df=df, source=source, targets="c")
192  
193      input_outputs = dataset.to_pyfunc()
194  
195      assert isinstance(input_outputs, PyFuncInputsOutputs)
196      assert len(input_outputs.inputs) == 1
197      assert isinstance(input_outputs.inputs[0], pd.DataFrame)
198      assert input_outputs.inputs[0].equals(pd.DataFrame({"a": [1, 1], "b": [2, 2]}))
199      assert len(input_outputs.outputs) == 1
200      assert isinstance(input_outputs.outputs[0], pd.Series)
201      assert input_outputs.outputs[0].equals(pd.Series([3, 3], name="c"))
202  
203  
204  def test_from_polars_with_targets(tmp_path: Path) -> None:
205      df = pl.DataFrame({"a": [1, 1], "b": [2, 2], "c": [3, 3]})
206      path = tmp_path / "temp.csv"
207      df.write_csv(path)
208  
209      dataset = from_polars(df, targets="c", source=str(path))
210      input_outputs = dataset.to_pyfunc()
211  
212      assert isinstance(input_outputs, PyFuncInputsOutputs)
213      assert len(input_outputs.inputs) == 1
214      assert isinstance(input_outputs.inputs[0], pd.DataFrame)
215      assert input_outputs.inputs[0].equals(pd.DataFrame({"a": [1, 1], "b": [2, 2]}))
216      assert len(input_outputs.outputs) == 1
217      assert isinstance(input_outputs.outputs[0], pd.Series)
218      assert input_outputs.outputs[0].equals(pd.Series([3, 3], name="c"))
219  
220  
221  def test_from_polars_file_system_datasource(tmp_path: Path) -> None:
222      df = pl.DataFrame({"a": [1, 1], "b": [2, 2], "c": [3, 3]})
223      path = tmp_path / "temp.csv"
224      df.write_csv(path)
225  
226      mlflow_df = from_polars(df, source=str(path))
227  
228      assert isinstance(mlflow_df, PolarsDataset)
229      assert mlflow_df.df.equals(df)
230      assert mlflow_df.schema == infer_schema(df)
231      assert mlflow_df.profile == {"num_rows": 2, "num_elements": 6}
232      assert isinstance(mlflow_df.source, FileSystemDatasetSource)
233  
234  
235  def test_from_polars_no_source_specified() -> None:
236      df = pl.DataFrame({"a": [1, 1], "b": [2, 2], "c": [3, 3]})
237  
238      mlflow_df = from_polars(df)
239  
240      assert isinstance(mlflow_df, PolarsDataset)
241      assert isinstance(mlflow_df.source, CodeDatasetSource)
242      assert "mlflow.source.name" in mlflow_df.source.to_json()
243  
244  
245  def test_to_evaluation_dataset(source: SampleDatasetSource) -> None:
246      import numpy as np
247  
248      df = pl.DataFrame({"a": [1, 1], "b": [2, 2], "c": [3, 3]})
249      dataset = PolarsDataset(df=df, source=source, targets="c", name="testname")
250      evaluation_dataset = dataset.to_evaluation_dataset()
251  
252      assert evaluation_dataset.name is not None
253      assert evaluation_dataset.digest is not None
254      assert isinstance(evaluation_dataset, EvaluationDataset)
255      assert isinstance(evaluation_dataset.features_data, pd.DataFrame)
256      assert evaluation_dataset.features_data.equals(df.drop("c").to_pandas())
257      assert isinstance(evaluation_dataset.labels_data, np.ndarray)
258      assert np.array_equal(evaluation_dataset.labels_data, df["c"].to_numpy())