/ tests / data / test_spark_dataset.py
test_spark_dataset.py
  1  import json
  2  import os
  3  from typing import TYPE_CHECKING, Any
  4  
  5  import pandas as pd
  6  import pytest
  7  from packaging.version import Version
  8  
  9  import mlflow.data
 10  from mlflow.data.code_dataset_source import CodeDatasetSource
 11  from mlflow.data.delta_dataset_source import DeltaDatasetSource
 12  from mlflow.data.evaluation_dataset import EvaluationDataset
 13  from mlflow.data.spark_dataset import SparkDataset
 14  from mlflow.data.spark_dataset_source import SparkDatasetSource
 15  from mlflow.exceptions import MlflowException
 16  from mlflow.types.schema import Schema
 17  from mlflow.types.utils import _infer_schema
 18  
 19  if TYPE_CHECKING:
 20      from pyspark.sql import SparkSession
 21  
 22  
 23  @pytest.fixture(scope="module")
 24  def spark_session(tmp_path_factory: pytest.TempPathFactory):
 25      import pyspark
 26      from pyspark.sql import SparkSession
 27  
 28      pyspark_version = Version(pyspark.__version__)
 29      if pyspark_version.major >= 4:
 30          delta_package = "io.delta:delta-spark_2.13:4.0.0"
 31      else:
 32          delta_package = "io.delta:delta-spark_2.12:3.0.0"
 33  
 34      tmp_dir = tmp_path_factory.mktemp("spark_tmp")
 35      with (
 36          SparkSession.builder
 37          .master("local[*]")
 38          .config("spark.jars.packages", delta_package)
 39          .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
 40          .config(
 41              "spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog"
 42          )
 43          .config("spark.sql.warehouse.dir", str(tmp_dir))
 44          .getOrCreate()
 45      ) as session:
 46          yield session
 47  
 48  
 49  @pytest.fixture(autouse=True)
 50  def drop_tables(spark_session: "SparkSession"):
 51      yield
 52      for row in spark_session.sql("SHOW TABLES").collect():
 53          spark_session.sql(f"DROP TABLE IF EXISTS {row.tableName}")
 54  
 55  
 56  @pytest.fixture
 57  def df():
 58      return pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"])
 59  
 60  
 61  def _assert_dataframes_equal(df1, df2):
 62      if df1.schema == df2.schema:
 63          diff = df1.exceptAll(df2)
 64          assert diff.rdd.isEmpty()
 65      else:
 66          assert False
 67  
 68  
 69  def _validate_profile_approx_count(parsed_json: dict[str, Any]) -> None:
 70      """Validate approx_count in profile data, handling platform/version differences."""
 71      # On Windows with certain PySpark versions, Spark datasets may return "unknown" for approx_count
 72      # instead of the actual count. We should check that the profile is valid JSON and contains
 73      # the expected key, but not assert on the exact value.
 74      profile_data = json.loads(parsed_json["profile"])
 75      assert "approx_count" in profile_data
 76      assert profile_data["approx_count"] in [1, 2, "unknown"]
 77  
 78  
 79  def _check_spark_dataset(dataset, original_df, df_spark, expected_source_type, expected_name=None):
 80      assert isinstance(dataset, SparkDataset)
 81      _assert_dataframes_equal(dataset.df, df_spark)
 82      assert dataset.schema == _infer_schema(original_df)
 83      assert isinstance(dataset.profile, dict)
 84      approx_count = dataset.profile.get("approx_count")
 85      assert isinstance(approx_count, int) or approx_count == "unknown"
 86      assert isinstance(dataset.source, expected_source_type)
 87      # NB: In real-world scenarios, Spark dataset sources may not match Spark DataFrames precisely.
 88      # For example, users may transform Spark DataFrames after loading contents from source files.
 89      # To ensure that source loading works properly for the purpose of the test cases in this suite,
 90      # we require the source to match the DataFrame and make the following equality assertion
 91      _assert_dataframes_equal(dataset.source.load(), df_spark)
 92      if expected_name is not None:
 93          assert dataset.name == expected_name
 94  
 95  
 96  def test_conversion_to_json_spark_dataset_source(spark_session, tmp_path, df):
 97      df_spark = spark_session.createDataFrame(df)
 98      path = str(tmp_path / "temp.parquet")
 99      df_spark.write.parquet(path)
100  
101      source = SparkDatasetSource(path=path)
102  
103      dataset = SparkDataset(
104          df=df_spark,
105          source=source,
106          name="testname",
107      )
108  
109      dataset_json = dataset.to_json()
110      parsed_json = json.loads(dataset_json)
111      assert parsed_json.keys() <= {"name", "digest", "source", "source_type", "schema", "profile"}
112      assert parsed_json["name"] == dataset.name
113      assert parsed_json["digest"] == dataset.digest
114      assert parsed_json["source"] == dataset.source.to_json()
115      assert parsed_json["source_type"] == dataset.source._get_source_type()
116      _validate_profile_approx_count(parsed_json)
117  
118      schema_json = json.dumps(json.loads(parsed_json["schema"])["mlflow_colspec"])
119      assert Schema.from_json(schema_json) == dataset.schema
120  
121  
122  def test_conversion_to_json_delta_dataset_source(spark_session, tmp_path, df):
123      df_spark = spark_session.createDataFrame(df)
124      path = str(tmp_path / "temp.parquet")
125      df_spark.write.format("delta").save(path)
126  
127      source = DeltaDatasetSource(path=path)
128  
129      dataset = SparkDataset(
130          df=df_spark,
131          source=source,
132          name="testname",
133      )
134  
135      dataset_json = dataset.to_json()
136      parsed_json = json.loads(dataset_json)
137      assert parsed_json.keys() <= {"name", "digest", "source", "source_type", "schema", "profile"}
138      assert parsed_json["name"] == dataset.name
139      assert parsed_json["digest"] == dataset.digest
140      assert parsed_json["source"] == dataset.source.to_json()
141      assert parsed_json["source_type"] == dataset.source._get_source_type()
142      _validate_profile_approx_count(parsed_json)
143  
144      schema_json = json.dumps(json.loads(parsed_json["schema"])["mlflow_colspec"])
145      assert Schema.from_json(schema_json) == dataset.schema
146  
147  
148  def test_digest_property_has_expected_value(spark_session, tmp_path, df):
149      df_spark = spark_session.createDataFrame(df)
150      path = str(tmp_path / "temp.parquet")
151      df_spark.write.parquet(path)
152  
153      source = SparkDatasetSource(path=path)
154  
155      dataset = SparkDataset(
156          df=df_spark,
157          source=source,
158          name="testname",
159      )
160      assert dataset.digest == dataset._compute_digest()
161      # Note that digests are stable within a session, but may not be stable across sessions
162      # Hence we are not checking the digest value here
163  
164  
165  def test_df_property_has_expected_value(spark_session, tmp_path, df):
166      df_spark = spark_session.createDataFrame(df)
167      path = str(tmp_path / "temp.parquet")
168      df_spark.write.parquet(path)
169  
170      source = SparkDatasetSource(path=path)
171  
172      dataset = SparkDataset(
173          df=df_spark,
174          source=source,
175          name="testname",
176      )
177      assert dataset.df == df_spark
178  
179  
180  def test_targets_property(spark_session, tmp_path, df):
181      df_spark = spark_session.createDataFrame(df)
182      path = str(tmp_path / "temp.parquet")
183      df_spark.write.parquet(path)
184  
185      source = SparkDatasetSource(path=path)
186      dataset_no_targets = SparkDataset(
187          df=df_spark,
188          source=source,
189          name="testname",
190      )
191      assert dataset_no_targets.targets is None
192      dataset_with_targets = SparkDataset(
193          df=df_spark,
194          source=source,
195          targets="c",
196          name="testname",
197      )
198      assert dataset_with_targets.targets == "c"
199  
200      with pytest.raises(
201          MlflowException,
202          match="The specified Spark dataset does not contain the specified targets column",
203      ):
204          SparkDataset(
205              df=df_spark,
206              source=source,
207              targets="nonexistent",
208              name="testname",
209          )
210  
211  
212  def test_predictions_property(spark_session, tmp_path, df):
213      df_spark = spark_session.createDataFrame(df)
214      path = str(tmp_path / "temp.parquet")
215      df_spark.write.parquet(path)
216  
217      source = SparkDatasetSource(path=path)
218      dataset_no_predictions = SparkDataset(
219          df=df_spark,
220          source=source,
221          name="testname",
222      )
223      assert dataset_no_predictions.predictions is None
224      dataset_with_predictions = SparkDataset(
225          df=df_spark,
226          source=source,
227          predictions="b",
228          name="testname",
229      )
230      assert dataset_with_predictions.predictions == "b"
231  
232      with pytest.raises(
233          MlflowException,
234          match="The specified Spark dataset does not contain the specified predictions column",
235      ):
236          SparkDataset(
237              df=df_spark,
238              source=source,
239              predictions="nonexistent",
240              name="testname",
241          )
242  
243  
244  def test_from_spark_no_source_specified(spark_session, df):
245      df_spark = spark_session.createDataFrame(df)
246      mlflow_df = mlflow.data.from_spark(df_spark)
247  
248      assert isinstance(mlflow_df, SparkDataset)
249  
250      assert isinstance(mlflow_df.source, CodeDatasetSource)
251      assert "mlflow.source.name" in mlflow_df.source.to_json()
252  
253  
254  def test_from_spark_with_sql_and_version(spark_session, tmp_path, df):
255      df_spark = spark_session.createDataFrame(df)
256      path = str(tmp_path / "temp.parquet")
257      df_spark.write.parquet(path)
258      with pytest.raises(
259          MlflowException,
260          match="`version` may not be specified when `sql` is specified. `version` may only be"
261          " specified when `table_name` or `path` is specified.",
262      ):
263          mlflow.data.from_spark(df_spark, sql="SELECT * FROM table", version=1)
264  
265  
266  def test_from_spark_path(spark_session, tmp_path, df):
267      df_spark = spark_session.createDataFrame(df)
268      dir_path = str(tmp_path / "df_dir")
269      df_spark.write.parquet(dir_path)
270      assert os.path.isdir(dir_path)
271  
272      mlflow_df_from_dir = mlflow.data.from_spark(df_spark, path=dir_path)
273      _check_spark_dataset(mlflow_df_from_dir, df, df_spark, SparkDatasetSource)
274  
275      file_path = str(tmp_path / "df.parquet")
276      df_spark.toPandas().to_parquet(file_path)
277      assert not os.path.isdir(file_path)
278  
279      mlflow_df_from_file = mlflow.data.from_spark(df_spark, path=file_path)
280      _check_spark_dataset(mlflow_df_from_file, df, df_spark, SparkDatasetSource)
281  
282  
283  def test_from_spark_delta_path(spark_session, tmp_path, df):
284      df_spark = spark_session.createDataFrame(df)
285      path = str(tmp_path / "temp.delta")
286      df_spark.write.format("delta").save(path)
287  
288      mlflow_df = mlflow.data.from_spark(df_spark, path=path)
289  
290      _check_spark_dataset(mlflow_df, df, df_spark, DeltaDatasetSource)
291  
292  
293  def test_from_spark_sql(spark_session, df):
294      df_spark = spark_session.createDataFrame(df)
295      df_spark.createOrReplaceTempView("table")
296  
297      mlflow_df = mlflow.data.from_spark(df_spark, sql="SELECT * FROM table")
298  
299      _check_spark_dataset(mlflow_df, df, df_spark, SparkDatasetSource)
300  
301  
302  def test_from_spark_table_name(spark_session, df):
303      df_spark = spark_session.createDataFrame(df)
304      df_spark.createOrReplaceTempView("my_spark_table")
305  
306      mlflow_df = mlflow.data.from_spark(df_spark, table_name="my_spark_table")
307  
308      _check_spark_dataset(mlflow_df, df, df_spark, SparkDatasetSource)
309  
310  
311  def test_from_spark_table_name_with_version(spark_session, df):
312      df_spark = spark_session.createDataFrame(df)
313      df_spark.createOrReplaceTempView("my_spark_table")
314  
315      with pytest.raises(
316          MlflowException,
317          match="Version '1' was specified, but could not find a Delta table "
318          "with name 'my_spark_table'",
319      ):
320          mlflow.data.from_spark(df_spark, table_name="my_spark_table", version=1)
321  
322  
323  def test_from_spark_delta_table_name(spark_session, df):
324      df_spark = spark_session.createDataFrame(df)
325      # write to delta table
326      df_spark.write.format("delta").mode("overwrite").saveAsTable("my_delta_table")
327  
328      mlflow_df = mlflow.data.from_spark(df_spark, table_name="my_delta_table")
329  
330      _check_spark_dataset(mlflow_df, df, df_spark, DeltaDatasetSource)
331  
332  
333  def test_from_spark_delta_table_name_and_version(spark_session, df):
334      df_spark = spark_session.createDataFrame(df)
335      # write to delta table
336      df_spark.write.format("delta").mode("overwrite").saveAsTable("my_delta_table")
337  
338      mlflow_df = mlflow.data.from_spark(df_spark, table_name="my_delta_table", version=0)
339  
340      _check_spark_dataset(mlflow_df, df, df_spark, DeltaDatasetSource)
341  
342  
343  def test_load_delta_with_no_source_info():
344      with pytest.raises(
345          MlflowException,
346          match="Must specify exactly one of `table_name` or `path`.",
347      ):
348          mlflow.data.load_delta()
349  
350  
351  def test_load_delta_with_both_table_name_and_path():
352      with pytest.raises(
353          MlflowException,
354          match="Must specify exactly one of `table_name` or `path`.",
355      ):
356          mlflow.data.load_delta(table_name="my_table", path="my_path")
357  
358  
359  def test_load_delta_path(spark_session, tmp_path, df):
360      df_spark = spark_session.createDataFrame(df)
361      path = str(tmp_path / "temp.delta")
362      df_spark.write.format("delta").mode("overwrite").save(path)
363  
364      mlflow_df = mlflow.data.load_delta(path=path)
365  
366      _check_spark_dataset(mlflow_df, df, df_spark, DeltaDatasetSource)
367  
368  
369  def test_load_delta_path_with_version(spark_session, tmp_path, df):
370      path = str(tmp_path / "temp.delta")
371  
372      df_v0 = pd.DataFrame([[4, 5, 6], [4, 5, 6]], columns=["a", "b", "c"])
373      assert not df_v0.equals(df)
374      df_v0_spark = spark_session.createDataFrame(df_v0)
375      df_v0_spark.write.format("delta").mode("overwrite").save(path)
376  
377      # write again to create a new version
378      df_v1_spark = spark_session.createDataFrame(df)
379      df_v1_spark.write.format("delta").mode("overwrite").save(path)
380  
381      mlflow_df = mlflow.data.load_delta(path=path, version=1)
382      _check_spark_dataset(mlflow_df, df, df_v1_spark, DeltaDatasetSource)
383  
384  
385  def test_load_delta_table_name(spark_session, df):
386      df_spark = spark_session.createDataFrame(df)
387      # write to delta table
388      df_spark.write.format("delta").mode("overwrite").saveAsTable("my_delta_table")
389  
390      mlflow_df = mlflow.data.load_delta(table_name="my_delta_table")
391  
392      _check_spark_dataset(mlflow_df, df, df_spark, DeltaDatasetSource, "my_delta_table@v0")
393  
394  
395  def test_load_delta_table_name_with_version(spark_session, df):
396      df_spark = spark_session.createDataFrame(df)
397      df_spark.write.format("delta").mode("overwrite").saveAsTable("my_delta_table_versioned")
398  
399      df2 = pd.DataFrame([[4, 5, 6], [4, 5, 6]], columns=["a", "b", "c"])
400      assert not df2.equals(df)
401      df2_spark = spark_session.createDataFrame(df2)
402      df2_spark.write.format("delta").mode("overwrite").saveAsTable("my_delta_table_versioned")
403  
404      mlflow_df = mlflow.data.load_delta(table_name="my_delta_table_versioned", version=1)
405  
406      _check_spark_dataset(
407          mlflow_df, df2, df2_spark, DeltaDatasetSource, "my_delta_table_versioned@v1"
408      )
409      pd.testing.assert_frame_equal(mlflow_df.df.toPandas(), df2)
410  
411  
412  def test_to_evaluation_dataset(spark_session, tmp_path, df):
413      import numpy as np
414  
415      df_spark = spark_session.createDataFrame(df)
416      path = str(tmp_path / "temp.parquet")
417      df_spark.write.parquet(path)
418  
419      source = SparkDatasetSource(path=path)
420  
421      dataset = SparkDataset(
422          df=df_spark,
423          source=source,
424          targets="c",
425          name="testname",
426          predictions="b",
427      )
428      evaluation_dataset = dataset.to_evaluation_dataset()
429      assert isinstance(evaluation_dataset, EvaluationDataset)
430      assert evaluation_dataset.features_data.equals(df_spark.toPandas()[["a"]])
431      assert np.array_equal(evaluation_dataset.labels_data, df_spark.toPandas()["c"].values)
432      assert np.array_equal(evaluation_dataset.predictions_data, df_spark.toPandas()["b"].values)