test_spark_connect_model_export.py
1 import json 2 import os 3 from unittest import mock 4 5 import numpy as np 6 import pandas as pd 7 import pyspark 8 import pytest 9 from packaging.version import Version 10 from pyspark.sql import SparkSession 11 from pyspark.sql import functions as F 12 from pyspark.sql.types import LongType 13 from sklearn import datasets 14 15 import mlflow 16 import mlflow.pyfunc.scoring_server as pyfunc_scoring_server 17 from mlflow import pyfunc 18 from mlflow.pyfunc import spark_udf 19 20 from tests.helper_functions import pyfunc_serve_and_score_model 21 from tests.pyfunc.test_spark import score_model_as_udf 22 from tests.spark.test_spark_model_export import SparkModelWithData 23 24 PYSPARK_VERSION = Version(pyspark.__version__) 25 26 27 def _get_spark_connect_session(): 28 builder = SparkSession.builder.remote("local[2]").config( 29 "spark.connect.copyFromLocalToFs.allowDestLocal", "true" 30 ) 31 if not PYSPARK_VERSION.is_devrelease and PYSPARK_VERSION.major < 4: 32 builder.config( 33 "spark.jars.packages", f"org.apache.spark:spark-connect_2.12:{pyspark.__version__}" 34 ) 35 return builder.getOrCreate() 36 37 38 @pytest.fixture 39 def model_path(tmp_path): 40 return os.path.join(tmp_path, "model") 41 42 43 def score_model_as_udf(model_uri, pandas_df, result_type): 44 spark = SparkSession.getActiveSession() 45 spark_df = spark.createDataFrame(pandas_df).coalesce(1) 46 pyfunc_udf = spark_udf( 47 spark=spark, model_uri=model_uri, result_type=result_type, env_manager="local" 48 ) 49 new_df = spark_df.withColumn("prediction", pyfunc_udf(F.struct(F.col("features")))) 50 return new_df.toPandas()["prediction"] 51 52 53 @pytest.fixture(scope="module") 54 def spark(): 55 spark = _get_spark_connect_session() 56 yield spark 57 spark.stop() 58 59 60 @pytest.fixture(scope="module") 61 def iris_df(spark): 62 X, y = datasets.load_iris(return_X_y=True) 63 spark_df = spark.createDataFrame(zip(X, y), schema="features: array<double>, label: long") 64 return spark_df.toPandas(), spark_df 65 66 67 @pytest.fixture(scope="module") 68 def spark_model(iris_df): 69 from pyspark.ml.connect.classification import LogisticRegression 70 from pyspark.ml.connect.feature import StandardScaler 71 from pyspark.ml.connect.pipeline import Pipeline 72 73 iris_pandas_df, iris_spark_df = iris_df 74 scaler = StandardScaler(inputCol="features", outputCol="scaled_features") 75 lr = LogisticRegression(maxIter=10, numTrainWorkers=2, learningRate=0.001) 76 pipeline = Pipeline(stages=[scaler, lr]) 77 # Fit the model 78 model = pipeline.fit(iris_spark_df) 79 preds_pandas_df = model.transform(iris_pandas_df.copy(deep=False)) 80 return SparkModelWithData( 81 model=model, 82 spark_df=None, 83 pandas_df=iris_pandas_df, 84 predictions=preds_pandas_df, 85 ) 86 87 88 @pytest.fixture 89 def model_path(tmp_path): 90 return os.path.join(tmp_path, "model") 91 92 93 def test_model_export(spark_model, model_path): 94 mlflow.spark.save_model(spark_model.model, path=model_path) 95 # 1. score and compare reloaded sparkml model 96 reloaded_model = mlflow.spark.load_model(model_uri=model_path) 97 preds_df = reloaded_model.transform(spark_model.pandas_df.copy(deep=False)) 98 pd.testing.assert_frame_equal(spark_model.predictions, preds_df, check_dtype=False) 99 100 m = pyfunc.load_model(model_path) 101 # 2. score and compare reloaded pyfunc 102 preds2 = m.predict(spark_model.pandas_df.copy(deep=False)) 103 pd.testing.assert_series_equal(spark_model.predictions["prediction"], preds2, check_dtype=False) 104 105 # 3. score and compare reloaded pyfunc Spark udf 106 preds3 = score_model_as_udf( 107 model_uri=model_path, pandas_df=spark_model.pandas_df, result_type=LongType() 108 ) 109 pd.testing.assert_series_equal(spark_model.predictions["prediction"], preds3, check_dtype=False) 110 111 112 def test_sparkml_model_log(spark_model): 113 with mlflow.start_run(): 114 model_info = mlflow.spark.log_model( 115 spark_model.model, 116 artifact_path="model", 117 ) 118 model_uri = model_info.model_uri 119 120 reloaded_model = mlflow.spark.load_model(model_uri=model_uri) 121 preds_df = reloaded_model.transform(spark_model.pandas_df.copy(deep=False)) 122 pd.testing.assert_frame_equal(spark_model.predictions, preds_df, check_dtype=False) 123 124 125 def test_pyfunc_serve_and_score(spark_model): 126 artifact_path = "model" 127 with mlflow.start_run(): 128 model_info = mlflow.spark.log_model(spark_model.model, artifact_path=artifact_path) 129 130 input_data = pd.DataFrame({"features": spark_model.pandas_df["features"].map(list)}) 131 resp = pyfunc_serve_and_score_model( 132 model_info.model_uri, 133 data=input_data, 134 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 135 extra_args=["--env-manager", "local"], 136 ) 137 scores = pd.DataFrame( 138 data=json.loads(resp.content.decode("utf-8"))["predictions"] 139 ).values.squeeze() 140 np.testing.assert_array_almost_equal( 141 scores, spark_model.model.transform(spark_model.pandas_df)["prediction"].values 142 ) 143 144 145 def test_databricks_serverless_model_save_load(spark_model): 146 with ( 147 mock.patch("mlflow.utils.databricks_utils.is_in_databricks_runtime", return_value=True), 148 mock.patch("mlflow.spark._is_uc_volume_uri", return_value=True), 149 ): 150 for mock_fun in [ 151 "is_in_databricks_serverless_runtime", 152 "is_in_databricks_shared_cluster_runtime", 153 ]: 154 with mock.patch(f"mlflow.utils.databricks_utils.{mock_fun}", return_value=True): 155 artifact_path = "model" 156 with mlflow.start_run(): 157 model_info = mlflow.spark.log_model( 158 spark_model.model, artifact_path=artifact_path 159 ) 160 161 mlflow.spark.load_model(model_info.model_uri)