test_spark_dataset.py
1 import json 2 import os 3 from typing import TYPE_CHECKING, Any 4 5 import pandas as pd 6 import pytest 7 from packaging.version import Version 8 9 import mlflow.data 10 from mlflow.data.code_dataset_source import CodeDatasetSource 11 from mlflow.data.delta_dataset_source import DeltaDatasetSource 12 from mlflow.data.evaluation_dataset import EvaluationDataset 13 from mlflow.data.spark_dataset import SparkDataset 14 from mlflow.data.spark_dataset_source import SparkDatasetSource 15 from mlflow.exceptions import MlflowException 16 from mlflow.types.schema import Schema 17 from mlflow.types.utils import _infer_schema 18 19 if TYPE_CHECKING: 20 from pyspark.sql import SparkSession 21 22 23 @pytest.fixture(scope="module") 24 def spark_session(tmp_path_factory: pytest.TempPathFactory): 25 import pyspark 26 from pyspark.sql import SparkSession 27 28 pyspark_version = Version(pyspark.__version__) 29 if pyspark_version.major >= 4: 30 delta_package = "io.delta:delta-spark_2.13:4.0.0" 31 else: 32 delta_package = "io.delta:delta-spark_2.12:3.0.0" 33 34 tmp_dir = tmp_path_factory.mktemp("spark_tmp") 35 with ( 36 SparkSession.builder 37 .master("local[*]") 38 .config("spark.jars.packages", delta_package) 39 .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") 40 .config( 41 "spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog" 42 ) 43 .config("spark.sql.warehouse.dir", str(tmp_dir)) 44 .getOrCreate() 45 ) as session: 46 yield session 47 48 49 @pytest.fixture(autouse=True) 50 def drop_tables(spark_session: "SparkSession"): 51 yield 52 for row in spark_session.sql("SHOW TABLES").collect(): 53 spark_session.sql(f"DROP TABLE IF EXISTS {row.tableName}") 54 55 56 @pytest.fixture 57 def df(): 58 return pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"]) 59 60 61 def _assert_dataframes_equal(df1, df2): 62 if df1.schema == df2.schema: 63 diff = df1.exceptAll(df2) 64 assert diff.rdd.isEmpty() 65 else: 66 assert False 67 68 69 def _validate_profile_approx_count(parsed_json: dict[str, Any]) -> None: 70 """Validate approx_count in profile data, handling platform/version differences.""" 71 # On Windows with certain PySpark versions, Spark datasets may return "unknown" for approx_count 72 # instead of the actual count. We should check that the profile is valid JSON and contains 73 # the expected key, but not assert on the exact value. 74 profile_data = json.loads(parsed_json["profile"]) 75 assert "approx_count" in profile_data 76 assert profile_data["approx_count"] in [1, 2, "unknown"] 77 78 79 def _check_spark_dataset(dataset, original_df, df_spark, expected_source_type, expected_name=None): 80 assert isinstance(dataset, SparkDataset) 81 _assert_dataframes_equal(dataset.df, df_spark) 82 assert dataset.schema == _infer_schema(original_df) 83 assert isinstance(dataset.profile, dict) 84 approx_count = dataset.profile.get("approx_count") 85 assert isinstance(approx_count, int) or approx_count == "unknown" 86 assert isinstance(dataset.source, expected_source_type) 87 # NB: In real-world scenarios, Spark dataset sources may not match Spark DataFrames precisely. 88 # For example, users may transform Spark DataFrames after loading contents from source files. 89 # To ensure that source loading works properly for the purpose of the test cases in this suite, 90 # we require the source to match the DataFrame and make the following equality assertion 91 _assert_dataframes_equal(dataset.source.load(), df_spark) 92 if expected_name is not None: 93 assert dataset.name == expected_name 94 95 96 def test_conversion_to_json_spark_dataset_source(spark_session, tmp_path, df): 97 df_spark = spark_session.createDataFrame(df) 98 path = str(tmp_path / "temp.parquet") 99 df_spark.write.parquet(path) 100 101 source = SparkDatasetSource(path=path) 102 103 dataset = SparkDataset( 104 df=df_spark, 105 source=source, 106 name="testname", 107 ) 108 109 dataset_json = dataset.to_json() 110 parsed_json = json.loads(dataset_json) 111 assert parsed_json.keys() <= {"name", "digest", "source", "source_type", "schema", "profile"} 112 assert parsed_json["name"] == dataset.name 113 assert parsed_json["digest"] == dataset.digest 114 assert parsed_json["source"] == dataset.source.to_json() 115 assert parsed_json["source_type"] == dataset.source._get_source_type() 116 _validate_profile_approx_count(parsed_json) 117 118 schema_json = json.dumps(json.loads(parsed_json["schema"])["mlflow_colspec"]) 119 assert Schema.from_json(schema_json) == dataset.schema 120 121 122 def test_conversion_to_json_delta_dataset_source(spark_session, tmp_path, df): 123 df_spark = spark_session.createDataFrame(df) 124 path = str(tmp_path / "temp.parquet") 125 df_spark.write.format("delta").save(path) 126 127 source = DeltaDatasetSource(path=path) 128 129 dataset = SparkDataset( 130 df=df_spark, 131 source=source, 132 name="testname", 133 ) 134 135 dataset_json = dataset.to_json() 136 parsed_json = json.loads(dataset_json) 137 assert parsed_json.keys() <= {"name", "digest", "source", "source_type", "schema", "profile"} 138 assert parsed_json["name"] == dataset.name 139 assert parsed_json["digest"] == dataset.digest 140 assert parsed_json["source"] == dataset.source.to_json() 141 assert parsed_json["source_type"] == dataset.source._get_source_type() 142 _validate_profile_approx_count(parsed_json) 143 144 schema_json = json.dumps(json.loads(parsed_json["schema"])["mlflow_colspec"]) 145 assert Schema.from_json(schema_json) == dataset.schema 146 147 148 def test_digest_property_has_expected_value(spark_session, tmp_path, df): 149 df_spark = spark_session.createDataFrame(df) 150 path = str(tmp_path / "temp.parquet") 151 df_spark.write.parquet(path) 152 153 source = SparkDatasetSource(path=path) 154 155 dataset = SparkDataset( 156 df=df_spark, 157 source=source, 158 name="testname", 159 ) 160 assert dataset.digest == dataset._compute_digest() 161 # Note that digests are stable within a session, but may not be stable across sessions 162 # Hence we are not checking the digest value here 163 164 165 def test_df_property_has_expected_value(spark_session, tmp_path, df): 166 df_spark = spark_session.createDataFrame(df) 167 path = str(tmp_path / "temp.parquet") 168 df_spark.write.parquet(path) 169 170 source = SparkDatasetSource(path=path) 171 172 dataset = SparkDataset( 173 df=df_spark, 174 source=source, 175 name="testname", 176 ) 177 assert dataset.df == df_spark 178 179 180 def test_targets_property(spark_session, tmp_path, df): 181 df_spark = spark_session.createDataFrame(df) 182 path = str(tmp_path / "temp.parquet") 183 df_spark.write.parquet(path) 184 185 source = SparkDatasetSource(path=path) 186 dataset_no_targets = SparkDataset( 187 df=df_spark, 188 source=source, 189 name="testname", 190 ) 191 assert dataset_no_targets.targets is None 192 dataset_with_targets = SparkDataset( 193 df=df_spark, 194 source=source, 195 targets="c", 196 name="testname", 197 ) 198 assert dataset_with_targets.targets == "c" 199 200 with pytest.raises( 201 MlflowException, 202 match="The specified Spark dataset does not contain the specified targets column", 203 ): 204 SparkDataset( 205 df=df_spark, 206 source=source, 207 targets="nonexistent", 208 name="testname", 209 ) 210 211 212 def test_predictions_property(spark_session, tmp_path, df): 213 df_spark = spark_session.createDataFrame(df) 214 path = str(tmp_path / "temp.parquet") 215 df_spark.write.parquet(path) 216 217 source = SparkDatasetSource(path=path) 218 dataset_no_predictions = SparkDataset( 219 df=df_spark, 220 source=source, 221 name="testname", 222 ) 223 assert dataset_no_predictions.predictions is None 224 dataset_with_predictions = SparkDataset( 225 df=df_spark, 226 source=source, 227 predictions="b", 228 name="testname", 229 ) 230 assert dataset_with_predictions.predictions == "b" 231 232 with pytest.raises( 233 MlflowException, 234 match="The specified Spark dataset does not contain the specified predictions column", 235 ): 236 SparkDataset( 237 df=df_spark, 238 source=source, 239 predictions="nonexistent", 240 name="testname", 241 ) 242 243 244 def test_from_spark_no_source_specified(spark_session, df): 245 df_spark = spark_session.createDataFrame(df) 246 mlflow_df = mlflow.data.from_spark(df_spark) 247 248 assert isinstance(mlflow_df, SparkDataset) 249 250 assert isinstance(mlflow_df.source, CodeDatasetSource) 251 assert "mlflow.source.name" in mlflow_df.source.to_json() 252 253 254 def test_from_spark_with_sql_and_version(spark_session, tmp_path, df): 255 df_spark = spark_session.createDataFrame(df) 256 path = str(tmp_path / "temp.parquet") 257 df_spark.write.parquet(path) 258 with pytest.raises( 259 MlflowException, 260 match="`version` may not be specified when `sql` is specified. `version` may only be" 261 " specified when `table_name` or `path` is specified.", 262 ): 263 mlflow.data.from_spark(df_spark, sql="SELECT * FROM table", version=1) 264 265 266 def test_from_spark_path(spark_session, tmp_path, df): 267 df_spark = spark_session.createDataFrame(df) 268 dir_path = str(tmp_path / "df_dir") 269 df_spark.write.parquet(dir_path) 270 assert os.path.isdir(dir_path) 271 272 mlflow_df_from_dir = mlflow.data.from_spark(df_spark, path=dir_path) 273 _check_spark_dataset(mlflow_df_from_dir, df, df_spark, SparkDatasetSource) 274 275 file_path = str(tmp_path / "df.parquet") 276 df_spark.toPandas().to_parquet(file_path) 277 assert not os.path.isdir(file_path) 278 279 mlflow_df_from_file = mlflow.data.from_spark(df_spark, path=file_path) 280 _check_spark_dataset(mlflow_df_from_file, df, df_spark, SparkDatasetSource) 281 282 283 def test_from_spark_delta_path(spark_session, tmp_path, df): 284 df_spark = spark_session.createDataFrame(df) 285 path = str(tmp_path / "temp.delta") 286 df_spark.write.format("delta").save(path) 287 288 mlflow_df = mlflow.data.from_spark(df_spark, path=path) 289 290 _check_spark_dataset(mlflow_df, df, df_spark, DeltaDatasetSource) 291 292 293 def test_from_spark_sql(spark_session, df): 294 df_spark = spark_session.createDataFrame(df) 295 df_spark.createOrReplaceTempView("table") 296 297 mlflow_df = mlflow.data.from_spark(df_spark, sql="SELECT * FROM table") 298 299 _check_spark_dataset(mlflow_df, df, df_spark, SparkDatasetSource) 300 301 302 def test_from_spark_table_name(spark_session, df): 303 df_spark = spark_session.createDataFrame(df) 304 df_spark.createOrReplaceTempView("my_spark_table") 305 306 mlflow_df = mlflow.data.from_spark(df_spark, table_name="my_spark_table") 307 308 _check_spark_dataset(mlflow_df, df, df_spark, SparkDatasetSource) 309 310 311 def test_from_spark_table_name_with_version(spark_session, df): 312 df_spark = spark_session.createDataFrame(df) 313 df_spark.createOrReplaceTempView("my_spark_table") 314 315 with pytest.raises( 316 MlflowException, 317 match="Version '1' was specified, but could not find a Delta table " 318 "with name 'my_spark_table'", 319 ): 320 mlflow.data.from_spark(df_spark, table_name="my_spark_table", version=1) 321 322 323 def test_from_spark_delta_table_name(spark_session, df): 324 df_spark = spark_session.createDataFrame(df) 325 # write to delta table 326 df_spark.write.format("delta").mode("overwrite").saveAsTable("my_delta_table") 327 328 mlflow_df = mlflow.data.from_spark(df_spark, table_name="my_delta_table") 329 330 _check_spark_dataset(mlflow_df, df, df_spark, DeltaDatasetSource) 331 332 333 def test_from_spark_delta_table_name_and_version(spark_session, df): 334 df_spark = spark_session.createDataFrame(df) 335 # write to delta table 336 df_spark.write.format("delta").mode("overwrite").saveAsTable("my_delta_table") 337 338 mlflow_df = mlflow.data.from_spark(df_spark, table_name="my_delta_table", version=0) 339 340 _check_spark_dataset(mlflow_df, df, df_spark, DeltaDatasetSource) 341 342 343 def test_load_delta_with_no_source_info(): 344 with pytest.raises( 345 MlflowException, 346 match="Must specify exactly one of `table_name` or `path`.", 347 ): 348 mlflow.data.load_delta() 349 350 351 def test_load_delta_with_both_table_name_and_path(): 352 with pytest.raises( 353 MlflowException, 354 match="Must specify exactly one of `table_name` or `path`.", 355 ): 356 mlflow.data.load_delta(table_name="my_table", path="my_path") 357 358 359 def test_load_delta_path(spark_session, tmp_path, df): 360 df_spark = spark_session.createDataFrame(df) 361 path = str(tmp_path / "temp.delta") 362 df_spark.write.format("delta").mode("overwrite").save(path) 363 364 mlflow_df = mlflow.data.load_delta(path=path) 365 366 _check_spark_dataset(mlflow_df, df, df_spark, DeltaDatasetSource) 367 368 369 def test_load_delta_path_with_version(spark_session, tmp_path, df): 370 path = str(tmp_path / "temp.delta") 371 372 df_v0 = pd.DataFrame([[4, 5, 6], [4, 5, 6]], columns=["a", "b", "c"]) 373 assert not df_v0.equals(df) 374 df_v0_spark = spark_session.createDataFrame(df_v0) 375 df_v0_spark.write.format("delta").mode("overwrite").save(path) 376 377 # write again to create a new version 378 df_v1_spark = spark_session.createDataFrame(df) 379 df_v1_spark.write.format("delta").mode("overwrite").save(path) 380 381 mlflow_df = mlflow.data.load_delta(path=path, version=1) 382 _check_spark_dataset(mlflow_df, df, df_v1_spark, DeltaDatasetSource) 383 384 385 def test_load_delta_table_name(spark_session, df): 386 df_spark = spark_session.createDataFrame(df) 387 # write to delta table 388 df_spark.write.format("delta").mode("overwrite").saveAsTable("my_delta_table") 389 390 mlflow_df = mlflow.data.load_delta(table_name="my_delta_table") 391 392 _check_spark_dataset(mlflow_df, df, df_spark, DeltaDatasetSource, "my_delta_table@v0") 393 394 395 def test_load_delta_table_name_with_version(spark_session, df): 396 df_spark = spark_session.createDataFrame(df) 397 df_spark.write.format("delta").mode("overwrite").saveAsTable("my_delta_table_versioned") 398 399 df2 = pd.DataFrame([[4, 5, 6], [4, 5, 6]], columns=["a", "b", "c"]) 400 assert not df2.equals(df) 401 df2_spark = spark_session.createDataFrame(df2) 402 df2_spark.write.format("delta").mode("overwrite").saveAsTable("my_delta_table_versioned") 403 404 mlflow_df = mlflow.data.load_delta(table_name="my_delta_table_versioned", version=1) 405 406 _check_spark_dataset( 407 mlflow_df, df2, df2_spark, DeltaDatasetSource, "my_delta_table_versioned@v1" 408 ) 409 pd.testing.assert_frame_equal(mlflow_df.df.toPandas(), df2) 410 411 412 def test_to_evaluation_dataset(spark_session, tmp_path, df): 413 import numpy as np 414 415 df_spark = spark_session.createDataFrame(df) 416 path = str(tmp_path / "temp.parquet") 417 df_spark.write.parquet(path) 418 419 source = SparkDatasetSource(path=path) 420 421 dataset = SparkDataset( 422 df=df_spark, 423 source=source, 424 targets="c", 425 name="testname", 426 predictions="b", 427 ) 428 evaluation_dataset = dataset.to_evaluation_dataset() 429 assert isinstance(evaluation_dataset, EvaluationDataset) 430 assert evaluation_dataset.features_data.equals(df_spark.toPandas()[["a"]]) 431 assert np.array_equal(evaluation_dataset.labels_data, df_spark.toPandas()["c"].values) 432 assert np.array_equal(evaluation_dataset.predictions_data, df_spark.toPandas()["b"].values)