/ tests / sklearn / test_sklearn_model_export.py
test_sklearn_model_export.py
  1  import json
  2  import os
  3  import pickle
  4  import shutil
  5  import tempfile
  6  from pathlib import Path
  7  from typing import Any, NamedTuple
  8  from unittest import mock
  9  
 10  import cloudpickle
 11  import numpy as np
 12  import pandas as pd
 13  import pytest
 14  import sklearn
 15  import sklearn.linear_model as glm
 16  import sklearn.naive_bayes as nb
 17  import sklearn.neighbors as knn
 18  import skops
 19  import yaml
 20  from packaging.version import Version
 21  from sklearn import datasets
 22  from sklearn.pipeline import Pipeline as SKPipeline
 23  from sklearn.pipeline import make_pipeline
 24  from sklearn.preprocessing import FunctionTransformer as SKFunctionTransformer
 25  
 26  import mlflow.pyfunc.scoring_server as pyfunc_scoring_server
 27  import mlflow.sklearn
 28  from mlflow import pyfunc
 29  from mlflow.entities.model_registry.model_version import ModelVersion, ModelVersionStatus
 30  from mlflow.exceptions import MlflowException
 31  from mlflow.models import Model, ModelSignature
 32  from mlflow.models.utils import _read_example, load_serving_example
 33  from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, ErrorCode
 34  from mlflow.store._unity_catalog.registry.rest_store import UcModelRegistryStore
 35  from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository
 36  from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository
 37  from mlflow.tracking.artifact_utils import _download_artifact_from_uri
 38  from mlflow.types import DataType
 39  from mlflow.types.schema import ColSpec, Schema
 40  from mlflow.utils.environment import _mlflow_conda_env
 41  from mlflow.utils.file_utils import TempDir
 42  from mlflow.utils.model_utils import _get_flavor_configuration
 43  
 44  from tests.helper_functions import (
 45      _assert_pip_requirements,
 46      _compare_conda_env_requirements,
 47      _compare_logged_code_paths,
 48      _is_available_on_pypi,
 49      _mlflow_major_version_string,
 50      assert_register_model_called_with_local_model_path,
 51      pyfunc_serve_and_score_model,
 52  )
 53  from tests.store._unity_catalog.conftest import (
 54      configure_client_for_uc,  # noqa: F401
 55      mock_databricks_uc_host_creds,  # noqa: F401
 56  )
 57  
 58  EXTRA_PYFUNC_SERVING_TEST_ARGS = (
 59      [] if _is_available_on_pypi("scikit-learn", module="sklearn") else ["--env-manager", "local"]
 60  )
 61  
 62  
 63  class ModelWithData(NamedTuple):
 64      model: Any
 65      inference_data: Any
 66  
 67  
 68  @pytest.fixture(scope="module")
 69  def iris_df():
 70      iris = datasets.load_iris()
 71      X = iris.data
 72      y = iris.target
 73      X_df = pd.DataFrame(X, columns=iris.feature_names)
 74      X_df = X_df.iloc[:, :2]  # we only take the first two features.
 75      y_series = pd.Series(y)
 76      return X_df, y_series
 77  
 78  
 79  @pytest.fixture(scope="module")
 80  def iris_signature():
 81      return ModelSignature(
 82          inputs=Schema([
 83              ColSpec(name="sepal length (cm)", type=DataType.double),
 84              ColSpec(name="sepal width (cm)", type=DataType.double),
 85          ]),
 86          outputs=Schema([ColSpec(type=DataType.long)]),
 87      )
 88  
 89  
 90  @pytest.fixture(scope="module")
 91  def sklearn_knn_model(iris_df):
 92      X, y = iris_df
 93      knn_model = knn.KNeighborsClassifier()
 94      knn_model.fit(X, y)
 95      return ModelWithData(model=knn_model, inference_data=X)
 96  
 97  
 98  # To load sklearn KNN model as skops format,
 99  # We need to mark these types as `skops_trusted_types`
100  # related ticket: https://github.com/skops-dev/skops/issues/498
101  sklearn_knn_model_skops_trusted_types = [
102      "sklearn.metrics._dist_metrics.EuclideanDistance64",
103      "sklearn.neighbors._kd_tree.KDTree",
104  ]
105  
106  
107  @pytest.fixture(scope="module")
108  def sklearn_logreg_model(iris_df):
109      X, y = iris_df
110      linear_lr = glm.LogisticRegression()
111      linear_lr.fit(X, y)
112      return ModelWithData(model=linear_lr, inference_data=X)
113  
114  
115  @pytest.fixture(scope="module")
116  def sklearn_gaussian_model(iris_df):
117      X, y = iris_df
118      gaussian_nb = nb.GaussianNB()
119      gaussian_nb.fit(X, y)
120      return ModelWithData(model=gaussian_nb, inference_data=X)
121  
122  
123  @pytest.fixture(scope="module")
124  def sklearn_custom_transformer_model(sklearn_knn_model, iris_df):
125      def transform(vec):
126          return vec + 1
127  
128      transformer = SKFunctionTransformer(transform, validate=True)
129      pipeline = SKPipeline([("custom_transformer", transformer), ("knn", sklearn_knn_model.model)])
130      X, _ = iris_df
131      return ModelWithData(pipeline, inference_data=X)
132  
133  
134  @pytest.fixture
135  def model_path(tmp_path):
136      return os.path.join(tmp_path, "model")
137  
138  
139  @pytest.fixture
140  def sklearn_custom_env(tmp_path):
141      conda_env = os.path.join(tmp_path, "conda_env.yml")
142      _mlflow_conda_env(conda_env, additional_pip_deps=["scikit-learn", "pytest"])
143      return conda_env
144  
145  
146  @pytest.mark.parametrize("serialization_format", mlflow.sklearn.SUPPORTED_SERIALIZATION_FORMATS)
147  def test_model_save_load(sklearn_logreg_model, model_path, serialization_format):
148      from mlflow.utils.requirements_utils import _parse_requirements
149  
150      sk_model = sklearn_logreg_model.model
151      mlflow.sklearn.save_model(
152          sk_model=sk_model, path=model_path, serialization_format=serialization_format
153      )
154      reloaded_model = mlflow.sklearn.load_model(model_uri=model_path)
155      reloaded_pyfunc = pyfunc.load_model(model_uri=model_path)
156  
157      sklearn_conf = _get_flavor_configuration(
158          model_path=model_path, flavor_name=mlflow.sklearn.FLAVOR_NAME
159      )
160      assert "serialization_format" in sklearn_conf
161      assert sklearn_conf["serialization_format"] == serialization_format
162  
163      req_map = {
164          mlflow.sklearn.SERIALIZATION_FORMAT_SKOPS: f"skops=={skops.__version__}",
165          mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE: f"cloudpickle=={cloudpickle.__version__}",
166      }
167  
168      logged_reqs = [
169          req.req_str
170          for req in _parse_requirements(
171              os.path.join(model_path, "requirements.txt"), is_constraint=False
172          )
173      ]
174      if serialization_format != mlflow.sklearn.SERIALIZATION_FORMAT_PICKLE:
175          assert req_map[serialization_format] in logged_reqs
176  
177      np.testing.assert_array_equal(
178          sk_model.predict(sklearn_logreg_model.inference_data),
179          reloaded_model.predict(sklearn_logreg_model.inference_data),
180      )
181  
182      np.testing.assert_array_equal(
183          reloaded_model.predict(sklearn_logreg_model.inference_data),
184          reloaded_pyfunc.predict(sklearn_logreg_model.inference_data),
185      )
186  
187  
188  def test_model_skops_format_trusted_type(sklearn_knn_model, model_path):
189      sk_model = sklearn_knn_model.model
190  
191      with pytest.raises(MlflowException, match="The saved sklearn model references untrusted type"):
192          mlflow.sklearn.save_model(
193              sk_model=sk_model,
194              path=model_path,
195              serialization_format="skops",
196          )
197  
198      shutil.rmtree(model_path)
199      mlflow.sklearn.save_model(
200          sk_model=sklearn_knn_model.model,
201          path=model_path,
202          serialization_format="skops",
203          skops_trusted_types=sklearn_knn_model_skops_trusted_types,
204      )
205      reloaded_model = mlflow.sklearn.load_model(model_uri=model_path)
206      reloaded_pyfunc = pyfunc.load_model(model_uri=model_path)
207      np.testing.assert_array_equal(
208          sk_model.predict(sklearn_knn_model.inference_data),
209          reloaded_model.predict(sklearn_knn_model.inference_data),
210      )
211  
212      np.testing.assert_array_equal(
213          reloaded_model.predict(sklearn_knn_model.inference_data),
214          reloaded_pyfunc.predict(sklearn_knn_model.inference_data),
215      )
216  
217  
218  def test_log_model_skops_no_pip_requirements_warning(sklearn_logreg_model, recwarn):
219      with mlflow.start_run():
220          mlflow.sklearn.log_model(
221              sklearn_logreg_model.model,
222              serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_SKOPS,
223          )
224      warning_messages = [str(w.message) for w in recwarn]
225      assert not any("Fall back to return" in msg for msg in warning_messages)
226  
227  
228  def test_model_save_behavior_with_preexisting_folders(sklearn_knn_model, tmp_path):
229      sklearn_model_path = tmp_path / "sklearn_model_empty_exists"
230      sklearn_model_path.mkdir()
231      mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path)
232  
233      sklearn_model_path = tmp_path / "sklearn_model_filled_exists"
234      sklearn_model_path.mkdir()
235      (sklearn_model_path / "foo.txt").write_text("dummy content")
236      with pytest.raises(MlflowException, match="already exists and is not empty"):
237          mlflow.sklearn.save_model(sk_model=sklearn_knn_model, path=sklearn_model_path)
238  
239  
240  def test_signature_and_examples_are_saved_correctly(sklearn_knn_model, iris_signature):
241      data = sklearn_knn_model.inference_data
242      model = sklearn_knn_model.model
243      example_ = data[:3]
244      for signature in (None, iris_signature):
245          for example in (None, example_):
246              with TempDir() as tmp:
247                  path = tmp.path("model")
248                  mlflow.sklearn.save_model(
249                      model, path=path, signature=signature, input_example=example
250                  )
251                  mlflow_model = Model.load(path)
252                  if signature is None and example is None:
253                      assert mlflow_model.signature is None
254                  else:
255                      assert mlflow_model.signature == iris_signature
256                  if example is None:
257                      assert mlflow_model.saved_input_example_info is None
258                  else:
259                      np.testing.assert_array_equal(_read_example(mlflow_model, path), example)
260  
261  
262  def test_model_load_from_remote_uri_succeeds(sklearn_knn_model, model_path, mock_s3_bucket):
263      mlflow.sklearn.save_model(sk_model=sklearn_knn_model.model, path=model_path)
264  
265      artifact_root = f"s3://{mock_s3_bucket}"
266      artifact_path = "model"
267      artifact_repo = S3ArtifactRepository(artifact_root)
268      artifact_repo.log_artifacts(model_path, artifact_path=artifact_path)
269  
270      model_uri = artifact_root + "/" + artifact_path
271      reloaded_knn_model = mlflow.sklearn.load_model(model_uri=model_uri)
272      np.testing.assert_array_equal(
273          sklearn_knn_model.model.predict(sklearn_knn_model.inference_data),
274          reloaded_knn_model.predict(sklearn_knn_model.inference_data),
275      )
276  
277  
278  def test_model_log(sklearn_logreg_model, model_path):
279      with TempDir(chdr=True, remove_on_exit=True) as tmp:
280          for should_start_run in [False, True]:
281              try:
282                  if should_start_run:
283                      mlflow.start_run()
284  
285                  artifact_path = "linear"
286                  conda_env = os.path.join(tmp.path(), "conda_env.yaml")
287                  _mlflow_conda_env(conda_env, additional_pip_deps=["scikit-learn"])
288  
289                  model_info = mlflow.sklearn.log_model(
290                      sklearn_logreg_model.model,
291                      name=artifact_path,
292                      conda_env=conda_env,
293                  )
294  
295                  reloaded_logsklearn_knn_model = mlflow.sklearn.load_model(
296                      model_uri=model_info.model_uri
297                  )
298                  np.testing.assert_array_equal(
299                      sklearn_logreg_model.model.predict(sklearn_logreg_model.inference_data),
300                      reloaded_logsklearn_knn_model.predict(sklearn_logreg_model.inference_data),
301                  )
302  
303                  model_path = _download_artifact_from_uri(artifact_uri=model_info.model_uri)
304                  model_config = Model.load(os.path.join(model_path, "MLmodel"))
305                  assert pyfunc.FLAVOR_NAME in model_config.flavors
306                  assert pyfunc.ENV in model_config.flavors[pyfunc.FLAVOR_NAME]
307                  env_path = model_config.flavors[pyfunc.FLAVOR_NAME][pyfunc.ENV]["conda"]
308                  assert os.path.exists(os.path.join(model_path, env_path))
309  
310              finally:
311                  mlflow.end_run()
312  
313  
314  def test_log_model_calls_register_model(sklearn_logreg_model):
315      artifact_path = "linear"
316      register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model")
317      with mlflow.start_run(), register_model_patch, TempDir(chdr=True, remove_on_exit=True) as tmp:
318          conda_env = os.path.join(tmp.path(), "conda_env.yaml")
319          _mlflow_conda_env(conda_env, additional_pip_deps=["scikit-learn"])
320          model_info = mlflow.sklearn.log_model(
321              sklearn_logreg_model.model,
322              name=artifact_path,
323              conda_env=conda_env,
324              registered_model_name="AdsModel1",
325          )
326          assert_register_model_called_with_local_model_path(
327              register_model_mock=mlflow.tracking._model_registry.fluent._register_model,
328              model_uri=model_info.model_uri,
329              registered_model_name="AdsModel1",
330          )
331  
332  
333  def test_log_model_call_register_model_to_uc(configure_client_for_uc, sklearn_logreg_model):
334      artifact_path = "linear"
335      mock_model_version = ModelVersion(
336          name="AdsModel1",
337          version=1,
338          creation_timestamp=123,
339          status=ModelVersionStatus.to_string(ModelVersionStatus.READY),
340      )
341      with (
342          mock.patch.object(UcModelRegistryStore, "create_registered_model"),
343          mock.patch.object(
344              UcModelRegistryStore,
345              "create_model_version",
346              return_value=mock_model_version,
347              autospec=True,
348          ) as mock_create_mv,
349          TempDir(chdr=True, remove_on_exit=True) as tmp,
350      ):
351          with mlflow.start_run() as run:
352              conda_env = os.path.join(tmp.path(), "conda_env.yaml")
353              _mlflow_conda_env(conda_env, additional_pip_deps=["scikit-learn"])
354              model_info = mlflow.sklearn.log_model(
355                  sklearn_logreg_model.model,
356                  name=artifact_path,
357                  conda_env=conda_env,
358                  registered_model_name="AdsModel1",
359              )
360              source = model_info.artifact_path
361              [(args, kwargs)] = mock_create_mv.call_args_list
362              assert args[1:] == ("AdsModel1", source, run.info.run_id, [], None, None)
363              assert kwargs["local_model_path"].startswith(tempfile.gettempdir())
364  
365  
366  def test_log_model_no_registered_model_name(sklearn_logreg_model):
367      artifact_path = "model"
368      register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model")
369      with mlflow.start_run(), register_model_patch, TempDir(chdr=True, remove_on_exit=True) as tmp:
370          conda_env = os.path.join(tmp.path(), "conda_env.yaml")
371          _mlflow_conda_env(conda_env, additional_pip_deps=["scikit-learn"])
372          mlflow.sklearn.log_model(
373              sklearn_logreg_model.model,
374              name=artifact_path,
375              conda_env=conda_env,
376          )
377          mlflow.tracking._model_registry.fluent._register_model.assert_not_called()
378  
379  
380  def test_custom_transformer_can_be_saved_and_loaded_with_cloudpickle_format(
381      sklearn_custom_transformer_model, tmp_path
382  ):
383      custom_transformer_model = sklearn_custom_transformer_model.model
384  
385      # Because the model contains a customer transformer that is not defined at the top level of the
386      # current test module, we expect pickle to fail when attempting to serialize it. In contrast,
387      # we expect cloudpickle to successfully locate the transformer definition and serialize the
388      # model successfully.
389      pickle_format_model_path = os.path.join(tmp_path, "pickle_model")
390      with pytest.raises(AttributeError, match="Can't pickle local object"):
391          mlflow.sklearn.save_model(
392              sk_model=custom_transformer_model,
393              path=pickle_format_model_path,
394              serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_PICKLE,
395          )
396  
397      cloudpickle_format_model_path = os.path.join(tmp_path, "cloud_pickle_model")
398      mlflow.sklearn.save_model(
399          sk_model=custom_transformer_model,
400          path=cloudpickle_format_model_path,
401          serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE,
402      )
403  
404      reloaded_custom_transformer_model = mlflow.sklearn.load_model(
405          model_uri=cloudpickle_format_model_path
406      )
407  
408      np.testing.assert_array_equal(
409          custom_transformer_model.predict(sklearn_custom_transformer_model.inference_data),
410          reloaded_custom_transformer_model.predict(sklearn_custom_transformer_model.inference_data),
411      )
412  
413  
414  def test_model_save_persists_specified_conda_env_in_mlflow_model_directory(
415      sklearn_logreg_model, model_path, sklearn_custom_env
416  ):
417      mlflow.sklearn.save_model(
418          sk_model=sklearn_logreg_model.model,
419          path=model_path,
420          conda_env=sklearn_custom_env,
421      )
422  
423      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
424      saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"])
425      assert os.path.exists(saved_conda_env_path)
426      assert saved_conda_env_path != sklearn_custom_env
427  
428      with open(sklearn_custom_env) as f:
429          sklearn_custom_env_parsed = yaml.safe_load(f)
430      with open(saved_conda_env_path) as f:
431          saved_conda_env_parsed = yaml.safe_load(f)
432      assert saved_conda_env_parsed == sklearn_custom_env_parsed
433  
434  
435  def test_model_save_persists_requirements_in_mlflow_model_directory(
436      sklearn_knn_model, model_path, sklearn_custom_env
437  ):
438      mlflow.sklearn.save_model(
439          sk_model=sklearn_knn_model.model, path=model_path, conda_env=sklearn_custom_env
440      )
441  
442      saved_pip_req_path = os.path.join(model_path, "requirements.txt")
443      _compare_conda_env_requirements(sklearn_custom_env, saved_pip_req_path)
444  
445  
446  def test_log_model_with_pip_requirements(sklearn_knn_model, tmp_path):
447      expected_mlflow_version = _mlflow_major_version_string()
448      # Path to a requirements file
449      req_file = tmp_path.joinpath("requirements.txt")
450      req_file.write_text("a")
451      with mlflow.start_run():
452          model_info = mlflow.sklearn.log_model(
453              sklearn_knn_model.model, name="model", pip_requirements=str(req_file)
454          )
455          _assert_pip_requirements(model_info.model_uri, [expected_mlflow_version, "a"], strict=True)
456  
457      # List of requirements
458      with mlflow.start_run():
459          model_info = mlflow.sklearn.log_model(
460              sklearn_knn_model.model, name="model", pip_requirements=[f"-r {req_file}", "b"]
461          )
462          _assert_pip_requirements(
463              model_info.model_uri, [expected_mlflow_version, "a", "b"], strict=True
464          )
465  
466      # Constraints file
467      with mlflow.start_run():
468          model_info = mlflow.sklearn.log_model(
469              sklearn_knn_model.model, name="model", pip_requirements=[f"-c {req_file}", "b"]
470          )
471          _assert_pip_requirements(
472              model_info.model_uri,
473              [expected_mlflow_version, "b", "-c constraints.txt"],
474              ["a"],
475              strict=True,
476          )
477  
478  
479  def test_log_model_with_extra_pip_requirements(sklearn_knn_model, tmp_path):
480      expected_mlflow_version = _mlflow_major_version_string()
481      default_reqs = mlflow.sklearn.get_default_pip_requirements(include_cloudpickle=True)
482  
483      # Path to a requirements file
484      req_file = tmp_path.joinpath("requirements.txt")
485      req_file.write_text("a")
486      with mlflow.start_run():
487          model_info = mlflow.sklearn.log_model(
488              sklearn_knn_model.model, name="model", extra_pip_requirements=str(req_file)
489          )
490          _assert_pip_requirements(
491              model_info.model_uri, [expected_mlflow_version, *default_reqs, "a"]
492          )
493  
494      # List of requirements
495      with mlflow.start_run():
496          model_info = mlflow.sklearn.log_model(
497              sklearn_knn_model.model, name="model", extra_pip_requirements=[f"-r {req_file}", "b"]
498          )
499          _assert_pip_requirements(
500              model_info.model_uri, [expected_mlflow_version, *default_reqs, "a", "b"]
501          )
502  
503      # Constraints file
504      with mlflow.start_run():
505          model_info = mlflow.sklearn.log_model(
506              sklearn_knn_model.model, name="model", extra_pip_requirements=[f"-c {req_file}", "b"]
507          )
508          _assert_pip_requirements(
509              model_info.model_uri,
510              [expected_mlflow_version, *default_reqs, "b", "-c constraints.txt"],
511              ["a"],
512          )
513  
514  
515  def test_model_save_accepts_conda_env_as_dict(sklearn_knn_model, model_path):
516      conda_env = dict(mlflow.sklearn.get_default_conda_env())
517      conda_env["dependencies"].append("pytest")
518      mlflow.sklearn.save_model(
519          sk_model=sklearn_knn_model.model, path=model_path, conda_env=conda_env
520      )
521  
522      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
523      saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"])
524      assert os.path.exists(saved_conda_env_path)
525  
526      with open(saved_conda_env_path) as f:
527          saved_conda_env_parsed = yaml.safe_load(f)
528      assert saved_conda_env_parsed == conda_env
529  
530  
531  def test_model_log_persists_specified_conda_env_in_mlflow_model_directory(
532      sklearn_knn_model, sklearn_custom_env
533  ):
534      artifact_path = "model"
535      with mlflow.start_run():
536          model_info = mlflow.sklearn.log_model(
537              sklearn_knn_model.model,
538              name=artifact_path,
539              conda_env=sklearn_custom_env,
540          )
541          model_uri = model_info.model_uri
542  
543      model_path = _download_artifact_from_uri(artifact_uri=model_uri)
544      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
545      saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"])
546      assert os.path.exists(saved_conda_env_path)
547      assert saved_conda_env_path != sklearn_custom_env
548  
549      with open(sklearn_custom_env) as f:
550          sklearn_custom_env_parsed = yaml.safe_load(f)
551      with open(saved_conda_env_path) as f:
552          saved_conda_env_parsed = yaml.safe_load(f)
553      assert saved_conda_env_parsed == sklearn_custom_env_parsed
554  
555  
556  def test_model_log_persists_requirements_in_mlflow_model_directory(
557      sklearn_knn_model, sklearn_custom_env
558  ):
559      with mlflow.start_run():
560          model_info = mlflow.sklearn.log_model(
561              sklearn_knn_model.model,
562              name="model",
563              conda_env=sklearn_custom_env,
564          )
565  
566      model_path = _download_artifact_from_uri(artifact_uri=model_info.model_uri)
567      saved_pip_req_path = os.path.join(model_path, "requirements.txt")
568      _compare_conda_env_requirements(sklearn_custom_env, saved_pip_req_path)
569  
570  
571  def test_model_save_throws_exception_if_serialization_format_is_unrecognized(
572      sklearn_knn_model, model_path
573  ):
574      with pytest.raises(MlflowException, match="Unrecognized serialization format") as exc:
575          mlflow.sklearn.save_model(
576              sk_model=sklearn_knn_model.model,
577              path=model_path,
578              serialization_format="not a valid format",
579          )
580      assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
581  
582      # The unsupported serialization format should have been detected prior to the execution of
583      # any directory creation or state-mutating persistence logic that would prevent a second
584      # serialization call with the same model path from succeeding
585      assert not os.path.exists(model_path)
586      mlflow.sklearn.save_model(sk_model=sklearn_knn_model.model, path=model_path)
587  
588  
589  def test_model_save_without_specified_conda_env_uses_default_env_with_expected_dependencies(
590      sklearn_knn_model, model_path
591  ):
592      mlflow.sklearn.save_model(sk_model=sklearn_knn_model.model, path=model_path)
593      _assert_pip_requirements(
594          model_path, mlflow.sklearn.get_default_pip_requirements(include_cloudpickle=True)
595      )
596  
597  
598  def test_model_log_without_specified_conda_env_uses_default_env_with_expected_dependencies(
599      sklearn_knn_model,
600  ):
601      with mlflow.start_run():
602          model_info = mlflow.sklearn.log_model(sklearn_knn_model.model, name="model")
603  
604      _assert_pip_requirements(
605          model_info.model_uri, mlflow.sklearn.get_default_pip_requirements(include_cloudpickle=True)
606      )
607  
608  
609  def test_model_save_uses_cloudpickle_serialization_format_by_default(sklearn_knn_model, model_path):
610      mlflow.sklearn.save_model(sk_model=sklearn_knn_model.model, path=model_path)
611  
612      sklearn_conf = _get_flavor_configuration(
613          model_path=model_path, flavor_name=mlflow.sklearn.FLAVOR_NAME
614      )
615      assert "serialization_format" in sklearn_conf
616      assert sklearn_conf["serialization_format"] == mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE
617  
618  
619  def test_model_log_uses_cloudpickle_serialization_format_by_default(sklearn_knn_model):
620      with mlflow.start_run():
621          model_info = mlflow.sklearn.log_model(sklearn_knn_model.model, name="model")
622  
623      model_path = _download_artifact_from_uri(artifact_uri=model_info.model_uri)
624      sklearn_conf = _get_flavor_configuration(
625          model_path=model_path, flavor_name=mlflow.sklearn.FLAVOR_NAME
626      )
627      assert "serialization_format" in sklearn_conf
628      assert sklearn_conf["serialization_format"] == mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE
629  
630  
631  def test_model_save_with_cloudpickle_format_adds_cloudpickle_to_conda_environment(
632      sklearn_knn_model, model_path
633  ):
634      mlflow.sklearn.save_model(
635          sk_model=sklearn_knn_model.model,
636          path=model_path,
637          serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE,
638      )
639  
640      sklearn_conf = _get_flavor_configuration(
641          model_path=model_path, flavor_name=mlflow.sklearn.FLAVOR_NAME
642      )
643      assert "serialization_format" in sklearn_conf
644      assert sklearn_conf["serialization_format"] == mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE
645  
646      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
647      saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"])
648      assert os.path.exists(saved_conda_env_path)
649      with open(saved_conda_env_path) as f:
650          saved_conda_env_parsed = yaml.safe_load(f)
651  
652      pip_deps = [
653          dependency
654          for dependency in saved_conda_env_parsed["dependencies"]
655          if type(dependency) == dict and "pip" in dependency
656      ]
657      assert len(pip_deps) == 1
658      assert any("cloudpickle" in pip_dep for pip_dep in pip_deps[0]["pip"])
659  
660  
661  def test_model_save_without_cloudpickle_format_does_not_add_cloudpickle_to_conda_environment(
662      sklearn_logreg_model, model_path
663  ):
664      non_cloudpickle_serialization_formats = list(mlflow.sklearn.SUPPORTED_SERIALIZATION_FORMATS)
665      non_cloudpickle_serialization_formats.remove(mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE)
666  
667      for serialization_format in non_cloudpickle_serialization_formats:
668          mlflow.sklearn.save_model(
669              sk_model=sklearn_logreg_model.model,
670              path=model_path,
671              serialization_format=serialization_format,
672          )
673  
674          sklearn_conf = _get_flavor_configuration(
675              model_path=model_path, flavor_name=mlflow.sklearn.FLAVOR_NAME
676          )
677          assert "serialization_format" in sklearn_conf
678          assert sklearn_conf["serialization_format"] == serialization_format
679  
680          pyfunc_conf = _get_flavor_configuration(
681              model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME
682          )
683          saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"])
684          assert os.path.exists(saved_conda_env_path)
685          with open(saved_conda_env_path) as f:
686              saved_conda_env_parsed = yaml.safe_load(f)
687          assert all(
688              "cloudpickle" not in dependency for dependency in saved_conda_env_parsed["dependencies"]
689          )
690  
691          shutil.rmtree(model_path)
692  
693  
694  def test_load_pyfunc_succeeds_for_older_models_with_pyfunc_data_field(
695      sklearn_knn_model, model_path
696  ):
697      """
698      This test verifies that scikit-learn models saved in older versions of MLflow are loaded
699      successfully by ``mlflow.pyfunc.load_model``. These older models specify a pyfunc ``data``
700      field referring directly to a serialized scikit-learn model file. In contrast, newer models
701      omit the ``data`` field.
702      """
703      mlflow.sklearn.save_model(
704          sk_model=sklearn_knn_model.model,
705          path=model_path,
706          serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_PICKLE,
707      )
708  
709      model_conf_path = os.path.join(model_path, "MLmodel")
710      model_conf = Model.load(model_conf_path)
711      pyfunc_conf = model_conf.flavors.get(pyfunc.FLAVOR_NAME)
712      sklearn_conf = model_conf.flavors.get(mlflow.sklearn.FLAVOR_NAME)
713      assert sklearn_conf is not None
714      assert pyfunc_conf is not None
715      pyfunc_conf[pyfunc.DATA] = sklearn_conf["pickled_model"]
716  
717      reloaded_knn_pyfunc = pyfunc.load_model(model_uri=model_path)
718  
719      np.testing.assert_array_equal(
720          sklearn_knn_model.model.predict(sklearn_knn_model.inference_data),
721          reloaded_knn_pyfunc.predict(sklearn_knn_model.inference_data),
722      )
723  
724  
725  def test_add_pyfunc_flavor_only_when_model_defines_predict(model_path):
726      from sklearn.cluster import AgglomerativeClustering
727  
728      sk_model = AgglomerativeClustering()
729      assert not hasattr(sk_model, "predict")
730  
731      mlflow.sklearn.save_model(
732          sk_model=sk_model,
733          path=model_path,
734          serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_PICKLE,
735      )
736  
737      model_conf_path = os.path.join(model_path, "MLmodel")
738      model_conf = Model.load(model_conf_path)
739      assert pyfunc.FLAVOR_NAME not in model_conf.flavors
740  
741  
742  def test_pyfunc_serve_and_score(sklearn_knn_model):
743      model, inference_dataframe = sklearn_knn_model
744      artifact_path = "model"
745      with mlflow.start_run():
746          model_info = mlflow.sklearn.log_model(
747              model, name=artifact_path, input_example=inference_dataframe
748          )
749  
750      inference_payload = load_serving_example(model_info.model_uri)
751      resp = pyfunc_serve_and_score_model(
752          model_info.model_uri,
753          data=inference_payload,
754          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
755          extra_args=EXTRA_PYFUNC_SERVING_TEST_ARGS,
756      )
757      scores = pd.DataFrame(
758          data=json.loads(resp.content.decode("utf-8"))["predictions"]
759      ).values.squeeze()
760      np.testing.assert_array_almost_equal(scores, model.predict(inference_dataframe))
761  
762  
763  def test_sklearn_compatible_with_mlflow_2_4_0(sklearn_knn_model, tmp_path):
764      model, inference_dataframe = sklearn_knn_model
765      model_predict = model.predict(inference_dataframe)
766  
767      # save test model
768      tmp_path.joinpath("MLmodel").write_text(
769          f"""
770  artifact_path: model
771  flavors:
772    python_function:
773      env:
774        conda: conda.yaml
775        virtualenv: python_env.yaml
776      loader_module: mlflow.sklearn
777      model_path: model.pkl
778      predict_fn: predict
779      python_version: 3.11.14
780    sklearn:
781      code: null
782      pickled_model: model.pkl
783      serialization_format: cloudpickle
784      sklearn_version: {sklearn.__version__}
785  mlflow_version: 2.4.0
786  model_uuid: c9833d74b1ff4013a1c9eff05d39eeef
787  run_id: 8146a2ae86104f5b853351e600fc9d7b
788  utc_time_created: '2023-07-04 07:19:43.561797'
789  """
790      )
791      tmp_path.joinpath("python_env.yaml").write_text(
792          """
793  python: 3.11.14
794  build_dependencies:
795     - pip==25.1.1
796     - setuptools==80.4.0
797     - wheel==0.45.1
798  dependencies:
799     - -r requirements.txt
800  """
801      )
802      tmp_path.joinpath("requirements.txt").write_text(
803          f"""
804  mlflow==2.4.0
805  cloudpickle
806  numpy
807  psutil
808  scikit-learn=={sklearn.__version__}
809  scipy
810  """
811      )
812      with open(tmp_path / "model.pkl", "wb") as out:
813          pickle.dump(model, out, protocol=pickle.DEFAULT_PROTOCOL)
814  
815      assert Version(mlflow.__version__) > Version("2.4.0")
816      model_uri = str(tmp_path)
817      pyfunc_loaded = mlflow.pyfunc.load_model(model_uri)
818  
819      # predict is compatible
820      local_predict = pyfunc_loaded.predict(inference_dataframe)
821      np.testing.assert_array_almost_equal(local_predict, model_predict)
822  
823      # model serving is compatible
824      resp = pyfunc_serve_and_score_model(
825          model_uri,
826          data=pd.DataFrame(inference_dataframe),
827          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
828          extra_args=EXTRA_PYFUNC_SERVING_TEST_ARGS,
829      )
830      scores = pd.DataFrame(
831          data=json.loads(resp.content.decode("utf-8"))["predictions"]
832      ).values.squeeze()
833      np.testing.assert_array_almost_equal(scores, model_predict)
834  
835      # Issues a warning if params are specified prior to MLflow support in 2.5.0
836      with mock.patch("mlflow.models.utils._logger.warning") as mock_warning:
837          pyfunc_loaded.predict(inference_dataframe, params={"top_k": 2})
838      mock_warning.assert_called_with(
839          "`params` can only be specified at inference time if the model signature defines a params "
840          "schema. This model does not define a params schema. Ignoring provided params: "
841          "['top_k']"
842      )
843  
844  
845  def test_log_model_with_code_paths(sklearn_knn_model):
846      artifact_path = "model"
847      with (
848          mlflow.start_run(),
849          mock.patch("mlflow.sklearn._add_code_from_conf_to_system_path") as add_mock,
850      ):
851          model_info = mlflow.sklearn.log_model(
852              sklearn_knn_model.model, name=artifact_path, code_paths=[__file__]
853          )
854          _compare_logged_code_paths(__file__, model_info.model_uri, mlflow.sklearn.FLAVOR_NAME)
855          mlflow.sklearn.load_model(model_uri=model_info.model_uri)
856          add_mock.assert_called()
857  
858  
859  @pytest.mark.parametrize(
860      "predict_fn", ["predict", "predict_proba", "predict_log_proba", "predict_joint_log_proba"]
861  )
862  def test_log_model_with_custom_pyfunc_predict_fn(sklearn_gaussian_model, predict_fn):
863      if Version(sklearn.__version__) < Version("1.2.0") and predict_fn == "predict_joint_log_proba":
864          pytest.skip("predict_joint_log_proba is not available in scikit-learn < 1.2.0")
865  
866      model, inference_dataframe = sklearn_gaussian_model
867      expected_scores = getattr(model, predict_fn)(inference_dataframe)
868      artifact_path = "model"
869      with mlflow.start_run():
870          model_info = mlflow.sklearn.log_model(
871              model, name=artifact_path, pyfunc_predict_fn=predict_fn
872          )
873  
874      loaded_model = pyfunc.load_model(model_info.model_uri)
875      actual_scores = loaded_model.predict(inference_dataframe)
876      np.testing.assert_array_almost_equal(expected_scores, actual_scores)
877  
878  
879  def test_virtualenv_subfield_points_to_correct_path(sklearn_logreg_model, model_path):
880      mlflow.sklearn.save_model(sklearn_logreg_model.model, path=model_path)
881      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
882      python_env_path = Path(model_path, pyfunc_conf[pyfunc.ENV]["virtualenv"])
883      assert python_env_path.exists()
884      assert python_env_path.is_file()
885  
886  
887  def test_model_save_load_with_metadata(sklearn_knn_model, model_path):
888      mlflow.sklearn.save_model(
889          sklearn_knn_model.model, path=model_path, metadata={"metadata_key": "metadata_value"}
890      )
891  
892      reloaded_model = mlflow.pyfunc.load_model(model_uri=model_path)
893      assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value"
894  
895  
896  def test_model_log_with_metadata(sklearn_knn_model):
897      artifact_path = "model"
898  
899      with mlflow.start_run():
900          model_info = mlflow.sklearn.log_model(
901              sklearn_knn_model.model,
902              name=artifact_path,
903              metadata={"metadata_key": "metadata_value"},
904          )
905  
906      reloaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri)
907      assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value"
908  
909  
910  def test_model_log_with_signature_inference(sklearn_knn_model, iris_signature):
911      artifact_path = "model"
912      X = sklearn_knn_model.inference_data
913      example = X.iloc[[0]]
914  
915      with mlflow.start_run():
916          model_info = mlflow.sklearn.log_model(
917              sklearn_knn_model.model, name=artifact_path, input_example=example
918          )
919  
920      mlflow_model = Model.load(model_info.model_uri)
921      assert mlflow_model.signature == iris_signature
922  
923  
924  def test_model_size_bytes(sklearn_logreg_model, tmp_path):
925      mlflow.sklearn.save_model(sklearn_logreg_model.model, path=tmp_path)
926  
927      # expected size only counts for files saved before the MLmodel file is saved
928      model_file = tmp_path.joinpath("model.pkl")
929      with model_file.open("rb") as fp:
930          expected_size = len(fp.read())
931  
932      mlmodel = yaml.safe_load(tmp_path.joinpath("MLmodel").read_bytes())
933      assert mlmodel["model_size_bytes"] == expected_size
934  
935  
936  def test_model_registration_metadata_handling(sklearn_knn_model, tmp_path):
937      artifact_path = "model"
938      with mlflow.start_run():
939          mlflow.sklearn.log_model(
940              sklearn_knn_model.model,
941              name=artifact_path,
942              registered_model_name="test",
943          )
944          model_uri = "models:/test/1"
945  
946      artifact_repository = get_artifact_repository(model_uri)
947  
948      dst_full = tmp_path.joinpath("full")
949      dst_full.mkdir()
950  
951      artifact_repository.download_artifacts("MLmodel", dst_full)
952      # This validates that the models artifact repo will not attempt to create a
953      # "registered model metadata" file if the source of an artifact download is a file.
954      assert os.listdir(dst_full) == ["MLmodel"]
955  
956  
957  def test_pipeline_predict_proba(sklearn_knn_model, model_path):
958      knn_model = sklearn_knn_model.model
959      pipeline = make_pipeline(knn_model)
960  
961      mlflow.sklearn.save_model(sk_model=pipeline, path=model_path, pyfunc_predict_fn="predict_proba")
962      reloaded_knn_pyfunc = pyfunc.load_model(model_uri=model_path)
963  
964      np.testing.assert_array_equal(
965          knn_model.predict_proba(sklearn_knn_model.inference_data),
966          reloaded_knn_pyfunc.predict(sklearn_knn_model.inference_data),
967      )
968  
969  
970  def test_get_raw_model(sklearn_knn_model):
971      with mlflow.start_run():
972          model_info = mlflow.sklearn.log_model(
973              sklearn_knn_model.model, name="model", input_example=sklearn_knn_model.inference_data
974          )
975      pyfunc_model = pyfunc.load_model(model_info.model_uri)
976      raw_model = pyfunc_model.get_raw_model()
977      assert type(raw_model) == type(sklearn_knn_model.model)
978      np.testing.assert_array_equal(
979          raw_model.predict(sklearn_knn_model.inference_data),
980          sklearn_knn_model.model.predict(sklearn_knn_model.inference_data),
981      )