test_sklearn_model_export.py
1 import json 2 import os 3 import pickle 4 import shutil 5 import tempfile 6 from pathlib import Path 7 from typing import Any, NamedTuple 8 from unittest import mock 9 10 import cloudpickle 11 import numpy as np 12 import pandas as pd 13 import pytest 14 import sklearn 15 import sklearn.linear_model as glm 16 import sklearn.naive_bayes as nb 17 import sklearn.neighbors as knn 18 import skops 19 import yaml 20 from packaging.version import Version 21 from sklearn import datasets 22 from sklearn.pipeline import Pipeline as SKPipeline 23 from sklearn.pipeline import make_pipeline 24 from sklearn.preprocessing import FunctionTransformer as SKFunctionTransformer 25 26 import mlflow.pyfunc.scoring_server as pyfunc_scoring_server 27 import mlflow.sklearn 28 from mlflow import pyfunc 29 from mlflow.entities.model_registry.model_version import ModelVersion, ModelVersionStatus 30 from mlflow.exceptions import MlflowException 31 from mlflow.models import Model, ModelSignature 32 from mlflow.models.utils import _read_example, load_serving_example 33 from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, ErrorCode 34 from mlflow.store._unity_catalog.registry.rest_store import UcModelRegistryStore 35 from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository 36 from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository 37 from mlflow.tracking.artifact_utils import _download_artifact_from_uri 38 from mlflow.types import DataType 39 from mlflow.types.schema import ColSpec, Schema 40 from mlflow.utils.environment import _mlflow_conda_env 41 from mlflow.utils.file_utils import TempDir 42 from mlflow.utils.model_utils import _get_flavor_configuration 43 44 from tests.helper_functions import ( 45 _assert_pip_requirements, 46 _compare_conda_env_requirements, 47 _compare_logged_code_paths, 48 _is_available_on_pypi, 49 _mlflow_major_version_string, 50 assert_register_model_called_with_local_model_path, 51 pyfunc_serve_and_score_model, 52 ) 53 from tests.store._unity_catalog.conftest import ( 54 configure_client_for_uc, # noqa: F401 55 mock_databricks_uc_host_creds, # noqa: F401 56 ) 57 58 EXTRA_PYFUNC_SERVING_TEST_ARGS = ( 59 [] if _is_available_on_pypi("scikit-learn", module="sklearn") else ["--env-manager", "local"] 60 ) 61 62 63 class ModelWithData(NamedTuple): 64 model: Any 65 inference_data: Any 66 67 68 @pytest.fixture(scope="module") 69 def iris_df(): 70 iris = datasets.load_iris() 71 X = iris.data 72 y = iris.target 73 X_df = pd.DataFrame(X, columns=iris.feature_names) 74 X_df = X_df.iloc[:, :2] # we only take the first two features. 75 y_series = pd.Series(y) 76 return X_df, y_series 77 78 79 @pytest.fixture(scope="module") 80 def iris_signature(): 81 return ModelSignature( 82 inputs=Schema([ 83 ColSpec(name="sepal length (cm)", type=DataType.double), 84 ColSpec(name="sepal width (cm)", type=DataType.double), 85 ]), 86 outputs=Schema([ColSpec(type=DataType.long)]), 87 ) 88 89 90 @pytest.fixture(scope="module") 91 def sklearn_knn_model(iris_df): 92 X, y = iris_df 93 knn_model = knn.KNeighborsClassifier() 94 knn_model.fit(X, y) 95 return ModelWithData(model=knn_model, inference_data=X) 96 97 98 # To load sklearn KNN model as skops format, 99 # We need to mark these types as `skops_trusted_types` 100 # related ticket: https://github.com/skops-dev/skops/issues/498 101 sklearn_knn_model_skops_trusted_types = [ 102 "sklearn.metrics._dist_metrics.EuclideanDistance64", 103 "sklearn.neighbors._kd_tree.KDTree", 104 ] 105 106 107 @pytest.fixture(scope="module") 108 def sklearn_logreg_model(iris_df): 109 X, y = iris_df 110 linear_lr = glm.LogisticRegression() 111 linear_lr.fit(X, y) 112 return ModelWithData(model=linear_lr, inference_data=X) 113 114 115 @pytest.fixture(scope="module") 116 def sklearn_gaussian_model(iris_df): 117 X, y = iris_df 118 gaussian_nb = nb.GaussianNB() 119 gaussian_nb.fit(X, y) 120 return ModelWithData(model=gaussian_nb, inference_data=X) 121 122 123 @pytest.fixture(scope="module") 124 def sklearn_custom_transformer_model(sklearn_knn_model, iris_df): 125 def transform(vec): 126 return vec + 1 127 128 transformer = SKFunctionTransformer(transform, validate=True) 129 pipeline = SKPipeline([("custom_transformer", transformer), ("knn", sklearn_knn_model.model)]) 130 X, _ = iris_df 131 return ModelWithData(pipeline, inference_data=X) 132 133 134 @pytest.fixture 135 def model_path(tmp_path): 136 return os.path.join(tmp_path, "model") 137 138 139 @pytest.fixture 140 def sklearn_custom_env(tmp_path): 141 conda_env = os.path.join(tmp_path, "conda_env.yml") 142 _mlflow_conda_env(conda_env, additional_pip_deps=["scikit-learn", "pytest"]) 143 return conda_env 144 145 146 @pytest.mark.parametrize("serialization_format", mlflow.sklearn.SUPPORTED_SERIALIZATION_FORMATS) 147 def test_model_save_load(sklearn_logreg_model, model_path, serialization_format): 148 from mlflow.utils.requirements_utils import _parse_requirements 149 150 sk_model = sklearn_logreg_model.model 151 mlflow.sklearn.save_model( 152 sk_model=sk_model, path=model_path, serialization_format=serialization_format 153 ) 154 reloaded_model = mlflow.sklearn.load_model(model_uri=model_path) 155 reloaded_pyfunc = pyfunc.load_model(model_uri=model_path) 156 157 sklearn_conf = _get_flavor_configuration( 158 model_path=model_path, flavor_name=mlflow.sklearn.FLAVOR_NAME 159 ) 160 assert "serialization_format" in sklearn_conf 161 assert sklearn_conf["serialization_format"] == serialization_format 162 163 req_map = { 164 mlflow.sklearn.SERIALIZATION_FORMAT_SKOPS: f"skops=={skops.__version__}", 165 mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE: f"cloudpickle=={cloudpickle.__version__}", 166 } 167 168 logged_reqs = [ 169 req.req_str 170 for req in _parse_requirements( 171 os.path.join(model_path, "requirements.txt"), is_constraint=False 172 ) 173 ] 174 if serialization_format != mlflow.sklearn.SERIALIZATION_FORMAT_PICKLE: 175 assert req_map[serialization_format] in logged_reqs 176 177 np.testing.assert_array_equal( 178 sk_model.predict(sklearn_logreg_model.inference_data), 179 reloaded_model.predict(sklearn_logreg_model.inference_data), 180 ) 181 182 np.testing.assert_array_equal( 183 reloaded_model.predict(sklearn_logreg_model.inference_data), 184 reloaded_pyfunc.predict(sklearn_logreg_model.inference_data), 185 ) 186 187 188 def test_model_skops_format_trusted_type(sklearn_knn_model, model_path): 189 sk_model = sklearn_knn_model.model 190 191 with pytest.raises(MlflowException, match="The saved sklearn model references untrusted type"): 192 mlflow.sklearn.save_model( 193 sk_model=sk_model, 194 path=model_path, 195 serialization_format="skops", 196 ) 197 198 shutil.rmtree(model_path) 199 mlflow.sklearn.save_model( 200 sk_model=sklearn_knn_model.model, 201 path=model_path, 202 serialization_format="skops", 203 skops_trusted_types=sklearn_knn_model_skops_trusted_types, 204 ) 205 reloaded_model = mlflow.sklearn.load_model(model_uri=model_path) 206 reloaded_pyfunc = pyfunc.load_model(model_uri=model_path) 207 np.testing.assert_array_equal( 208 sk_model.predict(sklearn_knn_model.inference_data), 209 reloaded_model.predict(sklearn_knn_model.inference_data), 210 ) 211 212 np.testing.assert_array_equal( 213 reloaded_model.predict(sklearn_knn_model.inference_data), 214 reloaded_pyfunc.predict(sklearn_knn_model.inference_data), 215 ) 216 217 218 def test_log_model_skops_no_pip_requirements_warning(sklearn_logreg_model, recwarn): 219 with mlflow.start_run(): 220 mlflow.sklearn.log_model( 221 sklearn_logreg_model.model, 222 serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_SKOPS, 223 ) 224 warning_messages = [str(w.message) for w in recwarn] 225 assert not any("Fall back to return" in msg for msg in warning_messages) 226 227 228 def test_model_save_behavior_with_preexisting_folders(sklearn_knn_model, tmp_path): 229 sklearn_model_path = tmp_path / "sklearn_model_empty_exists" 230 sklearn_model_path.mkdir() 231 mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path) 232 233 sklearn_model_path = tmp_path / "sklearn_model_filled_exists" 234 sklearn_model_path.mkdir() 235 (sklearn_model_path / "foo.txt").write_text("dummy content") 236 with pytest.raises(MlflowException, match="already exists and is not empty"): 237 mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path) 238 239 240 def test_signature_and_examples_are_saved_correctly(sklearn_knn_model, iris_signature): 241 data = sklearn_knn_model.inference_data 242 model = sklearn_knn_model.model 243 example_ = data[:3] 244 for signature in (None, iris_signature): 245 for example in (None, example_): 246 with TempDir() as tmp: 247 path = tmp.path("model") 248 mlflow.sklearn.save_model( 249 model, path=path, signature=signature, input_example=example 250 ) 251 mlflow_model = Model.load(path) 252 if signature is None and example is None: 253 assert mlflow_model.signature is None 254 else: 255 assert mlflow_model.signature == iris_signature 256 if example is None: 257 assert mlflow_model.saved_input_example_info is None 258 else: 259 np.testing.assert_array_equal(_read_example(mlflow_model, path), example) 260 261 262 def test_model_load_from_remote_uri_succeeds(sklearn_knn_model, model_path, mock_s3_bucket): 263 mlflow.sklearn.save_model(sk_model=sklearn_knn_model.model, path=model_path) 264 265 artifact_root = f"s3://{mock_s3_bucket}" 266 artifact_path = "model" 267 artifact_repo = S3ArtifactRepository(artifact_root) 268 artifact_repo.log_artifacts(model_path, artifact_path=artifact_path) 269 270 model_uri = artifact_root + "/" + artifact_path 271 reloaded_knn_model = mlflow.sklearn.load_model(model_uri=model_uri) 272 np.testing.assert_array_equal( 273 sklearn_knn_model.model.predict(sklearn_knn_model.inference_data), 274 reloaded_knn_model.predict(sklearn_knn_model.inference_data), 275 ) 276 277 278 def test_model_log(sklearn_logreg_model, model_path): 279 with TempDir(chdr=True, remove_on_exit=True) as tmp: 280 for should_start_run in [False, True]: 281 try: 282 if should_start_run: 283 mlflow.start_run() 284 285 artifact_path = "linear" 286 conda_env = os.path.join(tmp.path(), "conda_env.yaml") 287 _mlflow_conda_env(conda_env, additional_pip_deps=["scikit-learn"]) 288 289 model_info = mlflow.sklearn.log_model( 290 sklearn_logreg_model.model, 291 name=artifact_path, 292 conda_env=conda_env, 293 ) 294 295 reloaded_logsklearn_knn_model = mlflow.sklearn.load_model( 296 model_uri=model_info.model_uri 297 ) 298 np.testing.assert_array_equal( 299 sklearn_logreg_model.model.predict(sklearn_logreg_model.inference_data), 300 reloaded_logsklearn_knn_model.predict(sklearn_logreg_model.inference_data), 301 ) 302 303 model_path = _download_artifact_from_uri(artifact_uri=model_info.model_uri) 304 model_config = Model.load(os.path.join(model_path, "MLmodel")) 305 assert pyfunc.FLAVOR_NAME in model_config.flavors 306 assert pyfunc.ENV in model_config.flavors[pyfunc.FLAVOR_NAME] 307 env_path = model_config.flavors[pyfunc.FLAVOR_NAME][pyfunc.ENV]["conda"] 308 assert os.path.exists(os.path.join(model_path, env_path)) 309 310 finally: 311 mlflow.end_run() 312 313 314 def test_log_model_calls_register_model(sklearn_logreg_model): 315 artifact_path = "linear" 316 register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model") 317 with mlflow.start_run(), register_model_patch, TempDir(chdr=True, remove_on_exit=True) as tmp: 318 conda_env = os.path.join(tmp.path(), "conda_env.yaml") 319 _mlflow_conda_env(conda_env, additional_pip_deps=["scikit-learn"]) 320 model_info = mlflow.sklearn.log_model( 321 sklearn_logreg_model.model, 322 name=artifact_path, 323 conda_env=conda_env, 324 registered_model_name="AdsModel1", 325 ) 326 assert_register_model_called_with_local_model_path( 327 register_model_mock=mlflow.tracking._model_registry.fluent._register_model, 328 model_uri=model_info.model_uri, 329 registered_model_name="AdsModel1", 330 ) 331 332 333 def test_log_model_call_register_model_to_uc(configure_client_for_uc, sklearn_logreg_model): 334 artifact_path = "linear" 335 mock_model_version = ModelVersion( 336 name="AdsModel1", 337 version=1, 338 creation_timestamp=123, 339 status=ModelVersionStatus.to_string(ModelVersionStatus.READY), 340 ) 341 with ( 342 mock.patch.object(UcModelRegistryStore, "create_registered_model"), 343 mock.patch.object( 344 UcModelRegistryStore, 345 "create_model_version", 346 return_value=mock_model_version, 347 autospec=True, 348 ) as mock_create_mv, 349 TempDir(chdr=True, remove_on_exit=True) as tmp, 350 ): 351 with mlflow.start_run() as run: 352 conda_env = os.path.join(tmp.path(), "conda_env.yaml") 353 _mlflow_conda_env(conda_env, additional_pip_deps=["scikit-learn"]) 354 model_info = mlflow.sklearn.log_model( 355 sklearn_logreg_model.model, 356 name=artifact_path, 357 conda_env=conda_env, 358 registered_model_name="AdsModel1", 359 ) 360 source = model_info.artifact_path 361 [(args, kwargs)] = mock_create_mv.call_args_list 362 assert args[1:] == ("AdsModel1", source, run.info.run_id, [], None, None) 363 assert kwargs["local_model_path"].startswith(tempfile.gettempdir()) 364 365 366 def test_log_model_no_registered_model_name(sklearn_logreg_model): 367 artifact_path = "model" 368 register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model") 369 with mlflow.start_run(), register_model_patch, TempDir(chdr=True, remove_on_exit=True) as tmp: 370 conda_env = os.path.join(tmp.path(), "conda_env.yaml") 371 _mlflow_conda_env(conda_env, additional_pip_deps=["scikit-learn"]) 372 mlflow.sklearn.log_model( 373 sklearn_logreg_model.model, 374 name=artifact_path, 375 conda_env=conda_env, 376 ) 377 mlflow.tracking._model_registry.fluent._register_model.assert_not_called() 378 379 380 def test_custom_transformer_can_be_saved_and_loaded_with_cloudpickle_format( 381 sklearn_custom_transformer_model, tmp_path 382 ): 383 custom_transformer_model = sklearn_custom_transformer_model.model 384 385 # Because the model contains a customer transformer that is not defined at the top level of the 386 # current test module, we expect pickle to fail when attempting to serialize it. In contrast, 387 # we expect cloudpickle to successfully locate the transformer definition and serialize the 388 # model successfully. 389 pickle_format_model_path = os.path.join(tmp_path, "pickle_model") 390 with pytest.raises(AttributeError, match="Can't pickle local object"): 391 mlflow.sklearn.save_model( 392 sk_model=custom_transformer_model, 393 path=pickle_format_model_path, 394 serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_PICKLE, 395 ) 396 397 cloudpickle_format_model_path = os.path.join(tmp_path, "cloud_pickle_model") 398 mlflow.sklearn.save_model( 399 sk_model=custom_transformer_model, 400 path=cloudpickle_format_model_path, 401 serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE, 402 ) 403 404 reloaded_custom_transformer_model = mlflow.sklearn.load_model( 405 model_uri=cloudpickle_format_model_path 406 ) 407 408 np.testing.assert_array_equal( 409 custom_transformer_model.predict(sklearn_custom_transformer_model.inference_data), 410 reloaded_custom_transformer_model.predict(sklearn_custom_transformer_model.inference_data), 411 ) 412 413 414 def test_model_save_persists_specified_conda_env_in_mlflow_model_directory( 415 sklearn_logreg_model, model_path, sklearn_custom_env 416 ): 417 mlflow.sklearn.save_model( 418 sk_model=sklearn_logreg_model.model, 419 path=model_path, 420 conda_env=sklearn_custom_env, 421 ) 422 423 pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) 424 saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"]) 425 assert os.path.exists(saved_conda_env_path) 426 assert saved_conda_env_path != sklearn_custom_env 427 428 with open(sklearn_custom_env) as f: 429 sklearn_custom_env_parsed = yaml.safe_load(f) 430 with open(saved_conda_env_path) as f: 431 saved_conda_env_parsed = yaml.safe_load(f) 432 assert saved_conda_env_parsed == sklearn_custom_env_parsed 433 434 435 def test_model_save_persists_requirements_in_mlflow_model_directory( 436 sklearn_knn_model, model_path, sklearn_custom_env 437 ): 438 mlflow.sklearn.save_model( 439 sk_model=sklearn_knn_model.model, path=model_path, conda_env=sklearn_custom_env 440 ) 441 442 saved_pip_req_path = os.path.join(model_path, "requirements.txt") 443 _compare_conda_env_requirements(sklearn_custom_env, saved_pip_req_path) 444 445 446 def test_log_model_with_pip_requirements(sklearn_knn_model, tmp_path): 447 expected_mlflow_version = _mlflow_major_version_string() 448 # Path to a requirements file 449 req_file = tmp_path.joinpath("requirements.txt") 450 req_file.write_text("a") 451 with mlflow.start_run(): 452 model_info = mlflow.sklearn.log_model( 453 sklearn_knn_model.model, name="model", pip_requirements=str(req_file) 454 ) 455 _assert_pip_requirements(model_info.model_uri, [expected_mlflow_version, "a"], strict=True) 456 457 # List of requirements 458 with mlflow.start_run(): 459 model_info = mlflow.sklearn.log_model( 460 sklearn_knn_model.model, name="model", pip_requirements=[f"-r {req_file}", "b"] 461 ) 462 _assert_pip_requirements( 463 model_info.model_uri, [expected_mlflow_version, "a", "b"], strict=True 464 ) 465 466 # Constraints file 467 with mlflow.start_run(): 468 model_info = mlflow.sklearn.log_model( 469 sklearn_knn_model.model, name="model", pip_requirements=[f"-c {req_file}", "b"] 470 ) 471 _assert_pip_requirements( 472 model_info.model_uri, 473 [expected_mlflow_version, "b", "-c constraints.txt"], 474 ["a"], 475 strict=True, 476 ) 477 478 479 def test_log_model_with_extra_pip_requirements(sklearn_knn_model, tmp_path): 480 expected_mlflow_version = _mlflow_major_version_string() 481 default_reqs = mlflow.sklearn.get_default_pip_requirements(include_cloudpickle=True) 482 483 # Path to a requirements file 484 req_file = tmp_path.joinpath("requirements.txt") 485 req_file.write_text("a") 486 with mlflow.start_run(): 487 model_info = mlflow.sklearn.log_model( 488 sklearn_knn_model.model, name="model", extra_pip_requirements=str(req_file) 489 ) 490 _assert_pip_requirements( 491 model_info.model_uri, [expected_mlflow_version, *default_reqs, "a"] 492 ) 493 494 # List of requirements 495 with mlflow.start_run(): 496 model_info = mlflow.sklearn.log_model( 497 sklearn_knn_model.model, name="model", extra_pip_requirements=[f"-r {req_file}", "b"] 498 ) 499 _assert_pip_requirements( 500 model_info.model_uri, [expected_mlflow_version, *default_reqs, "a", "b"] 501 ) 502 503 # Constraints file 504 with mlflow.start_run(): 505 model_info = mlflow.sklearn.log_model( 506 sklearn_knn_model.model, name="model", extra_pip_requirements=[f"-c {req_file}", "b"] 507 ) 508 _assert_pip_requirements( 509 model_info.model_uri, 510 [expected_mlflow_version, *default_reqs, "b", "-c constraints.txt"], 511 ["a"], 512 ) 513 514 515 def test_model_save_accepts_conda_env_as_dict(sklearn_knn_model, model_path): 516 conda_env = dict(mlflow.sklearn.get_default_conda_env()) 517 conda_env["dependencies"].append("pytest") 518 mlflow.sklearn.save_model( 519 sk_model=sklearn_knn_model.model, path=model_path, conda_env=conda_env 520 ) 521 522 pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) 523 saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"]) 524 assert os.path.exists(saved_conda_env_path) 525 526 with open(saved_conda_env_path) as f: 527 saved_conda_env_parsed = yaml.safe_load(f) 528 assert saved_conda_env_parsed == conda_env 529 530 531 def test_model_log_persists_specified_conda_env_in_mlflow_model_directory( 532 sklearn_knn_model, sklearn_custom_env 533 ): 534 artifact_path = "model" 535 with mlflow.start_run(): 536 model_info = mlflow.sklearn.log_model( 537 sklearn_knn_model.model, 538 name=artifact_path, 539 conda_env=sklearn_custom_env, 540 ) 541 model_uri = model_info.model_uri 542 543 model_path = _download_artifact_from_uri(artifact_uri=model_uri) 544 pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) 545 saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"]) 546 assert os.path.exists(saved_conda_env_path) 547 assert saved_conda_env_path != sklearn_custom_env 548 549 with open(sklearn_custom_env) as f: 550 sklearn_custom_env_parsed = yaml.safe_load(f) 551 with open(saved_conda_env_path) as f: 552 saved_conda_env_parsed = yaml.safe_load(f) 553 assert saved_conda_env_parsed == sklearn_custom_env_parsed 554 555 556 def test_model_log_persists_requirements_in_mlflow_model_directory( 557 sklearn_knn_model, sklearn_custom_env 558 ): 559 with mlflow.start_run(): 560 model_info = mlflow.sklearn.log_model( 561 sklearn_knn_model.model, 562 name="model", 563 conda_env=sklearn_custom_env, 564 ) 565 566 model_path = _download_artifact_from_uri(artifact_uri=model_info.model_uri) 567 saved_pip_req_path = os.path.join(model_path, "requirements.txt") 568 _compare_conda_env_requirements(sklearn_custom_env, saved_pip_req_path) 569 570 571 def test_model_save_throws_exception_if_serialization_format_is_unrecognized( 572 sklearn_knn_model, model_path 573 ): 574 with pytest.raises(MlflowException, match="Unrecognized serialization format") as exc: 575 mlflow.sklearn.save_model( 576 sk_model=sklearn_knn_model.model, 577 path=model_path, 578 serialization_format="not a valid format", 579 ) 580 assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 581 582 # The unsupported serialization format should have been detected prior to the execution of 583 # any directory creation or state-mutating persistence logic that would prevent a second 584 # serialization call with the same model path from succeeding 585 assert not os.path.exists(model_path) 586 mlflow.sklearn.save_model(sk_model=sklearn_knn_model.model, path=model_path) 587 588 589 def test_model_save_without_specified_conda_env_uses_default_env_with_expected_dependencies( 590 sklearn_knn_model, model_path 591 ): 592 mlflow.sklearn.save_model(sk_model=sklearn_knn_model.model, path=model_path) 593 _assert_pip_requirements( 594 model_path, mlflow.sklearn.get_default_pip_requirements(include_cloudpickle=True) 595 ) 596 597 598 def test_model_log_without_specified_conda_env_uses_default_env_with_expected_dependencies( 599 sklearn_knn_model, 600 ): 601 with mlflow.start_run(): 602 model_info = mlflow.sklearn.log_model(sklearn_knn_model.model, name="model") 603 604 _assert_pip_requirements( 605 model_info.model_uri, mlflow.sklearn.get_default_pip_requirements(include_cloudpickle=True) 606 ) 607 608 609 def test_model_save_uses_cloudpickle_serialization_format_by_default(sklearn_knn_model, model_path): 610 mlflow.sklearn.save_model(sk_model=sklearn_knn_model.model, path=model_path) 611 612 sklearn_conf = _get_flavor_configuration( 613 model_path=model_path, flavor_name=mlflow.sklearn.FLAVOR_NAME 614 ) 615 assert "serialization_format" in sklearn_conf 616 assert sklearn_conf["serialization_format"] == mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE 617 618 619 def test_model_log_uses_cloudpickle_serialization_format_by_default(sklearn_knn_model): 620 with mlflow.start_run(): 621 model_info = mlflow.sklearn.log_model(sklearn_knn_model.model, name="model") 622 623 model_path = _download_artifact_from_uri(artifact_uri=model_info.model_uri) 624 sklearn_conf = _get_flavor_configuration( 625 model_path=model_path, flavor_name=mlflow.sklearn.FLAVOR_NAME 626 ) 627 assert "serialization_format" in sklearn_conf 628 assert sklearn_conf["serialization_format"] == mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE 629 630 631 def test_model_save_with_cloudpickle_format_adds_cloudpickle_to_conda_environment( 632 sklearn_knn_model, model_path 633 ): 634 mlflow.sklearn.save_model( 635 sk_model=sklearn_knn_model.model, 636 path=model_path, 637 serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE, 638 ) 639 640 sklearn_conf = _get_flavor_configuration( 641 model_path=model_path, flavor_name=mlflow.sklearn.FLAVOR_NAME 642 ) 643 assert "serialization_format" in sklearn_conf 644 assert sklearn_conf["serialization_format"] == mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE 645 646 pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) 647 saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"]) 648 assert os.path.exists(saved_conda_env_path) 649 with open(saved_conda_env_path) as f: 650 saved_conda_env_parsed = yaml.safe_load(f) 651 652 pip_deps = [ 653 dependency 654 for dependency in saved_conda_env_parsed["dependencies"] 655 if type(dependency) == dict and "pip" in dependency 656 ] 657 assert len(pip_deps) == 1 658 assert any("cloudpickle" in pip_dep for pip_dep in pip_deps[0]["pip"]) 659 660 661 def test_model_save_without_cloudpickle_format_does_not_add_cloudpickle_to_conda_environment( 662 sklearn_logreg_model, model_path 663 ): 664 non_cloudpickle_serialization_formats = list(mlflow.sklearn.SUPPORTED_SERIALIZATION_FORMATS) 665 non_cloudpickle_serialization_formats.remove(mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE) 666 667 for serialization_format in non_cloudpickle_serialization_formats: 668 mlflow.sklearn.save_model( 669 sk_model=sklearn_logreg_model.model, 670 path=model_path, 671 serialization_format=serialization_format, 672 ) 673 674 sklearn_conf = _get_flavor_configuration( 675 model_path=model_path, flavor_name=mlflow.sklearn.FLAVOR_NAME 676 ) 677 assert "serialization_format" in sklearn_conf 678 assert sklearn_conf["serialization_format"] == serialization_format 679 680 pyfunc_conf = _get_flavor_configuration( 681 model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME 682 ) 683 saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"]) 684 assert os.path.exists(saved_conda_env_path) 685 with open(saved_conda_env_path) as f: 686 saved_conda_env_parsed = yaml.safe_load(f) 687 assert all( 688 "cloudpickle" not in dependency for dependency in saved_conda_env_parsed["dependencies"] 689 ) 690 691 shutil.rmtree(model_path) 692 693 694 def test_load_pyfunc_succeeds_for_older_models_with_pyfunc_data_field( 695 sklearn_knn_model, model_path 696 ): 697 """ 698 This test verifies that scikit-learn models saved in older versions of MLflow are loaded 699 successfully by ``mlflow.pyfunc.load_model``. These older models specify a pyfunc ``data`` 700 field referring directly to a serialized scikit-learn model file. In contrast, newer models 701 omit the ``data`` field. 702 """ 703 mlflow.sklearn.save_model( 704 sk_model=sklearn_knn_model.model, 705 path=model_path, 706 serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_PICKLE, 707 ) 708 709 model_conf_path = os.path.join(model_path, "MLmodel") 710 model_conf = Model.load(model_conf_path) 711 pyfunc_conf = model_conf.flavors.get(pyfunc.FLAVOR_NAME) 712 sklearn_conf = model_conf.flavors.get(mlflow.sklearn.FLAVOR_NAME) 713 assert sklearn_conf is not None 714 assert pyfunc_conf is not None 715 pyfunc_conf[pyfunc.DATA] = sklearn_conf["pickled_model"] 716 717 reloaded_knn_pyfunc = pyfunc.load_model(model_uri=model_path) 718 719 np.testing.assert_array_equal( 720 sklearn_knn_model.model.predict(sklearn_knn_model.inference_data), 721 reloaded_knn_pyfunc.predict(sklearn_knn_model.inference_data), 722 ) 723 724 725 def test_add_pyfunc_flavor_only_when_model_defines_predict(model_path): 726 from sklearn.cluster import AgglomerativeClustering 727 728 sk_model = AgglomerativeClustering() 729 assert not hasattr(sk_model, "predict") 730 731 mlflow.sklearn.save_model( 732 sk_model=sk_model, 733 path=model_path, 734 serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_PICKLE, 735 ) 736 737 model_conf_path = os.path.join(model_path, "MLmodel") 738 model_conf = Model.load(model_conf_path) 739 assert pyfunc.FLAVOR_NAME not in model_conf.flavors 740 741 742 def test_pyfunc_serve_and_score(sklearn_knn_model): 743 model, inference_dataframe = sklearn_knn_model 744 artifact_path = "model" 745 with mlflow.start_run(): 746 model_info = mlflow.sklearn.log_model( 747 model, name=artifact_path, input_example=inference_dataframe 748 ) 749 750 inference_payload = load_serving_example(model_info.model_uri) 751 resp = pyfunc_serve_and_score_model( 752 model_info.model_uri, 753 data=inference_payload, 754 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 755 extra_args=EXTRA_PYFUNC_SERVING_TEST_ARGS, 756 ) 757 scores = pd.DataFrame( 758 data=json.loads(resp.content.decode("utf-8"))["predictions"] 759 ).values.squeeze() 760 np.testing.assert_array_almost_equal(scores, model.predict(inference_dataframe)) 761 762 763 def test_sklearn_compatible_with_mlflow_2_4_0(sklearn_knn_model, tmp_path): 764 model, inference_dataframe = sklearn_knn_model 765 model_predict = model.predict(inference_dataframe) 766 767 # save test model 768 tmp_path.joinpath("MLmodel").write_text( 769 f""" 770 artifact_path: model 771 flavors: 772 python_function: 773 env: 774 conda: conda.yaml 775 virtualenv: python_env.yaml 776 loader_module: mlflow.sklearn 777 model_path: model.pkl 778 predict_fn: predict 779 python_version: 3.11.14 780 sklearn: 781 code: null 782 pickled_model: model.pkl 783 serialization_format: cloudpickle 784 sklearn_version: {sklearn.__version__} 785 mlflow_version: 2.4.0 786 model_uuid: c9833d74b1ff4013a1c9eff05d39eeef 787 run_id: 8146a2ae86104f5b853351e600fc9d7b 788 utc_time_created: '2023-07-04 07:19:43.561797' 789 """ 790 ) 791 tmp_path.joinpath("python_env.yaml").write_text( 792 """ 793 python: 3.11.14 794 build_dependencies: 795 - pip==25.1.1 796 - setuptools==80.4.0 797 - wheel==0.45.1 798 dependencies: 799 - -r requirements.txt 800 """ 801 ) 802 tmp_path.joinpath("requirements.txt").write_text( 803 f""" 804 mlflow==2.4.0 805 cloudpickle 806 numpy 807 psutil 808 scikit-learn=={sklearn.__version__} 809 scipy 810 """ 811 ) 812 with open(tmp_path / "model.pkl", "wb") as out: 813 pickle.dump(model, out, protocol=pickle.DEFAULT_PROTOCOL) 814 815 assert Version(mlflow.__version__) > Version("2.4.0") 816 model_uri = str(tmp_path) 817 pyfunc_loaded = mlflow.pyfunc.load_model(model_uri) 818 819 # predict is compatible 820 local_predict = pyfunc_loaded.predict(inference_dataframe) 821 np.testing.assert_array_almost_equal(local_predict, model_predict) 822 823 # model serving is compatible 824 resp = pyfunc_serve_and_score_model( 825 model_uri, 826 data=pd.DataFrame(inference_dataframe), 827 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 828 extra_args=EXTRA_PYFUNC_SERVING_TEST_ARGS, 829 ) 830 scores = pd.DataFrame( 831 data=json.loads(resp.content.decode("utf-8"))["predictions"] 832 ).values.squeeze() 833 np.testing.assert_array_almost_equal(scores, model_predict) 834 835 # Issues a warning if params are specified prior to MLflow support in 2.5.0 836 with mock.patch("mlflow.models.utils._logger.warning") as mock_warning: 837 pyfunc_loaded.predict(inference_dataframe, params={"top_k": 2}) 838 mock_warning.assert_called_with( 839 "`params` can only be specified at inference time if the model signature defines a params " 840 "schema. This model does not define a params schema. Ignoring provided params: " 841 "['top_k']" 842 ) 843 844 845 def test_log_model_with_code_paths(sklearn_knn_model): 846 artifact_path = "model" 847 with ( 848 mlflow.start_run(), 849 mock.patch("mlflow.sklearn._add_code_from_conf_to_system_path") as add_mock, 850 ): 851 model_info = mlflow.sklearn.log_model( 852 sklearn_knn_model.model, name=artifact_path, code_paths=[__file__] 853 ) 854 _compare_logged_code_paths(__file__, model_info.model_uri, mlflow.sklearn.FLAVOR_NAME) 855 mlflow.sklearn.load_model(model_uri=model_info.model_uri) 856 add_mock.assert_called() 857 858 859 @pytest.mark.parametrize( 860 "predict_fn", ["predict", "predict_proba", "predict_log_proba", "predict_joint_log_proba"] 861 ) 862 def test_log_model_with_custom_pyfunc_predict_fn(sklearn_gaussian_model, predict_fn): 863 if Version(sklearn.__version__) < Version("1.2.0") and predict_fn == "predict_joint_log_proba": 864 pytest.skip("predict_joint_log_proba is not available in scikit-learn < 1.2.0") 865 866 model, inference_dataframe = sklearn_gaussian_model 867 expected_scores = getattr(model, predict_fn)(inference_dataframe) 868 artifact_path = "model" 869 with mlflow.start_run(): 870 model_info = mlflow.sklearn.log_model( 871 model, name=artifact_path, pyfunc_predict_fn=predict_fn 872 ) 873 874 loaded_model = pyfunc.load_model(model_info.model_uri) 875 actual_scores = loaded_model.predict(inference_dataframe) 876 np.testing.assert_array_almost_equal(expected_scores, actual_scores) 877 878 879 def test_virtualenv_subfield_points_to_correct_path(sklearn_logreg_model, model_path): 880 mlflow.sklearn.save_model(sklearn_logreg_model.model, path=model_path) 881 pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) 882 python_env_path = Path(model_path, pyfunc_conf[pyfunc.ENV]["virtualenv"]) 883 assert python_env_path.exists() 884 assert python_env_path.is_file() 885 886 887 def test_model_save_load_with_metadata(sklearn_knn_model, model_path): 888 mlflow.sklearn.save_model( 889 sklearn_knn_model.model, path=model_path, metadata={"metadata_key": "metadata_value"} 890 ) 891 892 reloaded_model = mlflow.pyfunc.load_model(model_uri=model_path) 893 assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value" 894 895 896 def test_model_log_with_metadata(sklearn_knn_model): 897 artifact_path = "model" 898 899 with mlflow.start_run(): 900 model_info = mlflow.sklearn.log_model( 901 sklearn_knn_model.model, 902 name=artifact_path, 903 metadata={"metadata_key": "metadata_value"}, 904 ) 905 906 reloaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri) 907 assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value" 908 909 910 def test_model_log_with_signature_inference(sklearn_knn_model, iris_signature): 911 artifact_path = "model" 912 X = sklearn_knn_model.inference_data 913 example = X.iloc[[0]] 914 915 with mlflow.start_run(): 916 model_info = mlflow.sklearn.log_model( 917 sklearn_knn_model.model, name=artifact_path, input_example=example 918 ) 919 920 mlflow_model = Model.load(model_info.model_uri) 921 assert mlflow_model.signature == iris_signature 922 923 924 def test_model_size_bytes(sklearn_logreg_model, tmp_path): 925 mlflow.sklearn.save_model(sklearn_logreg_model.model, path=tmp_path) 926 927 # expected size only counts for files saved before the MLmodel file is saved 928 model_file = tmp_path.joinpath("model.pkl") 929 with model_file.open("rb") as fp: 930 expected_size = len(fp.read()) 931 932 mlmodel = yaml.safe_load(tmp_path.joinpath("MLmodel").read_bytes()) 933 assert mlmodel["model_size_bytes"] == expected_size 934 935 936 def test_model_registration_metadata_handling(sklearn_knn_model, tmp_path): 937 artifact_path = "model" 938 with mlflow.start_run(): 939 mlflow.sklearn.log_model( 940 sklearn_knn_model.model, 941 name=artifact_path, 942 registered_model_name="test", 943 ) 944 model_uri = "models:/test/1" 945 946 artifact_repository = get_artifact_repository(model_uri) 947 948 dst_full = tmp_path.joinpath("full") 949 dst_full.mkdir() 950 951 artifact_repository.download_artifacts("MLmodel", dst_full) 952 # This validates that the models artifact repo will not attempt to create a 953 # "registered model metadata" file if the source of an artifact download is a file. 954 assert os.listdir(dst_full) == ["MLmodel"] 955 956 957 def test_pipeline_predict_proba(sklearn_knn_model, model_path): 958 knn_model = sklearn_knn_model.model 959 pipeline = make_pipeline(knn_model) 960 961 mlflow.sklearn.save_model(sk_model=pipeline, path=model_path, pyfunc_predict_fn="predict_proba") 962 reloaded_knn_pyfunc = pyfunc.load_model(model_uri=model_path) 963 964 np.testing.assert_array_equal( 965 knn_model.predict_proba(sklearn_knn_model.inference_data), 966 reloaded_knn_pyfunc.predict(sklearn_knn_model.inference_data), 967 ) 968 969 970 def test_get_raw_model(sklearn_knn_model): 971 with mlflow.start_run(): 972 model_info = mlflow.sklearn.log_model( 973 sklearn_knn_model.model, name="model", input_example=sklearn_knn_model.inference_data 974 ) 975 pyfunc_model = pyfunc.load_model(model_info.model_uri) 976 raw_model = pyfunc_model.get_raw_model() 977 assert type(raw_model) == type(sklearn_knn_model.model) 978 np.testing.assert_array_equal( 979 raw_model.predict(sklearn_knn_model.inference_data), 980 sklearn_knn_model.model.predict(sklearn_knn_model.inference_data), 981 )