test_spark_model_export.py
1 import inspect 2 import json 3 import logging 4 import os 5 from pathlib import Path 6 from typing import Any, NamedTuple 7 from unittest import mock 8 9 import numpy as np 10 import pandas as pd 11 import pyspark 12 import pytest 13 import yaml 14 from packaging.version import Version 15 from pyspark.ml.classification import LogisticRegression 16 from pyspark.ml.feature import VectorAssembler 17 from pyspark.ml.pipeline import Pipeline 18 from sklearn import datasets 19 20 import mlflow 21 import mlflow.pyfunc.scoring_server as pyfunc_scoring_server 22 import mlflow.tracking 23 import mlflow.utils.file_utils 24 from mlflow import pyfunc 25 from mlflow.entities.model_registry import ModelVersion 26 from mlflow.environment_variables import MLFLOW_DFS_TMP 27 from mlflow.exceptions import MlflowException 28 from mlflow.models import Model, ModelSignature 29 from mlflow.models.utils import _read_example 30 from mlflow.spark import _add_code_from_conf_to_system_path 31 from mlflow.store.artifact.databricks_models_artifact_repo import DatabricksModelsArtifactRepository 32 from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository 33 from mlflow.store.artifact.unity_catalog_models_artifact_repo import ( 34 UnityCatalogModelsArtifactRepository, 35 ) 36 from mlflow.tracking.artifact_utils import _download_artifact_from_uri 37 from mlflow.types import DataType 38 from mlflow.types.schema import ColSpec, Schema 39 from mlflow.utils.environment import _get_pip_deps, _mlflow_conda_env 40 from mlflow.utils.file_utils import TempDir 41 from mlflow.utils.model_utils import _get_flavor_configuration 42 43 from tests.helper_functions import ( 44 _assert_pip_requirements, 45 _compare_conda_env_requirements, 46 _compare_logged_code_paths, 47 _mlflow_major_version_string, 48 assert_register_model_called_with_local_model_path, 49 score_model_in_sagemaker_docker_container, 50 ) 51 from tests.pyfunc.test_spark import get_spark_session, score_model_as_udf 52 from tests.store.artifact.constants import MODELS_ARTIFACT_REPOSITORY 53 54 _logger = logging.getLogger(__name__) 55 56 PYSPARK_VERSION = Version(pyspark.__version__) 57 58 59 @pytest.fixture 60 def spark_custom_env(tmp_path): 61 conda_env = os.path.join(tmp_path, "conda_env.yml") 62 additional_pip_deps = ["/opt/mlflow", f"pyspark=={PYSPARK_VERSION}", "pytest"] 63 if PYSPARK_VERSION < Version("3.4"): 64 additional_pip_deps.extend([ 65 # Versions of PySpark < 3.4 are incompatible with pandas >= 2 66 "pandas<2", 67 # pandas<2.0 is incompatible with numpy>=2.0 68 "numpy<2.0", 69 ]) 70 _mlflow_conda_env(conda_env, additional_pip_deps=additional_pip_deps) 71 return conda_env 72 73 74 class SparkModelWithData(NamedTuple): 75 model: Any 76 spark_df: Any 77 pandas_df: Any 78 predictions: Any 79 80 81 def _get_spark_session_with_retry(max_tries=3): 82 conf = pyspark.SparkConf() 83 for attempt in range(max_tries): 84 try: 85 return get_spark_session(conf) 86 except Exception as e: 87 if attempt >= max_tries - 1: 88 raise 89 _logger.exception( 90 f"Attempt {attempt} to create a SparkSession failed ({e!r}), retrying..." 91 ) 92 93 94 # Specify `autouse=True` to ensure that a context is created 95 # before any tests are executed. This ensures that the Hadoop filesystem 96 # does not create its own SparkContext. 97 @pytest.fixture(scope="module") 98 def spark(): 99 if Version(pyspark.__version__) < Version("3.1"): 100 # A workaround for this issue: 101 # https://stackoverflow.com/questions/62109276/errorjava-lang-unsupportedoperationexception-for-pyspark-pandas-udf-documenta 102 spark_home = ( 103 os.environ.get("SPARK_HOME") 104 if "SPARK_HOME" in os.environ 105 else os.path.dirname(pyspark.__file__) 106 ) 107 conf_dir = os.path.join(spark_home, "conf") 108 os.makedirs(conf_dir, exist_ok=True) 109 with open(os.path.join(conf_dir, "spark-defaults.conf"), "w") as f: 110 conf = """ 111 spark.driver.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true" 112 spark.executor.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true" 113 """ 114 f.write(conf) 115 116 with _get_spark_session_with_retry() as spark: 117 yield spark 118 119 120 def iris_pandas_df(): 121 iris = datasets.load_iris() 122 X = iris.data 123 y = iris.target 124 feature_names = ["0", "1", "2", "3"] 125 df = pd.DataFrame(X, columns=feature_names) # to make spark_udf work 126 df["label"] = pd.Series(y) 127 return df 128 129 130 @pytest.fixture(scope="module") 131 def iris_df(spark): 132 pdf = iris_pandas_df() 133 feature_names = list(pdf.drop("label", axis=1).columns) 134 iris_spark_df = spark.createDataFrame(pdf) 135 return feature_names, pdf, iris_spark_df 136 137 138 @pytest.fixture(scope="module") 139 def iris_signature(): 140 return ModelSignature( 141 inputs=Schema([ 142 ColSpec(name="0", type=DataType.double), 143 ColSpec(name="1", type=DataType.double), 144 ColSpec(name="2", type=DataType.double), 145 ColSpec(name="3", type=DataType.double), 146 ]), 147 outputs=Schema([ColSpec(type=DataType.double)]), 148 ) 149 150 151 @pytest.fixture(scope="module") 152 def spark_model_iris(iris_df): 153 feature_names, iris_pandas_df, iris_spark_df = iris_df 154 assembler = VectorAssembler(inputCols=feature_names, outputCol="features") 155 lr = LogisticRegression(maxIter=50, regParam=0.1, elasticNetParam=0.8) 156 pipeline = Pipeline(stages=[assembler, lr]) 157 # Fit the model 158 model = pipeline.fit(iris_spark_df) 159 preds_df = model.transform(iris_spark_df) 160 preds = [x.prediction for x in preds_df.select("prediction").collect()] 161 return SparkModelWithData( 162 model=model, spark_df=iris_spark_df, pandas_df=iris_pandas_df, predictions=preds 163 ) 164 165 166 @pytest.fixture(scope="module") 167 def spark_model_transformer(iris_df): 168 feature_names, iris_pandas_df, iris_spark_df = iris_df 169 assembler = VectorAssembler(inputCols=feature_names, outputCol="features") 170 # Fit the model 171 preds_df = assembler.transform(iris_spark_df) 172 preds = [x.features for x in preds_df.select("features").collect()] 173 return SparkModelWithData( 174 model=assembler, spark_df=iris_spark_df, pandas_df=iris_pandas_df, predictions=preds 175 ) 176 177 178 @pytest.fixture(scope="module") 179 def spark_model_estimator(iris_df): 180 feature_names, iris_pandas_df, iris_spark_df = iris_df 181 assembler = VectorAssembler(inputCols=feature_names, outputCol="features") 182 features_df = assembler.transform(iris_spark_df) 183 lr = LogisticRegression(maxIter=50, regParam=0.1, elasticNetParam=0.8) 184 # Fit the model 185 model = lr.fit(features_df) 186 preds_df = model.transform(features_df) 187 preds = [x.prediction for x in preds_df.select("prediction").collect()] 188 return SparkModelWithData( 189 model=model, spark_df=features_df, pandas_df=iris_pandas_df, predictions=preds 190 ) 191 192 193 @pytest.fixture 194 def model_path(tmp_path): 195 return os.path.join(tmp_path, "model") 196 197 198 @pytest.mark.usefixtures("spark") 199 def test_hadoop_filesystem(tmp_path): 200 # copy local dir to and back from HadoopFS and make sure the results match 201 from mlflow.spark import _HadoopFileSystem as FS 202 203 test_dir_0 = os.path.join(tmp_path, "expected") 204 test_file_0 = os.path.join(test_dir_0, "root", "file_0") 205 test_dir_1 = os.path.join(test_dir_0, "root", "subdir") 206 test_file_1 = os.path.join(test_dir_1, "file_1") 207 os.makedirs(os.path.dirname(test_file_0)) 208 with open(test_file_0, "w") as f: 209 f.write("test0") 210 os.makedirs(os.path.dirname(test_file_1)) 211 with open(test_file_1, "w") as f: 212 f.write("test1") 213 remote = "/tmp/mlflow/test0" 214 # File should not be copied in this case 215 assert os.path.abspath(test_dir_0) == FS.maybe_copy_from_local_file(test_dir_0, remote) 216 FS.copy_from_local_file(test_dir_0, remote, remove_src=False) 217 local = os.path.join(tmp_path, "actual") 218 FS.copy_to_local_file(remote, local, remove_src=True) 219 assert sorted(os.listdir(os.path.join(local, "root"))) == sorted([ 220 "subdir", 221 "file_0", 222 ".file_0.crc", 223 ]) 224 assert sorted(os.listdir(os.path.join(local, "root", "subdir"))) == sorted([ 225 "file_1", 226 ".file_1.crc", 227 ]) 228 # compare the files 229 with open(os.path.join(test_dir_0, "root", "file_0")) as expected_f: 230 with open(os.path.join(local, "root", "file_0")) as actual_f: 231 assert expected_f.read() == actual_f.read() 232 with open(os.path.join(test_dir_0, "root", "subdir", "file_1")) as expected_f: 233 with open(os.path.join(local, "root", "subdir", "file_1")) as actual_f: 234 assert expected_f.read() == actual_f.read() 235 236 # make sure we cleanup 237 assert not os.path.exists(FS._remote_path(remote).toString()) # skip file: prefix 238 FS.copy_from_local_file(test_dir_0, remote, remove_src=False) 239 assert os.path.exists(FS._remote_path(remote).toString()) # skip file: prefix 240 FS.delete(remote) 241 assert not os.path.exists(FS._remote_path(remote).toString()) # skip file: prefix 242 243 244 def test_model_export(spark_model_iris, model_path, spark_custom_env): 245 mlflow.spark.save_model(spark_model_iris.model, path=model_path, conda_env=spark_custom_env) 246 # 1. score and compare reloaded sparkml model 247 reloaded_model = mlflow.spark.load_model(model_uri=model_path) 248 preds_df = reloaded_model.transform(spark_model_iris.spark_df) 249 preds1 = [x.prediction for x in preds_df.select("prediction").collect()] 250 assert spark_model_iris.predictions == preds1 251 m = pyfunc.load_model(model_path) 252 # 2. score and compare reloaded pyfunc 253 preds2 = m.predict(spark_model_iris.pandas_df) 254 assert spark_model_iris.predictions == preds2 255 # 3. score and compare reloaded pyfunc Spark udf 256 preds3 = score_model_as_udf(model_uri=model_path, pandas_df=spark_model_iris.pandas_df) 257 assert spark_model_iris.predictions == preds3 258 assert os.path.exists(MLFLOW_DFS_TMP.get()) 259 260 261 def test_model_export_with_signature_and_examples(spark_model_iris, iris_signature): 262 features_df = spark_model_iris.pandas_df.drop("label", axis=1) 263 example_ = features_df.head(3) 264 for signature in (None, iris_signature): 265 for example in (None, example_): 266 with TempDir() as tmp: 267 path = tmp.path("model") 268 mlflow.spark.save_model( 269 spark_model_iris.model, path=path, signature=signature, input_example=example 270 ) 271 mlflow_model = Model.load(path) 272 if example is None and signature is None: 273 assert mlflow_model.signature is None 274 else: 275 assert mlflow_model.signature == iris_signature 276 if example is None: 277 assert mlflow_model.saved_input_example_info is None 278 else: 279 assert all((_read_example(mlflow_model, path) == example).all()) 280 281 282 def test_model_export_raise_when_example_is_spark_dataframe(spark, spark_model_iris, model_path): 283 features_df = spark_model_iris.pandas_df.drop("label", axis=1) 284 example = spark.createDataFrame(features_df.head(3)) 285 with pytest.raises(MlflowException, match="Examples can not be provided as Spark Dataframe."): 286 mlflow.spark.save_model(spark_model_iris.model, path=model_path, input_example=example) 287 288 289 def test_log_model_with_signature_and_examples(spark_model_iris, iris_signature): 290 features_df = spark_model_iris.pandas_df.drop("label", axis=1) 291 example_ = features_df.head(3) 292 artifact_path = "model" 293 for signature in (None, iris_signature): 294 for example in (None, example_): 295 with mlflow.start_run(): 296 model_info = mlflow.spark.log_model( 297 spark_model_iris.model, 298 artifact_path=artifact_path, 299 signature=signature, 300 input_example=example, 301 ) 302 mlflow_model = Model.load(model_info.model_uri) 303 if example is None and signature is None: 304 assert mlflow_model.signature is None 305 else: 306 assert mlflow_model.signature == iris_signature 307 if example is None: 308 assert mlflow_model.saved_input_example_info is None 309 else: 310 assert all((_read_example(mlflow_model, model_info.model_uri) == example).all()) 311 312 313 def test_estimator_model_export(spark_model_estimator, model_path, spark_custom_env): 314 mlflow.spark.save_model( 315 spark_model_estimator.model, path=model_path, conda_env=spark_custom_env 316 ) 317 # score and compare the reloaded sparkml model 318 reloaded_model = mlflow.spark.load_model(model_uri=model_path) 319 preds_df = reloaded_model.transform(spark_model_estimator.spark_df) 320 preds = [x.prediction for x in preds_df.select("prediction").collect()] 321 assert spark_model_estimator.predictions == preds 322 # 2. score and compare reloaded pyfunc 323 m = pyfunc.load_model(model_path) 324 preds2 = m.predict(spark_model_estimator.spark_df.toPandas()) 325 assert spark_model_estimator.predictions == preds2 326 327 328 def test_transformer_model_export(spark_model_transformer, model_path, spark_custom_env): 329 mlflow.spark.save_model( 330 spark_model_transformer.model, path=model_path, conda_env=spark_custom_env 331 ) 332 # score and compare the reloaded sparkml model 333 reloaded_model = mlflow.spark.load_model(model_uri=model_path) 334 preds_df = reloaded_model.transform(spark_model_transformer.spark_df) 335 preds = [x.features for x in preds_df.select("features").collect()] 336 assert spark_model_transformer.predictions == preds 337 # 2. score and compare reloaded pyfunc 338 m = pyfunc.load_model(model_path) 339 preds2 = m.predict(spark_model_transformer.spark_df.toPandas()) 340 assert spark_model_transformer.predictions == preds2 341 342 343 @pytest.mark.skipif( 344 PYSPARK_VERSION.is_devrelease, reason="this test does not support PySpark dev version." 345 ) 346 def test_model_deployment(spark_model_iris, model_path, spark_custom_env, monkeypatch): 347 mlflow.spark.save_model( 348 spark_model_iris.model, 349 path=model_path, 350 conda_env=spark_custom_env, 351 ) 352 scoring_response = score_model_in_sagemaker_docker_container( 353 model_uri=model_path, 354 data=spark_model_iris.pandas_df, 355 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 356 flavor=mlflow.pyfunc.FLAVOR_NAME, 357 ) 358 from mlflow.deployments import PredictionsResponse 359 360 np.testing.assert_array_almost_equal( 361 spark_model_iris.predictions, 362 PredictionsResponse.from_json(scoring_response.content).get_predictions( 363 predictions_format="ndarray" 364 ), 365 decimal=4, 366 ) 367 368 369 @pytest.mark.skipif( 370 "dev" in pyspark.__version__, 371 reason="The dev version of pyspark built from the source doesn't exist on PyPI or Anaconda", 372 ) 373 def test_sagemaker_docker_model_scoring_with_default_conda_env(spark_model_iris, model_path): 374 mlflow.spark.save_model( 375 spark_model_iris.model, path=model_path, extra_pip_requirements=["/opt/mlflow"] 376 ) 377 378 scoring_response = score_model_in_sagemaker_docker_container( 379 model_uri=model_path, 380 data=spark_model_iris.pandas_df, 381 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 382 flavor=mlflow.pyfunc.FLAVOR_NAME, 383 ) 384 deployed_model_preds = np.array(json.loads(scoring_response.content)["predictions"]) 385 386 np.testing.assert_array_almost_equal( 387 deployed_model_preds, spark_model_iris.predictions, decimal=4 388 ) 389 390 391 @pytest.mark.parametrize("should_start_run", [False, True]) 392 @pytest.mark.parametrize("use_dfs_tmpdir", [False, True]) 393 def test_sparkml_model_log(tmp_path, spark_model_iris, should_start_run, use_dfs_tmpdir): 394 old_tracking_uri = mlflow.get_tracking_uri() 395 dfs_tmpdir = None if use_dfs_tmpdir else tmp_path.joinpath("test") 396 397 try: 398 tracking_dir = tmp_path.joinpath("mlruns") 399 mlflow.set_tracking_uri(f"file://{tracking_dir}") 400 if should_start_run: 401 mlflow.start_run() 402 artifact_path = "model" 403 model_info = mlflow.spark.log_model( 404 spark_model_iris.model, 405 artifact_path=artifact_path, 406 dfs_tmpdir=dfs_tmpdir, 407 ) 408 409 reloaded_model = mlflow.spark.load_model( 410 model_uri=model_info.model_uri, dfs_tmpdir=dfs_tmpdir 411 ) 412 preds_df = reloaded_model.transform(spark_model_iris.spark_df) 413 preds = [x.prediction for x in preds_df.select("prediction").collect()] 414 assert spark_model_iris.predictions == preds 415 finally: 416 mlflow.end_run() 417 mlflow.set_tracking_uri(old_tracking_uri) 418 419 420 @pytest.mark.parametrize( 421 ("registry_uri", "artifact_repo_class"), 422 [ 423 ("databricks-uc", UnityCatalogModelsArtifactRepository), 424 ("databricks", DatabricksModelsArtifactRepository), 425 ], 426 ) 427 def test_load_spark_model_from_models_uri( 428 tmp_path, spark_model_estimator, registry_uri, artifact_repo_class 429 ): 430 model_dir = str(tmp_path.joinpath("spark_model")) 431 model_name = "mycatalog.myschema.mymodel" 432 fake_model_version = ModelVersion(name=model_name, version=str(3), creation_timestamp=0) 433 434 with ( 435 mock.patch(f"{MODELS_ARTIFACT_REPOSITORY}.get_underlying_uri") as mock_get_underlying_uri, 436 mock.patch.object( 437 artifact_repo_class, "download_artifacts", return_value=model_dir 438 ) as mock_download_artifacts, 439 mock.patch("mlflow.get_registry_uri", return_value=registry_uri), 440 mock.patch.object( 441 mlflow.tracking._model_registry.client.ModelRegistryClient, 442 "get_model_version_by_alias", 443 return_value=fake_model_version, 444 ) as get_model_version_by_alias_mock, 445 ): 446 mlflow.spark.save_model( 447 path=model_dir, 448 spark_model=spark_model_estimator.model, 449 ) 450 mock_get_underlying_uri.return_value = "nonexistentscheme://fakeuri" 451 mlflow.spark.load_model(f"models:/{model_name}/1") 452 # Assert that we downloaded both the MLmodel file and the whole model itself using 453 # the models:/ URI 454 kwargs = ( 455 {"lineage_header_info": None} 456 if artifact_repo_class is UnityCatalogModelsArtifactRepository 457 else {} 458 ) 459 mock_download_artifacts.assert_called_once_with("", None, **kwargs) 460 mock_download_artifacts.reset_mock() 461 mlflow.spark.load_model(f"models:/{model_name}@Champion") 462 mock_download_artifacts.assert_called_once_with("", None, **kwargs) 463 assert get_model_version_by_alias_mock.called_with(model_name, "Champion") 464 465 466 @pytest.mark.parametrize("should_start_run", [False, True]) 467 @pytest.mark.parametrize("use_dfs_tmpdir", [False, True]) 468 def test_sparkml_estimator_model_log( 469 tmp_path, spark_model_estimator, should_start_run, use_dfs_tmpdir 470 ): 471 old_tracking_uri = mlflow.get_tracking_uri() 472 dfs_tmpdir = None if use_dfs_tmpdir else tmp_path.joinpath("test") 473 474 try: 475 tracking_dir = tmp_path.joinpath("mlruns") 476 mlflow.set_tracking_uri(f"file://{tracking_dir}") 477 if should_start_run: 478 mlflow.start_run() 479 artifact_path = "model" 480 model_info = mlflow.spark.log_model( 481 spark_model_estimator.model, 482 artifact_path=artifact_path, 483 dfs_tmpdir=dfs_tmpdir, 484 ) 485 486 reloaded_model = mlflow.spark.load_model( 487 model_uri=model_info.model_uri, dfs_tmpdir=dfs_tmpdir 488 ) 489 preds_df = reloaded_model.transform(spark_model_estimator.spark_df) 490 preds = [x.prediction for x in preds_df.select("prediction").collect()] 491 assert spark_model_estimator.predictions == preds 492 finally: 493 mlflow.end_run() 494 mlflow.set_tracking_uri(old_tracking_uri) 495 496 497 def test_log_model_calls_register_model(tmp_path, spark_model_iris): 498 artifact_path = "model" 499 dfs_tmp_dir = tmp_path.joinpath("test") 500 register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model") 501 with mlflow.start_run(), register_model_patch: 502 model_info = mlflow.spark.log_model( 503 spark_model_iris.model, 504 artifact_path=artifact_path, 505 dfs_tmpdir=dfs_tmp_dir, 506 registered_model_name="AdsModel1", 507 ) 508 assert_register_model_called_with_local_model_path( 509 register_model_mock=mlflow.tracking._model_registry.fluent._register_model, 510 model_uri=model_info.model_uri, 511 registered_model_name="AdsModel1", 512 ) 513 514 515 def test_log_model_no_registered_model_name(tmp_path, spark_model_iris): 516 artifact_path = "model" 517 dfs_tmp_dir = os.path.join(tmp_path, "test") 518 register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model") 519 with mlflow.start_run(), register_model_patch: 520 mlflow.spark.log_model( 521 spark_model_iris.model, 522 artifact_path=artifact_path, 523 dfs_tmpdir=dfs_tmp_dir, 524 ) 525 mlflow.tracking._model_registry.fluent._register_model.assert_not_called() 526 527 528 def test_log_model_skips_maybe_save_for_acled_artifact_uri(tmp_path): 529 """_maybe_save_model should not be called for Databricks ACL-protected artifact URIs 530 (dbfs:/databricks/mlflow-tracking/...) since Spark cannot write to them directly. 531 Calling it wastes ~6s per model on a guaranteed Py4JError before falling back. 532 """ 533 acled_uri = "dbfs:/databricks/mlflow-tracking/abc123/run456/artifacts" 534 535 class FakePipelineModel: 536 def __init__(self, stages=None): 537 pass 538 539 mock_model = FakePipelineModel() 540 with ( 541 mock.patch("mlflow.spark._validate_model"), 542 mock.patch("mlflow.spark._is_spark_connect_model", return_value=False), 543 mock.patch("mlflow.spark._maybe_save_model") as mock_maybe_save, 544 mock.patch("mlflow.get_artifact_uri", return_value=acled_uri), 545 mock.patch("mlflow.spark._should_use_mlflowdbfs", return_value=False), 546 mock.patch("mlflow.models.Model._log_v2") as mock_log_v2, 547 mock.patch("pyspark.ml.PipelineModel", FakePipelineModel), 548 mlflow.start_run(), 549 ): 550 mlflow.spark.log_model( 551 mock_model, 552 artifact_path="model", 553 dfs_tmpdir=str(tmp_path), 554 ) 555 mock_maybe_save.assert_not_called() 556 mock_log_v2.assert_called_once() 557 558 559 def test_sparkml_model_load_from_remote_uri_succeeds(spark_model_iris, model_path, mock_s3_bucket): 560 mlflow.spark.save_model(spark_model=spark_model_iris.model, path=model_path) 561 562 artifact_root = f"s3://{mock_s3_bucket}" 563 artifact_path = "model" 564 artifact_repo = S3ArtifactRepository(artifact_root) 565 artifact_repo.log_artifacts(model_path, artifact_path=artifact_path) 566 567 model_uri = artifact_root + "/" + artifact_path 568 reloaded_model = mlflow.spark.load_model(model_uri=model_uri) 569 preds_df = reloaded_model.transform(spark_model_iris.spark_df) 570 preds = [x.prediction for x in preds_df.select("prediction").collect()] 571 assert spark_model_iris.predictions == preds 572 573 574 def test_sparkml_model_save_persists_specified_conda_env_in_mlflow_model_directory( 575 spark_model_iris, model_path, spark_custom_env 576 ): 577 mlflow.spark.save_model( 578 spark_model=spark_model_iris.model, path=model_path, conda_env=spark_custom_env 579 ) 580 581 pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) 582 saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"]) 583 assert os.path.exists(saved_conda_env_path) 584 assert saved_conda_env_path != spark_custom_env 585 586 with open(spark_custom_env) as f: 587 spark_custom_env_parsed = yaml.safe_load(f) 588 with open(saved_conda_env_path) as f: 589 saved_conda_env_parsed = yaml.safe_load(f) 590 assert saved_conda_env_parsed == spark_custom_env_parsed 591 592 593 def test_sparkml_model_save_persists_requirements_in_mlflow_model_directory( 594 spark_model_iris, model_path, spark_custom_env 595 ): 596 mlflow.spark.save_model( 597 spark_model=spark_model_iris.model, path=model_path, conda_env=spark_custom_env 598 ) 599 600 saved_pip_req_path = os.path.join(model_path, "requirements.txt") 601 _compare_conda_env_requirements(spark_custom_env, saved_pip_req_path) 602 603 604 def test_log_model_with_pip_requirements(spark_model_iris, tmp_path): 605 expected_mlflow_version = _mlflow_major_version_string() 606 # Path to a requirements file 607 req_file = tmp_path.joinpath("requirements.txt") 608 req_file.write_text("a") 609 with mlflow.start_run(): 610 model_info = mlflow.spark.log_model( 611 spark_model_iris.model, artifact_path="model", pip_requirements=str(req_file) 612 ) 613 _assert_pip_requirements(model_info.model_uri, [expected_mlflow_version, "a"], strict=True) 614 615 # List of requirements 616 with mlflow.start_run(): 617 model_info = mlflow.spark.log_model( 618 spark_model_iris.model, artifact_path="model", pip_requirements=[f"-r {req_file}", "b"] 619 ) 620 _assert_pip_requirements( 621 model_info.model_uri, [expected_mlflow_version, "a", "b"], strict=True 622 ) 623 624 # Constraints file 625 with mlflow.start_run(): 626 model_info = mlflow.spark.log_model( 627 spark_model_iris.model, artifact_path="model", pip_requirements=[f"-c {req_file}", "b"] 628 ) 629 _assert_pip_requirements( 630 model_info.model_uri, 631 [expected_mlflow_version, "b", "-c constraints.txt"], 632 ["a"], 633 strict=True, 634 ) 635 636 637 def test_log_model_with_extra_pip_requirements(spark_model_iris, tmp_path): 638 expected_mlflow_version = _mlflow_major_version_string() 639 default_reqs = mlflow.spark.get_default_pip_requirements() 640 641 # Path to a requirements file 642 req_file = tmp_path.joinpath("requirements.txt") 643 req_file.write_text("a") 644 with mlflow.start_run(): 645 model_info = mlflow.spark.log_model( 646 spark_model_iris.model, artifact_path="model", extra_pip_requirements=str(req_file) 647 ) 648 _assert_pip_requirements( 649 model_info.model_uri, [expected_mlflow_version, *default_reqs, "a"] 650 ) 651 652 # List of requirements 653 with mlflow.start_run(): 654 model_info = mlflow.spark.log_model( 655 spark_model_iris.model, 656 artifact_path="model", 657 extra_pip_requirements=[f"-r {req_file}", "b"], 658 ) 659 _assert_pip_requirements( 660 model_info.model_uri, [expected_mlflow_version, *default_reqs, "a", "b"] 661 ) 662 663 # Constraints file 664 with mlflow.start_run(): 665 model_info = mlflow.spark.log_model( 666 spark_model_iris.model, 667 artifact_path="model", 668 extra_pip_requirements=[f"-c {req_file}", "b"], 669 ) 670 _assert_pip_requirements( 671 model_info.model_uri, 672 [expected_mlflow_version, *default_reqs, "b", "-c constraints.txt"], 673 ["a"], 674 ) 675 676 677 def test_sparkml_model_save_accepts_conda_env_as_dict(spark_model_iris, model_path): 678 conda_env = dict(mlflow.spark.get_default_conda_env()) 679 conda_env["dependencies"].append("pytest") 680 mlflow.spark.save_model( 681 spark_model=spark_model_iris.model, path=model_path, conda_env=conda_env 682 ) 683 684 pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) 685 saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"]) 686 assert os.path.exists(saved_conda_env_path) 687 688 with open(saved_conda_env_path) as f: 689 saved_conda_env_parsed = yaml.safe_load(f) 690 assert saved_conda_env_parsed == conda_env 691 692 693 def test_sparkml_model_log_persists_specified_conda_env_in_mlflow_model_directory( 694 spark_model_iris, model_path, spark_custom_env 695 ): 696 artifact_path = "model" 697 with mlflow.start_run(): 698 model_info = mlflow.spark.log_model( 699 spark_model_iris.model, 700 artifact_path=artifact_path, 701 conda_env=spark_custom_env, 702 ) 703 704 model_path = _download_artifact_from_uri(artifact_uri=model_info.model_uri) 705 pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) 706 saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"]) 707 assert os.path.exists(saved_conda_env_path) 708 assert saved_conda_env_path != spark_custom_env 709 710 with open(spark_custom_env) as f: 711 spark_custom_env_parsed = yaml.safe_load(f) 712 with open(saved_conda_env_path) as f: 713 saved_conda_env_parsed = yaml.safe_load(f) 714 assert saved_conda_env_parsed == spark_custom_env_parsed 715 716 717 def test_sparkml_model_log_persists_requirements_in_mlflow_model_directory( 718 spark_model_iris, model_path, spark_custom_env 719 ): 720 artifact_path = "model" 721 with mlflow.start_run(): 722 model_info = mlflow.spark.log_model( 723 spark_model_iris.model, 724 artifact_path=artifact_path, 725 conda_env=spark_custom_env, 726 ) 727 model_path = _download_artifact_from_uri(artifact_uri=model_info.model_uri) 728 saved_pip_req_path = os.path.join(model_path, "requirements.txt") 729 _compare_conda_env_requirements(spark_custom_env, saved_pip_req_path) 730 731 732 def test_sparkml_model_save_without_specified_conda_env_uses_default_env_with_expected_dependencies( 733 spark_model_iris, model_path 734 ): 735 mlflow.spark.save_model(spark_model=spark_model_iris.model, path=model_path) 736 _assert_pip_requirements(model_path, mlflow.spark.get_default_pip_requirements()) 737 738 739 def test_sparkml_model_log_without_specified_conda_env_uses_default_env_with_expected_dependencies( 740 spark_model_iris, 741 ): 742 artifact_path = "model" 743 with mlflow.start_run(): 744 model_info = mlflow.spark.log_model(spark_model_iris.model, artifact_path=artifact_path) 745 746 _assert_pip_requirements(model_info.model_uri, mlflow.spark.get_default_pip_requirements()) 747 748 749 def test_pyspark_version_is_logged_without_dev_suffix(spark_model_iris): 750 expected_mlflow_version = _mlflow_major_version_string() 751 unsuffixed_version = "2.4.0" 752 for dev_suffix in [".dev0", ".dev", ".dev1", "dev.a", ".devb"]: 753 with mock.patch("importlib_metadata.version", return_value=unsuffixed_version + dev_suffix): 754 with mlflow.start_run(): 755 model_info = mlflow.spark.log_model(spark_model_iris.model, artifact_path="model") 756 _assert_pip_requirements( 757 model_info.model_uri, [expected_mlflow_version, f"pyspark=={unsuffixed_version}"] 758 ) 759 760 for unaffected_version in ["2.0", "2.3.4", "2"]: 761 with mock.patch("importlib_metadata.version", return_value=unaffected_version): 762 pip_deps = _get_pip_deps(mlflow.spark.get_default_conda_env()) 763 assert any(x == f"pyspark=={unaffected_version}" for x in pip_deps) 764 765 766 def test_model_is_recorded_when_using_direct_save(spark_model_iris): 767 # Patch `is_local_uri` to enforce direct model serialization to DFS 768 with mock.patch("mlflow.spark.is_local_uri", return_value=False): 769 with mlflow.start_run(): 770 mlflow.spark.log_model(spark_model_iris.model, artifact_path="model") 771 current_tags = mlflow.get_run(mlflow.active_run().info.run_id).data.tags 772 assert mlflow.utils.mlflow_tags.MLFLOW_LOGGED_MODELS in current_tags 773 774 775 @pytest.mark.parametrize( 776 ( 777 "artifact_uri", 778 "db_runtime_version", 779 "mlflowdbfs_disabled", 780 "mlflowdbfs_available", 781 "dbutils_available", 782 "expected_uri", 783 "expect_log_v2", 784 ), 785 [ 786 ( 787 "dbfs:/databricks/mlflow-tracking/a/b", 788 "12.0", 789 "", 790 True, 791 True, 792 "mlflowdbfs:///artifacts?run_id={}&path=/model/sparkml", 793 False, 794 ), 795 ( 796 "dbfs:/databricks/mlflow-tracking/a/b", 797 "12.0", 798 "false", 799 True, 800 True, 801 "mlflowdbfs:///artifacts?run_id={}&path=/model/sparkml", 802 False, 803 ), 804 # ACL-protected paths where mlflowdbfs is unavailable/disabled always route through 805 # Model._log_v2 because _maybe_save_model is skipped via is_databricks_acled_artifacts_uri. 806 # In real Databricks, _maybe_save_model always fails with Py4JError for these paths anyway. 807 ( 808 "dbfs:/databricks/mlflow-tracking/a/b", 809 "12.0", 810 "false", 811 True, 812 False, 813 None, 814 True, 815 ), 816 ( 817 "dbfs:/databricks/mlflow-tracking/a/b", 818 "12.0", 819 "", 820 False, 821 True, 822 None, 823 True, 824 ), 825 ( 826 "dbfs:/databricks/mlflow-tracking/a/b", 827 "", 828 "", 829 True, 830 True, 831 None, 832 True, 833 ), 834 ( 835 "dbfs:/databricks/mlflow-tracking/a/b", 836 "12.0", 837 "true", 838 True, 839 True, 840 None, 841 True, 842 ), 843 ("dbfs:/root/a/b", "12.0", "", True, True, "dbfs:/root/a/b/model/sparkml", False), 844 ("s3://mybucket/a/b", "12.0", "", True, True, "s3://mybucket/a/b/model/sparkml", False), 845 ], 846 ) 847 def test_model_logged_via_mlflowdbfs_when_appropriate( 848 monkeypatch, 849 spark_model_iris, 850 artifact_uri, 851 db_runtime_version, 852 mlflowdbfs_disabled, 853 mlflowdbfs_available, 854 dbutils_available, 855 expected_uri, 856 expect_log_v2, 857 ): 858 def mock_spark_session_load(path): 859 raise Exception("MlflowDbfsClient operation failed!") 860 861 mock_spark_session = mock.Mock() 862 mock_read_spark_session = mock.Mock() 863 mock_read_spark_session.load = mock_spark_session_load 864 865 from mlflow.utils.databricks_utils import _get_dbutils as og_getdbutils 866 867 def mock_get_dbutils(): 868 # _get_dbutils is called during run creation and model logging; to avoid breaking run 869 # creation, we only mock the output if _get_dbutils is called during spark model logging 870 caller_fn_name = inspect.stack()[1].function 871 if caller_fn_name == "_should_use_mlflowdbfs": 872 if dbutils_available: 873 return mock.Mock() 874 else: 875 raise Exception("dbutils not available") 876 else: 877 return og_getdbutils() 878 879 with ( 880 mock.patch( 881 "mlflow.utils._spark_utils._get_active_spark_session", return_value=mock_spark_session 882 ), 883 mock.patch("mlflow.get_artifact_uri", return_value=artifact_uri), 884 mock.patch( 885 "mlflow.spark._HadoopFileSystem.is_filesystem_available", 886 return_value=mlflowdbfs_available, 887 ), 888 mock.patch("mlflow.utils.databricks_utils.MlflowCredentialContext", autospec=True), 889 mock.patch("mlflow.utils.databricks_utils._get_dbutils", mock_get_dbutils), 890 mock.patch.object(spark_model_iris.model, "save") as mock_save, 891 mock.patch("mlflow.models.infer_pip_requirements", return_value=[]) as mock_infer, 892 mock.patch("mlflow.models.Model._log_v2") as mock_log_v2, 893 ): 894 with mlflow.start_run(): 895 if db_runtime_version: 896 monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", db_runtime_version) 897 monkeypatch.setenv("DISABLE_MLFLOWDBFS", mlflowdbfs_disabled) 898 mlflow.spark.log_model(spark_model_iris.model, artifact_path="model") 899 900 if expect_log_v2: 901 # ACL-protected paths where mlflowdbfs is unavailable skip _maybe_save_model 902 # entirely and fall through to Model._log_v2. In production, _maybe_save_model 903 # always raises Py4JError for these paths, so skipping it is correct. 904 mock_log_v2.assert_called_once() 905 mock_save.assert_not_called() 906 else: 907 mock_save.assert_called_once_with( 908 expected_uri.format(mlflow.active_run().info.run_id) 909 ) 910 911 if expected_uri.startswith("mlflowdbfs"): 912 # If mlflowdbfs is used, infer_pip_requirements should load the model from the 913 # remote model path instead of a local tmp path. 914 assert ( 915 mock_infer.call_args[0][0] 916 == "dbfs:/databricks/mlflow-tracking/a/b/model/sparkml" 917 ) 918 919 920 @pytest.mark.parametrize("dummy_read_shows_mlflowdbfs_available", [True, False]) 921 def test_model_logging_uses_mlflowdbfs_if_appropriate_when_hdfs_check_fails( 922 monkeypatch, spark_model_iris, dummy_read_shows_mlflowdbfs_available 923 ): 924 def mock_spark_session_load(path): 925 if dummy_read_shows_mlflowdbfs_available: 926 raise Exception("MlflowdbfsClient operation failed!") 927 else: 928 raise Exception("mlflowdbfs filesystem not found") 929 930 mock_read_spark_session = mock.Mock() 931 mock_read_spark_session.load = mock_spark_session_load 932 mock_spark_session = mock.Mock() 933 mock_spark_session.read = mock_read_spark_session 934 935 from mlflow.utils.databricks_utils import _get_dbutils as og_getdbutils 936 937 def mock_get_dbutils(): 938 # _get_dbutils is called during run creation and model logging; to avoid breaking run 939 # creation, we only mock the output if _get_dbutils is called during spark model logging 940 caller_fn_name = inspect.stack()[1].function 941 if caller_fn_name == "_should_use_mlflowdbfs": 942 return mock.Mock() 943 else: 944 return og_getdbutils() 945 946 with ( 947 mock.patch( 948 "mlflow.utils._spark_utils._get_active_spark_session", 949 return_value=mock_spark_session, 950 ), 951 mock.patch( 952 "mlflow.get_artifact_uri", 953 return_value="dbfs:/databricks/mlflow-tracking/a/b", 954 ), 955 mock.patch( 956 "mlflow.spark._HadoopFileSystem.is_filesystem_available", 957 side_effect=Exception("MlflowDbfsClient operation failed!"), 958 ), 959 mock.patch("mlflow.utils.databricks_utils.MlflowCredentialContext", autospec=True), 960 mock.patch( 961 "mlflow.utils.databricks_utils._get_dbutils", 962 mock_get_dbutils, 963 ), 964 mock.patch.object(spark_model_iris.model, "save") as mock_save, 965 mock.patch("mlflow.models.Model._log_v2") as mock_log_v2, 966 ): 967 with mlflow.start_run(): 968 monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "12.0") 969 mlflow.spark.log_model(spark_model_iris.model, artifact_path="model") 970 run_id = mlflow.active_run().info.run_id 971 if dummy_read_shows_mlflowdbfs_available: 972 mock_save.assert_called_once_with( 973 f"mlflowdbfs:///artifacts?run_id={run_id}&path=/model/sparkml" 974 ) 975 else: 976 # mlflowdbfs unavailable + ACL-protected path: _maybe_save_model is skipped, 977 # Model._log_v2 is called directly. In production, _maybe_save_model always 978 # raises Py4JError for these ACL-protected paths, so skipping it is correct. 979 mock_log_v2.assert_called_once() 980 mock_save.assert_not_called() 981 982 983 def test_log_model_with_code_paths(spark_model_iris): 984 artifact_path = "model" 985 with ( 986 mlflow.start_run(), 987 mock.patch( 988 "mlflow.spark._add_code_from_conf_to_system_path", 989 wraps=_add_code_from_conf_to_system_path, 990 ) as add_mock, 991 ): 992 model_info = mlflow.spark.log_model( 993 spark_model_iris.model, artifact_path=artifact_path, code_paths=[__file__] 994 ) 995 _compare_logged_code_paths(__file__, model_info.model_uri, mlflow.spark.FLAVOR_NAME) 996 mlflow.spark.load_model(model_info.model_uri) 997 add_mock.assert_called() 998 999 1000 def test_virtualenv_subfield_points_to_correct_path(spark_model_iris, model_path): 1001 mlflow.spark.save_model(spark_model_iris.model, path=model_path) 1002 pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) 1003 python_env_path = Path(model_path, pyfunc_conf[pyfunc.ENV]["virtualenv"]) 1004 assert python_env_path.exists() 1005 assert python_env_path.is_file() 1006 1007 1008 def test_model_save_load_with_metadata(spark_model_iris, model_path): 1009 mlflow.spark.save_model( 1010 spark_model_iris.model, path=model_path, metadata={"metadata_key": "metadata_value"} 1011 ) 1012 1013 reloaded_model = mlflow.pyfunc.load_model(model_uri=model_path) 1014 assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value" 1015 1016 1017 def test_model_log_with_metadata(spark_model_iris): 1018 with mlflow.start_run(): 1019 model_info = mlflow.spark.log_model( 1020 spark_model_iris.model, 1021 artifact_path="model", 1022 metadata={"metadata_key": "metadata_value"}, 1023 ) 1024 1025 reloaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri) 1026 assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value" 1027 1028 1029 _df_input_example = iris_pandas_df().drop("label", axis=1).iloc[[0]] 1030 1031 1032 @pytest.mark.parametrize( 1033 "input_example", 1034 # array and dict input examples are not supported any more as they 1035 # won't be converted to pandas dataframe when saving example 1036 [_df_input_example], 1037 ) 1038 def test_model_log_with_signature_inference(spark_model_iris, input_example): 1039 artifact_path = "model" 1040 1041 with mlflow.start_run(): 1042 model_info = mlflow.spark.log_model( 1043 spark_model_iris.model, artifact_path=artifact_path, input_example=input_example 1044 ) 1045 1046 mlflow_model = Model.load(model_info.model_uri) 1047 input_columns = mlflow_model.signature.inputs.inputs 1048 assert all(col.type == DataType.double for col in input_columns) 1049 column_names = [col.name for col in input_columns] 1050 if isinstance(input_example, list): 1051 assert column_names == [0, 1, 2, 3] 1052 else: 1053 assert column_names == ["0", "1", "2", "3"] 1054 assert mlflow_model.signature.outputs == Schema([ColSpec(type=DataType.double)]) 1055 1056 1057 def test_log_model_with_vector_input_type_signature(spark, spark_model_estimator): 1058 from pyspark.ml.functions import vector_to_array 1059 1060 from mlflow.types.schema import SparkMLVector 1061 1062 model = spark_model_estimator.model 1063 with mlflow.start_run(): 1064 model_info = mlflow.spark.log_model( 1065 model, 1066 artifact_path="model", 1067 signature=ModelSignature( 1068 inputs=Schema([ 1069 ColSpec(name="features", type=SparkMLVector()), 1070 ]), 1071 outputs=Schema([ColSpec(type=DataType.double)]), 1072 ), 1073 ) 1074 1075 model_uri = model_info.model_uri 1076 model_meta = Model.load(model_uri) 1077 input_type = model_meta.signature.inputs.input_dict()["features"].type 1078 assert isinstance(input_type, SparkMLVector) 1079 1080 pyfunc_model = pyfunc.load_model(model_uri) 1081 infer_data = spark_model_estimator.spark_df.withColumn( 1082 "features", vector_to_array("features") 1083 ).toPandas() 1084 preds = pyfunc_model.predict(infer_data) 1085 assert spark_model_estimator.predictions == preds