test_spark_dataset_source.py
1 import json 2 3 import pandas as pd 4 import pytest 5 6 from mlflow.data.dataset_source_registry import get_dataset_source_from_json 7 from mlflow.data.spark_dataset_source import SparkDatasetSource 8 from mlflow.exceptions import MlflowException 9 10 11 @pytest.fixture(scope="module") 12 def spark_session(): 13 from pyspark.sql import SparkSession 14 15 with ( 16 SparkSession.builder 17 .master("local[*]") 18 .config("spark.jars.packages", "io.delta:delta-spark_2.12:3.0.0") 19 .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") 20 .config( 21 "spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog" 22 ) 23 .getOrCreate() 24 ) as session: 25 yield session 26 27 28 def test_spark_dataset_source_from_path(spark_session, tmp_path): 29 df = pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"]) 30 df_spark = spark_session.createDataFrame(df) 31 path = str(tmp_path / "temp.parquet") 32 df_spark.write.parquet(path) 33 34 spark_datasource = SparkDatasetSource(path=path) 35 assert spark_datasource.to_json() == json.dumps({"path": path}) 36 loaded_df_spark = spark_datasource.load() 37 assert loaded_df_spark.count() == df_spark.count() 38 39 reloaded_source = get_dataset_source_from_json( 40 spark_datasource.to_json(), source_type=spark_datasource._get_source_type() 41 ) 42 assert isinstance(reloaded_source, SparkDatasetSource) 43 assert type(spark_datasource) == type(reloaded_source) 44 assert reloaded_source.to_json() == spark_datasource.to_json() 45 46 47 def test_spark_dataset_source_from_table(spark_session, tmp_path): 48 df = pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"]) 49 df_spark = spark_session.createDataFrame(df) 50 df_spark.write.mode("overwrite").saveAsTable("temp", path=tmp_path) 51 52 spark_datasource = SparkDatasetSource(table_name="temp") 53 assert spark_datasource.to_json() == json.dumps({"table_name": "temp"}) 54 loaded_df_spark = spark_datasource.load() 55 assert loaded_df_spark.count() == df_spark.count() 56 57 reloaded_source = get_dataset_source_from_json( 58 spark_datasource.to_json(), source_type=spark_datasource._get_source_type() 59 ) 60 assert isinstance(reloaded_source, SparkDatasetSource) 61 assert type(spark_datasource) == type(reloaded_source) 62 assert reloaded_source.to_json() == spark_datasource.to_json() 63 64 65 def test_spark_dataset_source_from_sql(spark_session, tmp_path): 66 df = pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"]) 67 df_spark = spark_session.createDataFrame(df) 68 df_spark.write.mode("overwrite").saveAsTable("temp_sql", path=tmp_path) 69 70 spark_datasource = SparkDatasetSource(sql="SELECT * FROM temp_sql") 71 assert spark_datasource.to_json() == json.dumps({"sql": "SELECT * FROM temp_sql"}) 72 loaded_df_spark = spark_datasource.load() 73 assert loaded_df_spark.count() == df_spark.count() 74 75 reloaded_source = get_dataset_source_from_json( 76 spark_datasource.to_json(), source_type=spark_datasource._get_source_type() 77 ) 78 assert isinstance(reloaded_source, SparkDatasetSource) 79 assert type(spark_datasource) == type(reloaded_source) 80 assert reloaded_source.to_json() == spark_datasource.to_json() 81 82 83 def test_spark_dataset_source_too_many_inputs(spark_session, tmp_path): 84 df = pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"]) 85 df_spark = spark_session.createDataFrame(df) 86 df_spark.write.mode("overwrite").saveAsTable("temp", path=tmp_path) 87 88 with pytest.raises( 89 MlflowException, match='Must specify exactly one of "path", "table_name", or "sql"' 90 ): 91 SparkDatasetSource(path=tmp_path, table_name="temp")