/ tests / spark / test_spark_connect_model_export.py
test_spark_connect_model_export.py
  1  import json
  2  import os
  3  from unittest import mock
  4  
  5  import numpy as np
  6  import pandas as pd
  7  import pyspark
  8  import pytest
  9  from packaging.version import Version
 10  from pyspark.sql import SparkSession
 11  from pyspark.sql import functions as F
 12  from pyspark.sql.types import LongType
 13  from sklearn import datasets
 14  
 15  import mlflow
 16  import mlflow.pyfunc.scoring_server as pyfunc_scoring_server
 17  from mlflow import pyfunc
 18  from mlflow.pyfunc import spark_udf
 19  
 20  from tests.helper_functions import pyfunc_serve_and_score_model
 21  from tests.pyfunc.test_spark import score_model_as_udf
 22  from tests.spark.test_spark_model_export import SparkModelWithData
 23  
 24  PYSPARK_VERSION = Version(pyspark.__version__)
 25  
 26  
 27  def _get_spark_connect_session():
 28      builder = SparkSession.builder.remote("local[2]").config(
 29          "spark.connect.copyFromLocalToFs.allowDestLocal", "true"
 30      )
 31      if not PYSPARK_VERSION.is_devrelease and PYSPARK_VERSION.major < 4:
 32          builder.config(
 33              "spark.jars.packages", f"org.apache.spark:spark-connect_2.12:{pyspark.__version__}"
 34          )
 35      return builder.getOrCreate()
 36  
 37  
 38  @pytest.fixture
 39  def model_path(tmp_path):
 40      return os.path.join(tmp_path, "model")
 41  
 42  
 43  def score_model_as_udf(model_uri, pandas_df, result_type):
 44      spark = SparkSession.getActiveSession()
 45      spark_df = spark.createDataFrame(pandas_df).coalesce(1)
 46      pyfunc_udf = spark_udf(
 47          spark=spark, model_uri=model_uri, result_type=result_type, env_manager="local"
 48      )
 49      new_df = spark_df.withColumn("prediction", pyfunc_udf(F.struct(F.col("features"))))
 50      return new_df.toPandas()["prediction"]
 51  
 52  
 53  @pytest.fixture(scope="module")
 54  def spark():
 55      spark = _get_spark_connect_session()
 56      yield spark
 57      spark.stop()
 58  
 59  
 60  @pytest.fixture(scope="module")
 61  def iris_df(spark):
 62      X, y = datasets.load_iris(return_X_y=True)
 63      spark_df = spark.createDataFrame(zip(X, y), schema="features: array<double>, label: long")
 64      return spark_df.toPandas(), spark_df
 65  
 66  
 67  @pytest.fixture(scope="module")
 68  def spark_model(iris_df):
 69      from pyspark.ml.connect.classification import LogisticRegression
 70      from pyspark.ml.connect.feature import StandardScaler
 71      from pyspark.ml.connect.pipeline import Pipeline
 72  
 73      iris_pandas_df, iris_spark_df = iris_df
 74      scaler = StandardScaler(inputCol="features", outputCol="scaled_features")
 75      lr = LogisticRegression(maxIter=10, numTrainWorkers=2, learningRate=0.001)
 76      pipeline = Pipeline(stages=[scaler, lr])
 77      # Fit the model
 78      model = pipeline.fit(iris_spark_df)
 79      preds_pandas_df = model.transform(iris_pandas_df.copy(deep=False))
 80      return SparkModelWithData(
 81          model=model,
 82          spark_df=None,
 83          pandas_df=iris_pandas_df,
 84          predictions=preds_pandas_df,
 85      )
 86  
 87  
 88  @pytest.fixture
 89  def model_path(tmp_path):
 90      return os.path.join(tmp_path, "model")
 91  
 92  
 93  def test_model_export(spark_model, model_path):
 94      mlflow.spark.save_model(spark_model.model, path=model_path)
 95      # 1. score and compare reloaded sparkml model
 96      reloaded_model = mlflow.spark.load_model(model_uri=model_path)
 97      preds_df = reloaded_model.transform(spark_model.pandas_df.copy(deep=False))
 98      pd.testing.assert_frame_equal(spark_model.predictions, preds_df, check_dtype=False)
 99  
100      m = pyfunc.load_model(model_path)
101      # 2. score and compare reloaded pyfunc
102      preds2 = m.predict(spark_model.pandas_df.copy(deep=False))
103      pd.testing.assert_series_equal(spark_model.predictions["prediction"], preds2, check_dtype=False)
104  
105      # 3. score and compare reloaded pyfunc Spark udf
106      preds3 = score_model_as_udf(
107          model_uri=model_path, pandas_df=spark_model.pandas_df, result_type=LongType()
108      )
109      pd.testing.assert_series_equal(spark_model.predictions["prediction"], preds3, check_dtype=False)
110  
111  
112  def test_sparkml_model_log(spark_model):
113      with mlflow.start_run():
114          model_info = mlflow.spark.log_model(
115              spark_model.model,
116              artifact_path="model",
117          )
118      model_uri = model_info.model_uri
119  
120      reloaded_model = mlflow.spark.load_model(model_uri=model_uri)
121      preds_df = reloaded_model.transform(spark_model.pandas_df.copy(deep=False))
122      pd.testing.assert_frame_equal(spark_model.predictions, preds_df, check_dtype=False)
123  
124  
125  def test_pyfunc_serve_and_score(spark_model):
126      artifact_path = "model"
127      with mlflow.start_run():
128          model_info = mlflow.spark.log_model(spark_model.model, artifact_path=artifact_path)
129  
130      input_data = pd.DataFrame({"features": spark_model.pandas_df["features"].map(list)})
131      resp = pyfunc_serve_and_score_model(
132          model_info.model_uri,
133          data=input_data,
134          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
135          extra_args=["--env-manager", "local"],
136      )
137      scores = pd.DataFrame(
138          data=json.loads(resp.content.decode("utf-8"))["predictions"]
139      ).values.squeeze()
140      np.testing.assert_array_almost_equal(
141          scores, spark_model.model.transform(spark_model.pandas_df)["prediction"].values
142      )
143  
144  
145  def test_databricks_serverless_model_save_load(spark_model):
146      with (
147          mock.patch("mlflow.utils.databricks_utils.is_in_databricks_runtime", return_value=True),
148          mock.patch("mlflow.spark._is_uc_volume_uri", return_value=True),
149      ):
150          for mock_fun in [
151              "is_in_databricks_serverless_runtime",
152              "is_in_databricks_shared_cluster_runtime",
153          ]:
154              with mock.patch(f"mlflow.utils.databricks_utils.{mock_fun}", return_value=True):
155                  artifact_path = "model"
156                  with mlflow.start_run():
157                      model_info = mlflow.spark.log_model(
158                          spark_model.model, artifact_path=artifact_path
159                      )
160  
161                  mlflow.spark.load_model(model_info.model_uri)