/ tests / data / test_spark_dataset_source.py
test_spark_dataset_source.py
 1  import json
 2  
 3  import pandas as pd
 4  import pytest
 5  
 6  from mlflow.data.dataset_source_registry import get_dataset_source_from_json
 7  from mlflow.data.spark_dataset_source import SparkDatasetSource
 8  from mlflow.exceptions import MlflowException
 9  
10  
11  @pytest.fixture(scope="module")
12  def spark_session():
13      from pyspark.sql import SparkSession
14  
15      with (
16          SparkSession.builder
17          .master("local[*]")
18          .config("spark.jars.packages", "io.delta:delta-spark_2.12:3.0.0")
19          .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
20          .config(
21              "spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog"
22          )
23          .getOrCreate()
24      ) as session:
25          yield session
26  
27  
28  def test_spark_dataset_source_from_path(spark_session, tmp_path):
29      df = pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"])
30      df_spark = spark_session.createDataFrame(df)
31      path = str(tmp_path / "temp.parquet")
32      df_spark.write.parquet(path)
33  
34      spark_datasource = SparkDatasetSource(path=path)
35      assert spark_datasource.to_json() == json.dumps({"path": path})
36      loaded_df_spark = spark_datasource.load()
37      assert loaded_df_spark.count() == df_spark.count()
38  
39      reloaded_source = get_dataset_source_from_json(
40          spark_datasource.to_json(), source_type=spark_datasource._get_source_type()
41      )
42      assert isinstance(reloaded_source, SparkDatasetSource)
43      assert type(spark_datasource) == type(reloaded_source)
44      assert reloaded_source.to_json() == spark_datasource.to_json()
45  
46  
47  def test_spark_dataset_source_from_table(spark_session, tmp_path):
48      df = pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"])
49      df_spark = spark_session.createDataFrame(df)
50      df_spark.write.mode("overwrite").saveAsTable("temp", path=tmp_path)
51  
52      spark_datasource = SparkDatasetSource(table_name="temp")
53      assert spark_datasource.to_json() == json.dumps({"table_name": "temp"})
54      loaded_df_spark = spark_datasource.load()
55      assert loaded_df_spark.count() == df_spark.count()
56  
57      reloaded_source = get_dataset_source_from_json(
58          spark_datasource.to_json(), source_type=spark_datasource._get_source_type()
59      )
60      assert isinstance(reloaded_source, SparkDatasetSource)
61      assert type(spark_datasource) == type(reloaded_source)
62      assert reloaded_source.to_json() == spark_datasource.to_json()
63  
64  
65  def test_spark_dataset_source_from_sql(spark_session, tmp_path):
66      df = pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"])
67      df_spark = spark_session.createDataFrame(df)
68      df_spark.write.mode("overwrite").saveAsTable("temp_sql", path=tmp_path)
69  
70      spark_datasource = SparkDatasetSource(sql="SELECT * FROM temp_sql")
71      assert spark_datasource.to_json() == json.dumps({"sql": "SELECT * FROM temp_sql"})
72      loaded_df_spark = spark_datasource.load()
73      assert loaded_df_spark.count() == df_spark.count()
74  
75      reloaded_source = get_dataset_source_from_json(
76          spark_datasource.to_json(), source_type=spark_datasource._get_source_type()
77      )
78      assert isinstance(reloaded_source, SparkDatasetSource)
79      assert type(spark_datasource) == type(reloaded_source)
80      assert reloaded_source.to_json() == spark_datasource.to_json()
81  
82  
83  def test_spark_dataset_source_too_many_inputs(spark_session, tmp_path):
84      df = pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"])
85      df_spark = spark_session.createDataFrame(df)
86      df_spark.write.mode("overwrite").saveAsTable("temp", path=tmp_path)
87  
88      with pytest.raises(
89          MlflowException, match='Must specify exactly one of "path", "table_name", or "sql"'
90      ):
91          SparkDatasetSource(path=tmp_path, table_name="temp")