/ tests / spark / test_spark_model_export.py
test_spark_model_export.py
   1  import inspect
   2  import json
   3  import logging
   4  import os
   5  from pathlib import Path
   6  from typing import Any, NamedTuple
   7  from unittest import mock
   8  
   9  import numpy as np
  10  import pandas as pd
  11  import pyspark
  12  import pytest
  13  import yaml
  14  from packaging.version import Version
  15  from pyspark.ml.classification import LogisticRegression
  16  from pyspark.ml.feature import VectorAssembler
  17  from pyspark.ml.pipeline import Pipeline
  18  from sklearn import datasets
  19  
  20  import mlflow
  21  import mlflow.pyfunc.scoring_server as pyfunc_scoring_server
  22  import mlflow.tracking
  23  import mlflow.utils.file_utils
  24  from mlflow import pyfunc
  25  from mlflow.entities.model_registry import ModelVersion
  26  from mlflow.environment_variables import MLFLOW_DFS_TMP
  27  from mlflow.exceptions import MlflowException
  28  from mlflow.models import Model, ModelSignature
  29  from mlflow.models.utils import _read_example
  30  from mlflow.spark import _add_code_from_conf_to_system_path
  31  from mlflow.store.artifact.databricks_models_artifact_repo import DatabricksModelsArtifactRepository
  32  from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository
  33  from mlflow.store.artifact.unity_catalog_models_artifact_repo import (
  34      UnityCatalogModelsArtifactRepository,
  35  )
  36  from mlflow.tracking.artifact_utils import _download_artifact_from_uri
  37  from mlflow.types import DataType
  38  from mlflow.types.schema import ColSpec, Schema
  39  from mlflow.utils.environment import _get_pip_deps, _mlflow_conda_env
  40  from mlflow.utils.file_utils import TempDir
  41  from mlflow.utils.model_utils import _get_flavor_configuration
  42  
  43  from tests.helper_functions import (
  44      _assert_pip_requirements,
  45      _compare_conda_env_requirements,
  46      _compare_logged_code_paths,
  47      _mlflow_major_version_string,
  48      assert_register_model_called_with_local_model_path,
  49      score_model_in_sagemaker_docker_container,
  50  )
  51  from tests.pyfunc.test_spark import get_spark_session, score_model_as_udf
  52  from tests.store.artifact.constants import MODELS_ARTIFACT_REPOSITORY
  53  
  54  _logger = logging.getLogger(__name__)
  55  
  56  PYSPARK_VERSION = Version(pyspark.__version__)
  57  
  58  
  59  @pytest.fixture
  60  def spark_custom_env(tmp_path):
  61      conda_env = os.path.join(tmp_path, "conda_env.yml")
  62      additional_pip_deps = ["/opt/mlflow", f"pyspark=={PYSPARK_VERSION}", "pytest"]
  63      if PYSPARK_VERSION < Version("3.4"):
  64          additional_pip_deps.extend([
  65              # Versions of PySpark < 3.4 are incompatible with pandas >= 2
  66              "pandas<2",
  67              # pandas<2.0 is incompatible with numpy>=2.0
  68              "numpy<2.0",
  69          ])
  70      _mlflow_conda_env(conda_env, additional_pip_deps=additional_pip_deps)
  71      return conda_env
  72  
  73  
  74  class SparkModelWithData(NamedTuple):
  75      model: Any
  76      spark_df: Any
  77      pandas_df: Any
  78      predictions: Any
  79  
  80  
  81  def _get_spark_session_with_retry(max_tries=3):
  82      conf = pyspark.SparkConf()
  83      for attempt in range(max_tries):
  84          try:
  85              return get_spark_session(conf)
  86          except Exception as e:
  87              if attempt >= max_tries - 1:
  88                  raise
  89              _logger.exception(
  90                  f"Attempt {attempt} to create a SparkSession failed ({e!r}), retrying..."
  91              )
  92  
  93  
  94  # Specify `autouse=True` to ensure that a context is created
  95  # before any tests are executed. This ensures that the Hadoop filesystem
  96  # does not create its own SparkContext.
  97  @pytest.fixture(scope="module")
  98  def spark():
  99      if Version(pyspark.__version__) < Version("3.1"):
 100          # A workaround for this issue:
 101          # https://stackoverflow.com/questions/62109276/errorjava-lang-unsupportedoperationexception-for-pyspark-pandas-udf-documenta
 102          spark_home = (
 103              os.environ.get("SPARK_HOME")
 104              if "SPARK_HOME" in os.environ
 105              else os.path.dirname(pyspark.__file__)
 106          )
 107          conf_dir = os.path.join(spark_home, "conf")
 108          os.makedirs(conf_dir, exist_ok=True)
 109          with open(os.path.join(conf_dir, "spark-defaults.conf"), "w") as f:
 110              conf = """
 111  spark.driver.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true"
 112  spark.executor.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true"
 113  """
 114              f.write(conf)
 115  
 116      with _get_spark_session_with_retry() as spark:
 117          yield spark
 118  
 119  
 120  def iris_pandas_df():
 121      iris = datasets.load_iris()
 122      X = iris.data
 123      y = iris.target
 124      feature_names = ["0", "1", "2", "3"]
 125      df = pd.DataFrame(X, columns=feature_names)  # to make spark_udf work
 126      df["label"] = pd.Series(y)
 127      return df
 128  
 129  
 130  @pytest.fixture(scope="module")
 131  def iris_df(spark):
 132      pdf = iris_pandas_df()
 133      feature_names = list(pdf.drop("label", axis=1).columns)
 134      iris_spark_df = spark.createDataFrame(pdf)
 135      return feature_names, pdf, iris_spark_df
 136  
 137  
 138  @pytest.fixture(scope="module")
 139  def iris_signature():
 140      return ModelSignature(
 141          inputs=Schema([
 142              ColSpec(name="0", type=DataType.double),
 143              ColSpec(name="1", type=DataType.double),
 144              ColSpec(name="2", type=DataType.double),
 145              ColSpec(name="3", type=DataType.double),
 146          ]),
 147          outputs=Schema([ColSpec(type=DataType.double)]),
 148      )
 149  
 150  
 151  @pytest.fixture(scope="module")
 152  def spark_model_iris(iris_df):
 153      feature_names, iris_pandas_df, iris_spark_df = iris_df
 154      assembler = VectorAssembler(inputCols=feature_names, outputCol="features")
 155      lr = LogisticRegression(maxIter=50, regParam=0.1, elasticNetParam=0.8)
 156      pipeline = Pipeline(stages=[assembler, lr])
 157      # Fit the model
 158      model = pipeline.fit(iris_spark_df)
 159      preds_df = model.transform(iris_spark_df)
 160      preds = [x.prediction for x in preds_df.select("prediction").collect()]
 161      return SparkModelWithData(
 162          model=model, spark_df=iris_spark_df, pandas_df=iris_pandas_df, predictions=preds
 163      )
 164  
 165  
 166  @pytest.fixture(scope="module")
 167  def spark_model_transformer(iris_df):
 168      feature_names, iris_pandas_df, iris_spark_df = iris_df
 169      assembler = VectorAssembler(inputCols=feature_names, outputCol="features")
 170      # Fit the model
 171      preds_df = assembler.transform(iris_spark_df)
 172      preds = [x.features for x in preds_df.select("features").collect()]
 173      return SparkModelWithData(
 174          model=assembler, spark_df=iris_spark_df, pandas_df=iris_pandas_df, predictions=preds
 175      )
 176  
 177  
 178  @pytest.fixture(scope="module")
 179  def spark_model_estimator(iris_df):
 180      feature_names, iris_pandas_df, iris_spark_df = iris_df
 181      assembler = VectorAssembler(inputCols=feature_names, outputCol="features")
 182      features_df = assembler.transform(iris_spark_df)
 183      lr = LogisticRegression(maxIter=50, regParam=0.1, elasticNetParam=0.8)
 184      # Fit the model
 185      model = lr.fit(features_df)
 186      preds_df = model.transform(features_df)
 187      preds = [x.prediction for x in preds_df.select("prediction").collect()]
 188      return SparkModelWithData(
 189          model=model, spark_df=features_df, pandas_df=iris_pandas_df, predictions=preds
 190      )
 191  
 192  
 193  @pytest.fixture
 194  def model_path(tmp_path):
 195      return os.path.join(tmp_path, "model")
 196  
 197  
 198  @pytest.mark.usefixtures("spark")
 199  def test_hadoop_filesystem(tmp_path):
 200      # copy local dir to and back from HadoopFS and make sure the results match
 201      from mlflow.spark import _HadoopFileSystem as FS
 202  
 203      test_dir_0 = os.path.join(tmp_path, "expected")
 204      test_file_0 = os.path.join(test_dir_0, "root", "file_0")
 205      test_dir_1 = os.path.join(test_dir_0, "root", "subdir")
 206      test_file_1 = os.path.join(test_dir_1, "file_1")
 207      os.makedirs(os.path.dirname(test_file_0))
 208      with open(test_file_0, "w") as f:
 209          f.write("test0")
 210      os.makedirs(os.path.dirname(test_file_1))
 211      with open(test_file_1, "w") as f:
 212          f.write("test1")
 213      remote = "/tmp/mlflow/test0"
 214      # File should not be copied in this case
 215      assert os.path.abspath(test_dir_0) == FS.maybe_copy_from_local_file(test_dir_0, remote)
 216      FS.copy_from_local_file(test_dir_0, remote, remove_src=False)
 217      local = os.path.join(tmp_path, "actual")
 218      FS.copy_to_local_file(remote, local, remove_src=True)
 219      assert sorted(os.listdir(os.path.join(local, "root"))) == sorted([
 220          "subdir",
 221          "file_0",
 222          ".file_0.crc",
 223      ])
 224      assert sorted(os.listdir(os.path.join(local, "root", "subdir"))) == sorted([
 225          "file_1",
 226          ".file_1.crc",
 227      ])
 228      # compare the files
 229      with open(os.path.join(test_dir_0, "root", "file_0")) as expected_f:
 230          with open(os.path.join(local, "root", "file_0")) as actual_f:
 231              assert expected_f.read() == actual_f.read()
 232      with open(os.path.join(test_dir_0, "root", "subdir", "file_1")) as expected_f:
 233          with open(os.path.join(local, "root", "subdir", "file_1")) as actual_f:
 234              assert expected_f.read() == actual_f.read()
 235  
 236      # make sure we cleanup
 237      assert not os.path.exists(FS._remote_path(remote).toString())  # skip file: prefix
 238      FS.copy_from_local_file(test_dir_0, remote, remove_src=False)
 239      assert os.path.exists(FS._remote_path(remote).toString())  # skip file: prefix
 240      FS.delete(remote)
 241      assert not os.path.exists(FS._remote_path(remote).toString())  # skip file: prefix
 242  
 243  
 244  def test_model_export(spark_model_iris, model_path, spark_custom_env):
 245      mlflow.spark.save_model(spark_model_iris.model, path=model_path, conda_env=spark_custom_env)
 246      # 1. score and compare reloaded sparkml model
 247      reloaded_model = mlflow.spark.load_model(model_uri=model_path)
 248      preds_df = reloaded_model.transform(spark_model_iris.spark_df)
 249      preds1 = [x.prediction for x in preds_df.select("prediction").collect()]
 250      assert spark_model_iris.predictions == preds1
 251      m = pyfunc.load_model(model_path)
 252      # 2. score and compare reloaded pyfunc
 253      preds2 = m.predict(spark_model_iris.pandas_df)
 254      assert spark_model_iris.predictions == preds2
 255      # 3. score and compare reloaded pyfunc Spark udf
 256      preds3 = score_model_as_udf(model_uri=model_path, pandas_df=spark_model_iris.pandas_df)
 257      assert spark_model_iris.predictions == preds3
 258      assert os.path.exists(MLFLOW_DFS_TMP.get())
 259  
 260  
 261  def test_model_export_with_signature_and_examples(spark_model_iris, iris_signature):
 262      features_df = spark_model_iris.pandas_df.drop("label", axis=1)
 263      example_ = features_df.head(3)
 264      for signature in (None, iris_signature):
 265          for example in (None, example_):
 266              with TempDir() as tmp:
 267                  path = tmp.path("model")
 268                  mlflow.spark.save_model(
 269                      spark_model_iris.model, path=path, signature=signature, input_example=example
 270                  )
 271                  mlflow_model = Model.load(path)
 272                  if example is None and signature is None:
 273                      assert mlflow_model.signature is None
 274                  else:
 275                      assert mlflow_model.signature == iris_signature
 276                  if example is None:
 277                      assert mlflow_model.saved_input_example_info is None
 278                  else:
 279                      assert all((_read_example(mlflow_model, path) == example).all())
 280  
 281  
 282  def test_model_export_raise_when_example_is_spark_dataframe(spark, spark_model_iris, model_path):
 283      features_df = spark_model_iris.pandas_df.drop("label", axis=1)
 284      example = spark.createDataFrame(features_df.head(3))
 285      with pytest.raises(MlflowException, match="Examples can not be provided as Spark Dataframe."):
 286          mlflow.spark.save_model(spark_model_iris.model, path=model_path, input_example=example)
 287  
 288  
 289  def test_log_model_with_signature_and_examples(spark_model_iris, iris_signature):
 290      features_df = spark_model_iris.pandas_df.drop("label", axis=1)
 291      example_ = features_df.head(3)
 292      artifact_path = "model"
 293      for signature in (None, iris_signature):
 294          for example in (None, example_):
 295              with mlflow.start_run():
 296                  model_info = mlflow.spark.log_model(
 297                      spark_model_iris.model,
 298                      artifact_path=artifact_path,
 299                      signature=signature,
 300                      input_example=example,
 301                  )
 302                  mlflow_model = Model.load(model_info.model_uri)
 303                  if example is None and signature is None:
 304                      assert mlflow_model.signature is None
 305                  else:
 306                      assert mlflow_model.signature == iris_signature
 307                  if example is None:
 308                      assert mlflow_model.saved_input_example_info is None
 309                  else:
 310                      assert all((_read_example(mlflow_model, model_info.model_uri) == example).all())
 311  
 312  
 313  def test_estimator_model_export(spark_model_estimator, model_path, spark_custom_env):
 314      mlflow.spark.save_model(
 315          spark_model_estimator.model, path=model_path, conda_env=spark_custom_env
 316      )
 317      # score and compare the reloaded sparkml model
 318      reloaded_model = mlflow.spark.load_model(model_uri=model_path)
 319      preds_df = reloaded_model.transform(spark_model_estimator.spark_df)
 320      preds = [x.prediction for x in preds_df.select("prediction").collect()]
 321      assert spark_model_estimator.predictions == preds
 322      # 2. score and compare reloaded pyfunc
 323      m = pyfunc.load_model(model_path)
 324      preds2 = m.predict(spark_model_estimator.spark_df.toPandas())
 325      assert spark_model_estimator.predictions == preds2
 326  
 327  
 328  def test_transformer_model_export(spark_model_transformer, model_path, spark_custom_env):
 329      mlflow.spark.save_model(
 330          spark_model_transformer.model, path=model_path, conda_env=spark_custom_env
 331      )
 332      # score and compare the reloaded sparkml model
 333      reloaded_model = mlflow.spark.load_model(model_uri=model_path)
 334      preds_df = reloaded_model.transform(spark_model_transformer.spark_df)
 335      preds = [x.features for x in preds_df.select("features").collect()]
 336      assert spark_model_transformer.predictions == preds
 337      # 2. score and compare reloaded pyfunc
 338      m = pyfunc.load_model(model_path)
 339      preds2 = m.predict(spark_model_transformer.spark_df.toPandas())
 340      assert spark_model_transformer.predictions == preds2
 341  
 342  
 343  @pytest.mark.skipif(
 344      PYSPARK_VERSION.is_devrelease, reason="this test does not support PySpark dev version."
 345  )
 346  def test_model_deployment(spark_model_iris, model_path, spark_custom_env, monkeypatch):
 347      mlflow.spark.save_model(
 348          spark_model_iris.model,
 349          path=model_path,
 350          conda_env=spark_custom_env,
 351      )
 352      scoring_response = score_model_in_sagemaker_docker_container(
 353          model_uri=model_path,
 354          data=spark_model_iris.pandas_df,
 355          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
 356          flavor=mlflow.pyfunc.FLAVOR_NAME,
 357      )
 358      from mlflow.deployments import PredictionsResponse
 359  
 360      np.testing.assert_array_almost_equal(
 361          spark_model_iris.predictions,
 362          PredictionsResponse.from_json(scoring_response.content).get_predictions(
 363              predictions_format="ndarray"
 364          ),
 365          decimal=4,
 366      )
 367  
 368  
 369  @pytest.mark.skipif(
 370      "dev" in pyspark.__version__,
 371      reason="The dev version of pyspark built from the source doesn't exist on PyPI or Anaconda",
 372  )
 373  def test_sagemaker_docker_model_scoring_with_default_conda_env(spark_model_iris, model_path):
 374      mlflow.spark.save_model(
 375          spark_model_iris.model, path=model_path, extra_pip_requirements=["/opt/mlflow"]
 376      )
 377  
 378      scoring_response = score_model_in_sagemaker_docker_container(
 379          model_uri=model_path,
 380          data=spark_model_iris.pandas_df,
 381          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
 382          flavor=mlflow.pyfunc.FLAVOR_NAME,
 383      )
 384      deployed_model_preds = np.array(json.loads(scoring_response.content)["predictions"])
 385  
 386      np.testing.assert_array_almost_equal(
 387          deployed_model_preds, spark_model_iris.predictions, decimal=4
 388      )
 389  
 390  
 391  @pytest.mark.parametrize("should_start_run", [False, True])
 392  @pytest.mark.parametrize("use_dfs_tmpdir", [False, True])
 393  def test_sparkml_model_log(tmp_path, spark_model_iris, should_start_run, use_dfs_tmpdir):
 394      old_tracking_uri = mlflow.get_tracking_uri()
 395      dfs_tmpdir = None if use_dfs_tmpdir else tmp_path.joinpath("test")
 396  
 397      try:
 398          tracking_dir = tmp_path.joinpath("mlruns")
 399          mlflow.set_tracking_uri(f"file://{tracking_dir}")
 400          if should_start_run:
 401              mlflow.start_run()
 402          artifact_path = "model"
 403          model_info = mlflow.spark.log_model(
 404              spark_model_iris.model,
 405              artifact_path=artifact_path,
 406              dfs_tmpdir=dfs_tmpdir,
 407          )
 408  
 409          reloaded_model = mlflow.spark.load_model(
 410              model_uri=model_info.model_uri, dfs_tmpdir=dfs_tmpdir
 411          )
 412          preds_df = reloaded_model.transform(spark_model_iris.spark_df)
 413          preds = [x.prediction for x in preds_df.select("prediction").collect()]
 414          assert spark_model_iris.predictions == preds
 415      finally:
 416          mlflow.end_run()
 417          mlflow.set_tracking_uri(old_tracking_uri)
 418  
 419  
 420  @pytest.mark.parametrize(
 421      ("registry_uri", "artifact_repo_class"),
 422      [
 423          ("databricks-uc", UnityCatalogModelsArtifactRepository),
 424          ("databricks", DatabricksModelsArtifactRepository),
 425      ],
 426  )
 427  def test_load_spark_model_from_models_uri(
 428      tmp_path, spark_model_estimator, registry_uri, artifact_repo_class
 429  ):
 430      model_dir = str(tmp_path.joinpath("spark_model"))
 431      model_name = "mycatalog.myschema.mymodel"
 432      fake_model_version = ModelVersion(name=model_name, version=str(3), creation_timestamp=0)
 433  
 434      with (
 435          mock.patch(f"{MODELS_ARTIFACT_REPOSITORY}.get_underlying_uri") as mock_get_underlying_uri,
 436          mock.patch.object(
 437              artifact_repo_class, "download_artifacts", return_value=model_dir
 438          ) as mock_download_artifacts,
 439          mock.patch("mlflow.get_registry_uri", return_value=registry_uri),
 440          mock.patch.object(
 441              mlflow.tracking._model_registry.client.ModelRegistryClient,
 442              "get_model_version_by_alias",
 443              return_value=fake_model_version,
 444          ) as get_model_version_by_alias_mock,
 445      ):
 446          mlflow.spark.save_model(
 447              path=model_dir,
 448              spark_model=spark_model_estimator.model,
 449          )
 450          mock_get_underlying_uri.return_value = "nonexistentscheme://fakeuri"
 451          mlflow.spark.load_model(f"models:/{model_name}/1")
 452          # Assert that we downloaded both the MLmodel file and the whole model itself using
 453          # the models:/ URI
 454          kwargs = (
 455              {"lineage_header_info": None}
 456              if artifact_repo_class is UnityCatalogModelsArtifactRepository
 457              else {}
 458          )
 459          mock_download_artifacts.assert_called_once_with("", None, **kwargs)
 460          mock_download_artifacts.reset_mock()
 461          mlflow.spark.load_model(f"models:/{model_name}@Champion")
 462          mock_download_artifacts.assert_called_once_with("", None, **kwargs)
 463          assert get_model_version_by_alias_mock.called_with(model_name, "Champion")
 464  
 465  
 466  @pytest.mark.parametrize("should_start_run", [False, True])
 467  @pytest.mark.parametrize("use_dfs_tmpdir", [False, True])
 468  def test_sparkml_estimator_model_log(
 469      tmp_path, spark_model_estimator, should_start_run, use_dfs_tmpdir
 470  ):
 471      old_tracking_uri = mlflow.get_tracking_uri()
 472      dfs_tmpdir = None if use_dfs_tmpdir else tmp_path.joinpath("test")
 473  
 474      try:
 475          tracking_dir = tmp_path.joinpath("mlruns")
 476          mlflow.set_tracking_uri(f"file://{tracking_dir}")
 477          if should_start_run:
 478              mlflow.start_run()
 479          artifact_path = "model"
 480          model_info = mlflow.spark.log_model(
 481              spark_model_estimator.model,
 482              artifact_path=artifact_path,
 483              dfs_tmpdir=dfs_tmpdir,
 484          )
 485  
 486          reloaded_model = mlflow.spark.load_model(
 487              model_uri=model_info.model_uri, dfs_tmpdir=dfs_tmpdir
 488          )
 489          preds_df = reloaded_model.transform(spark_model_estimator.spark_df)
 490          preds = [x.prediction for x in preds_df.select("prediction").collect()]
 491          assert spark_model_estimator.predictions == preds
 492      finally:
 493          mlflow.end_run()
 494          mlflow.set_tracking_uri(old_tracking_uri)
 495  
 496  
 497  def test_log_model_calls_register_model(tmp_path, spark_model_iris):
 498      artifact_path = "model"
 499      dfs_tmp_dir = tmp_path.joinpath("test")
 500      register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model")
 501      with mlflow.start_run(), register_model_patch:
 502          model_info = mlflow.spark.log_model(
 503              spark_model_iris.model,
 504              artifact_path=artifact_path,
 505              dfs_tmpdir=dfs_tmp_dir,
 506              registered_model_name="AdsModel1",
 507          )
 508          assert_register_model_called_with_local_model_path(
 509              register_model_mock=mlflow.tracking._model_registry.fluent._register_model,
 510              model_uri=model_info.model_uri,
 511              registered_model_name="AdsModel1",
 512          )
 513  
 514  
 515  def test_log_model_no_registered_model_name(tmp_path, spark_model_iris):
 516      artifact_path = "model"
 517      dfs_tmp_dir = os.path.join(tmp_path, "test")
 518      register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model")
 519      with mlflow.start_run(), register_model_patch:
 520          mlflow.spark.log_model(
 521              spark_model_iris.model,
 522              artifact_path=artifact_path,
 523              dfs_tmpdir=dfs_tmp_dir,
 524          )
 525          mlflow.tracking._model_registry.fluent._register_model.assert_not_called()
 526  
 527  
 528  def test_log_model_skips_maybe_save_for_acled_artifact_uri(tmp_path):
 529      """_maybe_save_model should not be called for Databricks ACL-protected artifact URIs
 530      (dbfs:/databricks/mlflow-tracking/...) since Spark cannot write to them directly.
 531      Calling it wastes ~6s per model on a guaranteed Py4JError before falling back.
 532      """
 533      acled_uri = "dbfs:/databricks/mlflow-tracking/abc123/run456/artifacts"
 534  
 535      class FakePipelineModel:
 536          def __init__(self, stages=None):
 537              pass
 538  
 539      mock_model = FakePipelineModel()
 540      with (
 541          mock.patch("mlflow.spark._validate_model"),
 542          mock.patch("mlflow.spark._is_spark_connect_model", return_value=False),
 543          mock.patch("mlflow.spark._maybe_save_model") as mock_maybe_save,
 544          mock.patch("mlflow.get_artifact_uri", return_value=acled_uri),
 545          mock.patch("mlflow.spark._should_use_mlflowdbfs", return_value=False),
 546          mock.patch("mlflow.models.Model._log_v2") as mock_log_v2,
 547          mock.patch("pyspark.ml.PipelineModel", FakePipelineModel),
 548          mlflow.start_run(),
 549      ):
 550          mlflow.spark.log_model(
 551              mock_model,
 552              artifact_path="model",
 553              dfs_tmpdir=str(tmp_path),
 554          )
 555          mock_maybe_save.assert_not_called()
 556          mock_log_v2.assert_called_once()
 557  
 558  
 559  def test_sparkml_model_load_from_remote_uri_succeeds(spark_model_iris, model_path, mock_s3_bucket):
 560      mlflow.spark.save_model(spark_model=spark_model_iris.model, path=model_path)
 561  
 562      artifact_root = f"s3://{mock_s3_bucket}"
 563      artifact_path = "model"
 564      artifact_repo = S3ArtifactRepository(artifact_root)
 565      artifact_repo.log_artifacts(model_path, artifact_path=artifact_path)
 566  
 567      model_uri = artifact_root + "/" + artifact_path
 568      reloaded_model = mlflow.spark.load_model(model_uri=model_uri)
 569      preds_df = reloaded_model.transform(spark_model_iris.spark_df)
 570      preds = [x.prediction for x in preds_df.select("prediction").collect()]
 571      assert spark_model_iris.predictions == preds
 572  
 573  
 574  def test_sparkml_model_save_persists_specified_conda_env_in_mlflow_model_directory(
 575      spark_model_iris, model_path, spark_custom_env
 576  ):
 577      mlflow.spark.save_model(
 578          spark_model=spark_model_iris.model, path=model_path, conda_env=spark_custom_env
 579      )
 580  
 581      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
 582      saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"])
 583      assert os.path.exists(saved_conda_env_path)
 584      assert saved_conda_env_path != spark_custom_env
 585  
 586      with open(spark_custom_env) as f:
 587          spark_custom_env_parsed = yaml.safe_load(f)
 588      with open(saved_conda_env_path) as f:
 589          saved_conda_env_parsed = yaml.safe_load(f)
 590      assert saved_conda_env_parsed == spark_custom_env_parsed
 591  
 592  
 593  def test_sparkml_model_save_persists_requirements_in_mlflow_model_directory(
 594      spark_model_iris, model_path, spark_custom_env
 595  ):
 596      mlflow.spark.save_model(
 597          spark_model=spark_model_iris.model, path=model_path, conda_env=spark_custom_env
 598      )
 599  
 600      saved_pip_req_path = os.path.join(model_path, "requirements.txt")
 601      _compare_conda_env_requirements(spark_custom_env, saved_pip_req_path)
 602  
 603  
 604  def test_log_model_with_pip_requirements(spark_model_iris, tmp_path):
 605      expected_mlflow_version = _mlflow_major_version_string()
 606      # Path to a requirements file
 607      req_file = tmp_path.joinpath("requirements.txt")
 608      req_file.write_text("a")
 609      with mlflow.start_run():
 610          model_info = mlflow.spark.log_model(
 611              spark_model_iris.model, artifact_path="model", pip_requirements=str(req_file)
 612          )
 613          _assert_pip_requirements(model_info.model_uri, [expected_mlflow_version, "a"], strict=True)
 614  
 615      # List of requirements
 616      with mlflow.start_run():
 617          model_info = mlflow.spark.log_model(
 618              spark_model_iris.model, artifact_path="model", pip_requirements=[f"-r {req_file}", "b"]
 619          )
 620          _assert_pip_requirements(
 621              model_info.model_uri, [expected_mlflow_version, "a", "b"], strict=True
 622          )
 623  
 624      # Constraints file
 625      with mlflow.start_run():
 626          model_info = mlflow.spark.log_model(
 627              spark_model_iris.model, artifact_path="model", pip_requirements=[f"-c {req_file}", "b"]
 628          )
 629          _assert_pip_requirements(
 630              model_info.model_uri,
 631              [expected_mlflow_version, "b", "-c constraints.txt"],
 632              ["a"],
 633              strict=True,
 634          )
 635  
 636  
 637  def test_log_model_with_extra_pip_requirements(spark_model_iris, tmp_path):
 638      expected_mlflow_version = _mlflow_major_version_string()
 639      default_reqs = mlflow.spark.get_default_pip_requirements()
 640  
 641      # Path to a requirements file
 642      req_file = tmp_path.joinpath("requirements.txt")
 643      req_file.write_text("a")
 644      with mlflow.start_run():
 645          model_info = mlflow.spark.log_model(
 646              spark_model_iris.model, artifact_path="model", extra_pip_requirements=str(req_file)
 647          )
 648          _assert_pip_requirements(
 649              model_info.model_uri, [expected_mlflow_version, *default_reqs, "a"]
 650          )
 651  
 652      # List of requirements
 653      with mlflow.start_run():
 654          model_info = mlflow.spark.log_model(
 655              spark_model_iris.model,
 656              artifact_path="model",
 657              extra_pip_requirements=[f"-r {req_file}", "b"],
 658          )
 659          _assert_pip_requirements(
 660              model_info.model_uri, [expected_mlflow_version, *default_reqs, "a", "b"]
 661          )
 662  
 663      # Constraints file
 664      with mlflow.start_run():
 665          model_info = mlflow.spark.log_model(
 666              spark_model_iris.model,
 667              artifact_path="model",
 668              extra_pip_requirements=[f"-c {req_file}", "b"],
 669          )
 670          _assert_pip_requirements(
 671              model_info.model_uri,
 672              [expected_mlflow_version, *default_reqs, "b", "-c constraints.txt"],
 673              ["a"],
 674          )
 675  
 676  
 677  def test_sparkml_model_save_accepts_conda_env_as_dict(spark_model_iris, model_path):
 678      conda_env = dict(mlflow.spark.get_default_conda_env())
 679      conda_env["dependencies"].append("pytest")
 680      mlflow.spark.save_model(
 681          spark_model=spark_model_iris.model, path=model_path, conda_env=conda_env
 682      )
 683  
 684      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
 685      saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"])
 686      assert os.path.exists(saved_conda_env_path)
 687  
 688      with open(saved_conda_env_path) as f:
 689          saved_conda_env_parsed = yaml.safe_load(f)
 690      assert saved_conda_env_parsed == conda_env
 691  
 692  
 693  def test_sparkml_model_log_persists_specified_conda_env_in_mlflow_model_directory(
 694      spark_model_iris, model_path, spark_custom_env
 695  ):
 696      artifact_path = "model"
 697      with mlflow.start_run():
 698          model_info = mlflow.spark.log_model(
 699              spark_model_iris.model,
 700              artifact_path=artifact_path,
 701              conda_env=spark_custom_env,
 702          )
 703  
 704      model_path = _download_artifact_from_uri(artifact_uri=model_info.model_uri)
 705      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
 706      saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"])
 707      assert os.path.exists(saved_conda_env_path)
 708      assert saved_conda_env_path != spark_custom_env
 709  
 710      with open(spark_custom_env) as f:
 711          spark_custom_env_parsed = yaml.safe_load(f)
 712      with open(saved_conda_env_path) as f:
 713          saved_conda_env_parsed = yaml.safe_load(f)
 714      assert saved_conda_env_parsed == spark_custom_env_parsed
 715  
 716  
 717  def test_sparkml_model_log_persists_requirements_in_mlflow_model_directory(
 718      spark_model_iris, model_path, spark_custom_env
 719  ):
 720      artifact_path = "model"
 721      with mlflow.start_run():
 722          model_info = mlflow.spark.log_model(
 723              spark_model_iris.model,
 724              artifact_path=artifact_path,
 725              conda_env=spark_custom_env,
 726          )
 727      model_path = _download_artifact_from_uri(artifact_uri=model_info.model_uri)
 728      saved_pip_req_path = os.path.join(model_path, "requirements.txt")
 729      _compare_conda_env_requirements(spark_custom_env, saved_pip_req_path)
 730  
 731  
 732  def test_sparkml_model_save_without_specified_conda_env_uses_default_env_with_expected_dependencies(
 733      spark_model_iris, model_path
 734  ):
 735      mlflow.spark.save_model(spark_model=spark_model_iris.model, path=model_path)
 736      _assert_pip_requirements(model_path, mlflow.spark.get_default_pip_requirements())
 737  
 738  
 739  def test_sparkml_model_log_without_specified_conda_env_uses_default_env_with_expected_dependencies(
 740      spark_model_iris,
 741  ):
 742      artifact_path = "model"
 743      with mlflow.start_run():
 744          model_info = mlflow.spark.log_model(spark_model_iris.model, artifact_path=artifact_path)
 745  
 746      _assert_pip_requirements(model_info.model_uri, mlflow.spark.get_default_pip_requirements())
 747  
 748  
 749  def test_pyspark_version_is_logged_without_dev_suffix(spark_model_iris):
 750      expected_mlflow_version = _mlflow_major_version_string()
 751      unsuffixed_version = "2.4.0"
 752      for dev_suffix in [".dev0", ".dev", ".dev1", "dev.a", ".devb"]:
 753          with mock.patch("importlib_metadata.version", return_value=unsuffixed_version + dev_suffix):
 754              with mlflow.start_run():
 755                  model_info = mlflow.spark.log_model(spark_model_iris.model, artifact_path="model")
 756              _assert_pip_requirements(
 757                  model_info.model_uri, [expected_mlflow_version, f"pyspark=={unsuffixed_version}"]
 758              )
 759  
 760      for unaffected_version in ["2.0", "2.3.4", "2"]:
 761          with mock.patch("importlib_metadata.version", return_value=unaffected_version):
 762              pip_deps = _get_pip_deps(mlflow.spark.get_default_conda_env())
 763              assert any(x == f"pyspark=={unaffected_version}" for x in pip_deps)
 764  
 765  
 766  def test_model_is_recorded_when_using_direct_save(spark_model_iris):
 767      # Patch `is_local_uri` to enforce direct model serialization to DFS
 768      with mock.patch("mlflow.spark.is_local_uri", return_value=False):
 769          with mlflow.start_run():
 770              mlflow.spark.log_model(spark_model_iris.model, artifact_path="model")
 771              current_tags = mlflow.get_run(mlflow.active_run().info.run_id).data.tags
 772              assert mlflow.utils.mlflow_tags.MLFLOW_LOGGED_MODELS in current_tags
 773  
 774  
 775  @pytest.mark.parametrize(
 776      (
 777          "artifact_uri",
 778          "db_runtime_version",
 779          "mlflowdbfs_disabled",
 780          "mlflowdbfs_available",
 781          "dbutils_available",
 782          "expected_uri",
 783          "expect_log_v2",
 784      ),
 785      [
 786          (
 787              "dbfs:/databricks/mlflow-tracking/a/b",
 788              "12.0",
 789              "",
 790              True,
 791              True,
 792              "mlflowdbfs:///artifacts?run_id={}&path=/model/sparkml",
 793              False,
 794          ),
 795          (
 796              "dbfs:/databricks/mlflow-tracking/a/b",
 797              "12.0",
 798              "false",
 799              True,
 800              True,
 801              "mlflowdbfs:///artifacts?run_id={}&path=/model/sparkml",
 802              False,
 803          ),
 804          # ACL-protected paths where mlflowdbfs is unavailable/disabled always route through
 805          # Model._log_v2 because _maybe_save_model is skipped via is_databricks_acled_artifacts_uri.
 806          # In real Databricks, _maybe_save_model always fails with Py4JError for these paths anyway.
 807          (
 808              "dbfs:/databricks/mlflow-tracking/a/b",
 809              "12.0",
 810              "false",
 811              True,
 812              False,
 813              None,
 814              True,
 815          ),
 816          (
 817              "dbfs:/databricks/mlflow-tracking/a/b",
 818              "12.0",
 819              "",
 820              False,
 821              True,
 822              None,
 823              True,
 824          ),
 825          (
 826              "dbfs:/databricks/mlflow-tracking/a/b",
 827              "",
 828              "",
 829              True,
 830              True,
 831              None,
 832              True,
 833          ),
 834          (
 835              "dbfs:/databricks/mlflow-tracking/a/b",
 836              "12.0",
 837              "true",
 838              True,
 839              True,
 840              None,
 841              True,
 842          ),
 843          ("dbfs:/root/a/b", "12.0", "", True, True, "dbfs:/root/a/b/model/sparkml", False),
 844          ("s3://mybucket/a/b", "12.0", "", True, True, "s3://mybucket/a/b/model/sparkml", False),
 845      ],
 846  )
 847  def test_model_logged_via_mlflowdbfs_when_appropriate(
 848      monkeypatch,
 849      spark_model_iris,
 850      artifact_uri,
 851      db_runtime_version,
 852      mlflowdbfs_disabled,
 853      mlflowdbfs_available,
 854      dbutils_available,
 855      expected_uri,
 856      expect_log_v2,
 857  ):
 858      def mock_spark_session_load(path):
 859          raise Exception("MlflowDbfsClient operation failed!")
 860  
 861      mock_spark_session = mock.Mock()
 862      mock_read_spark_session = mock.Mock()
 863      mock_read_spark_session.load = mock_spark_session_load
 864  
 865      from mlflow.utils.databricks_utils import _get_dbutils as og_getdbutils
 866  
 867      def mock_get_dbutils():
 868          # _get_dbutils is called during run creation and model logging; to avoid breaking run
 869          # creation, we only mock the output if _get_dbutils is called during spark model logging
 870          caller_fn_name = inspect.stack()[1].function
 871          if caller_fn_name == "_should_use_mlflowdbfs":
 872              if dbutils_available:
 873                  return mock.Mock()
 874              else:
 875                  raise Exception("dbutils not available")
 876          else:
 877              return og_getdbutils()
 878  
 879      with (
 880          mock.patch(
 881              "mlflow.utils._spark_utils._get_active_spark_session", return_value=mock_spark_session
 882          ),
 883          mock.patch("mlflow.get_artifact_uri", return_value=artifact_uri),
 884          mock.patch(
 885              "mlflow.spark._HadoopFileSystem.is_filesystem_available",
 886              return_value=mlflowdbfs_available,
 887          ),
 888          mock.patch("mlflow.utils.databricks_utils.MlflowCredentialContext", autospec=True),
 889          mock.patch("mlflow.utils.databricks_utils._get_dbutils", mock_get_dbutils),
 890          mock.patch.object(spark_model_iris.model, "save") as mock_save,
 891          mock.patch("mlflow.models.infer_pip_requirements", return_value=[]) as mock_infer,
 892          mock.patch("mlflow.models.Model._log_v2") as mock_log_v2,
 893      ):
 894          with mlflow.start_run():
 895              if db_runtime_version:
 896                  monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", db_runtime_version)
 897              monkeypatch.setenv("DISABLE_MLFLOWDBFS", mlflowdbfs_disabled)
 898              mlflow.spark.log_model(spark_model_iris.model, artifact_path="model")
 899  
 900              if expect_log_v2:
 901                  # ACL-protected paths where mlflowdbfs is unavailable skip _maybe_save_model
 902                  # entirely and fall through to Model._log_v2. In production, _maybe_save_model
 903                  # always raises Py4JError for these paths, so skipping it is correct.
 904                  mock_log_v2.assert_called_once()
 905                  mock_save.assert_not_called()
 906              else:
 907                  mock_save.assert_called_once_with(
 908                      expected_uri.format(mlflow.active_run().info.run_id)
 909                  )
 910  
 911                  if expected_uri.startswith("mlflowdbfs"):
 912                      # If mlflowdbfs is used, infer_pip_requirements should load the model from the
 913                      # remote model path instead of a local tmp path.
 914                      assert (
 915                          mock_infer.call_args[0][0]
 916                          == "dbfs:/databricks/mlflow-tracking/a/b/model/sparkml"
 917                      )
 918  
 919  
 920  @pytest.mark.parametrize("dummy_read_shows_mlflowdbfs_available", [True, False])
 921  def test_model_logging_uses_mlflowdbfs_if_appropriate_when_hdfs_check_fails(
 922      monkeypatch, spark_model_iris, dummy_read_shows_mlflowdbfs_available
 923  ):
 924      def mock_spark_session_load(path):
 925          if dummy_read_shows_mlflowdbfs_available:
 926              raise Exception("MlflowdbfsClient operation failed!")
 927          else:
 928              raise Exception("mlflowdbfs filesystem not found")
 929  
 930      mock_read_spark_session = mock.Mock()
 931      mock_read_spark_session.load = mock_spark_session_load
 932      mock_spark_session = mock.Mock()
 933      mock_spark_session.read = mock_read_spark_session
 934  
 935      from mlflow.utils.databricks_utils import _get_dbutils as og_getdbutils
 936  
 937      def mock_get_dbutils():
 938          # _get_dbutils is called during run creation and model logging; to avoid breaking run
 939          # creation, we only mock the output if _get_dbutils is called during spark model logging
 940          caller_fn_name = inspect.stack()[1].function
 941          if caller_fn_name == "_should_use_mlflowdbfs":
 942              return mock.Mock()
 943          else:
 944              return og_getdbutils()
 945  
 946      with (
 947          mock.patch(
 948              "mlflow.utils._spark_utils._get_active_spark_session",
 949              return_value=mock_spark_session,
 950          ),
 951          mock.patch(
 952              "mlflow.get_artifact_uri",
 953              return_value="dbfs:/databricks/mlflow-tracking/a/b",
 954          ),
 955          mock.patch(
 956              "mlflow.spark._HadoopFileSystem.is_filesystem_available",
 957              side_effect=Exception("MlflowDbfsClient operation failed!"),
 958          ),
 959          mock.patch("mlflow.utils.databricks_utils.MlflowCredentialContext", autospec=True),
 960          mock.patch(
 961              "mlflow.utils.databricks_utils._get_dbutils",
 962              mock_get_dbutils,
 963          ),
 964          mock.patch.object(spark_model_iris.model, "save") as mock_save,
 965          mock.patch("mlflow.models.Model._log_v2") as mock_log_v2,
 966      ):
 967          with mlflow.start_run():
 968              monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "12.0")
 969              mlflow.spark.log_model(spark_model_iris.model, artifact_path="model")
 970              run_id = mlflow.active_run().info.run_id
 971              if dummy_read_shows_mlflowdbfs_available:
 972                  mock_save.assert_called_once_with(
 973                      f"mlflowdbfs:///artifacts?run_id={run_id}&path=/model/sparkml"
 974                  )
 975              else:
 976                  # mlflowdbfs unavailable + ACL-protected path: _maybe_save_model is skipped,
 977                  # Model._log_v2 is called directly. In production, _maybe_save_model always
 978                  # raises Py4JError for these ACL-protected paths, so skipping it is correct.
 979                  mock_log_v2.assert_called_once()
 980                  mock_save.assert_not_called()
 981  
 982  
 983  def test_log_model_with_code_paths(spark_model_iris):
 984      artifact_path = "model"
 985      with (
 986          mlflow.start_run(),
 987          mock.patch(
 988              "mlflow.spark._add_code_from_conf_to_system_path",
 989              wraps=_add_code_from_conf_to_system_path,
 990          ) as add_mock,
 991      ):
 992          model_info = mlflow.spark.log_model(
 993              spark_model_iris.model, artifact_path=artifact_path, code_paths=[__file__]
 994          )
 995          _compare_logged_code_paths(__file__, model_info.model_uri, mlflow.spark.FLAVOR_NAME)
 996          mlflow.spark.load_model(model_info.model_uri)
 997          add_mock.assert_called()
 998  
 999  
1000  def test_virtualenv_subfield_points_to_correct_path(spark_model_iris, model_path):
1001      mlflow.spark.save_model(spark_model_iris.model, path=model_path)
1002      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
1003      python_env_path = Path(model_path, pyfunc_conf[pyfunc.ENV]["virtualenv"])
1004      assert python_env_path.exists()
1005      assert python_env_path.is_file()
1006  
1007  
1008  def test_model_save_load_with_metadata(spark_model_iris, model_path):
1009      mlflow.spark.save_model(
1010          spark_model_iris.model, path=model_path, metadata={"metadata_key": "metadata_value"}
1011      )
1012  
1013      reloaded_model = mlflow.pyfunc.load_model(model_uri=model_path)
1014      assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value"
1015  
1016  
1017  def test_model_log_with_metadata(spark_model_iris):
1018      with mlflow.start_run():
1019          model_info = mlflow.spark.log_model(
1020              spark_model_iris.model,
1021              artifact_path="model",
1022              metadata={"metadata_key": "metadata_value"},
1023          )
1024  
1025      reloaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri)
1026      assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value"
1027  
1028  
1029  _df_input_example = iris_pandas_df().drop("label", axis=1).iloc[[0]]
1030  
1031  
1032  @pytest.mark.parametrize(
1033      "input_example",
1034      # array and dict input examples are not supported any more as they
1035      # won't be converted to pandas dataframe when saving example
1036      [_df_input_example],
1037  )
1038  def test_model_log_with_signature_inference(spark_model_iris, input_example):
1039      artifact_path = "model"
1040  
1041      with mlflow.start_run():
1042          model_info = mlflow.spark.log_model(
1043              spark_model_iris.model, artifact_path=artifact_path, input_example=input_example
1044          )
1045  
1046      mlflow_model = Model.load(model_info.model_uri)
1047      input_columns = mlflow_model.signature.inputs.inputs
1048      assert all(col.type == DataType.double for col in input_columns)
1049      column_names = [col.name for col in input_columns]
1050      if isinstance(input_example, list):
1051          assert column_names == [0, 1, 2, 3]
1052      else:
1053          assert column_names == ["0", "1", "2", "3"]
1054      assert mlflow_model.signature.outputs == Schema([ColSpec(type=DataType.double)])
1055  
1056  
1057  def test_log_model_with_vector_input_type_signature(spark, spark_model_estimator):
1058      from pyspark.ml.functions import vector_to_array
1059  
1060      from mlflow.types.schema import SparkMLVector
1061  
1062      model = spark_model_estimator.model
1063      with mlflow.start_run():
1064          model_info = mlflow.spark.log_model(
1065              model,
1066              artifact_path="model",
1067              signature=ModelSignature(
1068                  inputs=Schema([
1069                      ColSpec(name="features", type=SparkMLVector()),
1070                  ]),
1071                  outputs=Schema([ColSpec(type=DataType.double)]),
1072              ),
1073          )
1074  
1075      model_uri = model_info.model_uri
1076      model_meta = Model.load(model_uri)
1077      input_type = model_meta.signature.inputs.input_dict()["features"].type
1078      assert isinstance(input_type, SparkMLVector)
1079  
1080      pyfunc_model = pyfunc.load_model(model_uri)
1081      infer_data = spark_model_estimator.spark_df.withColumn(
1082          "features", vector_to_array("features")
1083      ).toPandas()
1084      preds = pyfunc_model.predict(infer_data)
1085      assert spark_model_estimator.predictions == preds