/ mlflow / sklearn / __init__.py
__init__.py
   1  """
   2  The ``mlflow.sklearn`` module provides an API for logging and loading scikit-learn models. This
   3  module exports scikit-learn models with the following flavors:
   4  
   5  Python (native) `pickle <https://scikit-learn.org/stable/modules/model_persistence.html>`_ format
   6      This is the main flavor that can be loaded back into scikit-learn.
   7  
   8  :py:mod:`mlflow.pyfunc`
   9      Produced for use by generic pyfunc-based deployment tools and batch inference.
  10      NOTE: The `mlflow.pyfunc` flavor is only added for scikit-learn models that define `predict()`,
  11      since `predict()` is required for pyfunc model inference.
  12  """
  13  
  14  import functools
  15  import inspect
  16  import logging
  17  import os
  18  import pickle
  19  import shutil
  20  import weakref
  21  from collections import OrderedDict, defaultdict
  22  from copy import deepcopy
  23  from typing import Any
  24  
  25  import numpy as np
  26  import yaml
  27  from packaging.version import Version
  28  
  29  import mlflow
  30  from mlflow import pyfunc
  31  from mlflow.data.code_dataset_source import CodeDatasetSource
  32  from mlflow.data.numpy_dataset import from_numpy
  33  from mlflow.data.pandas_dataset import from_pandas
  34  from mlflow.entities.dataset_input import DatasetInput
  35  from mlflow.entities.input_tag import InputTag
  36  from mlflow.environment_variables import MLFLOW_ALLOW_PICKLE_DESERIALIZATION
  37  from mlflow.exceptions import MlflowException
  38  from mlflow.models import Model, ModelInputExample, ModelSignature
  39  from mlflow.models.model import MLMODEL_FILE_NAME
  40  from mlflow.models.signature import _infer_signature_from_input_example
  41  from mlflow.models.utils import _save_example
  42  from mlflow.protos.databricks_pb2 import INTERNAL_ERROR, INVALID_PARAMETER_VALUE
  43  from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
  44  from mlflow.tracking.artifact_utils import _download_artifact_from_uri
  45  from mlflow.tracking.client import MlflowClient
  46  from mlflow.utils import _inspect_original_var_name, gorilla
  47  from mlflow.utils.autologging_utils import (
  48      INPUT_EXAMPLE_SAMPLE_ROWS,
  49      MlflowAutologgingQueueingClient,
  50      _get_new_training_session_class,
  51      autologging_integration,
  52      disable_autologging,
  53      get_autologging_config,
  54      get_instance_method_first_arg_value,
  55      resolve_input_example_and_signature,
  56      safe_patch,
  57      update_wrapper_extended,
  58  )
  59  from mlflow.utils.data_utils import is_polars_dataframe
  60  from mlflow.utils.databricks_utils import (
  61      is_in_databricks_model_serving_environment,
  62      is_in_databricks_runtime,
  63  )
  64  from mlflow.utils.docstring_utils import LOG_MODEL_PARAM_DOCS, format_docstring
  65  from mlflow.utils.environment import (
  66      _CONDA_ENV_FILE_NAME,
  67      _CONSTRAINTS_FILE_NAME,
  68      _PYTHON_ENV_FILE_NAME,
  69      _REQUIREMENTS_FILE_NAME,
  70      _mlflow_conda_env,
  71      _process_conda_env,
  72      _process_pip_requirements,
  73      _PythonEnv,
  74      _validate_env_arguments,
  75  )
  76  from mlflow.utils.file_utils import get_total_file_size, write_to
  77  from mlflow.utils.mlflow_tags import (
  78      MLFLOW_AUTOLOGGING,
  79      MLFLOW_DATASET_CONTEXT,
  80  )
  81  from mlflow.utils.model_utils import (
  82      _add_code_from_conf_to_system_path,
  83      _copy_extra_files,
  84      _get_flavor_configuration,
  85      _validate_and_copy_code_paths,
  86      _validate_and_prepare_target_save_path,
  87  )
  88  from mlflow.utils.requirements_utils import _get_pinned_requirement
  89  
  90  FLAVOR_NAME = "sklearn"
  91  
  92  SERIALIZATION_FORMAT_SKOPS = "skops"
  93  SERIALIZATION_FORMAT_PICKLE = "pickle"
  94  SERIALIZATION_FORMAT_CLOUDPICKLE = "cloudpickle"
  95  
  96  SUPPORTED_SERIALIZATION_FORMATS = [
  97      SERIALIZATION_FORMAT_SKOPS,
  98      SERIALIZATION_FORMAT_PICKLE,
  99      SERIALIZATION_FORMAT_CLOUDPICKLE,
 100  ]
 101  
 102  _logger = logging.getLogger(__name__)
 103  _SklearnTrainingSession = _get_new_training_session_class()
 104  
 105  _PICKLE_MODEL_DATA_SUBPATH = "model.pkl"
 106  _SKOPS_MODEL_DATA_SUBPATH = "model.skops"
 107  
 108  
 109  def _gen_estimators_to_patch():
 110      from mlflow.sklearn.utils import (
 111          _all_estimators,
 112          _get_meta_estimators_for_autologging,
 113      )
 114  
 115      _, estimators_to_patch = zip(*_all_estimators())
 116      # Ensure that relevant meta estimators (e.g. GridSearchCV, Pipeline) are selected
 117      # for patching if they are not already included in the output of `all_estimators()`
 118      estimators_to_patch = set(estimators_to_patch).union(
 119          set(_get_meta_estimators_for_autologging())
 120      )
 121      # Exclude certain preprocessing & feature manipulation estimators from patching. These
 122      # estimators represent data manipulation routines (e.g., normalization, label encoding)
 123      # rather than ML algorithms. Accordingly, we should not create MLflow runs and log
 124      # parameters / metrics for these routines, unless they are captured as part of an ML pipeline
 125      # (via `sklearn.pipeline.Pipeline`)
 126      excluded_module_names = [
 127          "sklearn.preprocessing",
 128          "sklearn.impute",
 129          "sklearn.feature_extraction",
 130          "sklearn.feature_selection",
 131      ]
 132  
 133      excluded_class_names = [
 134          "sklearn.compose._column_transformer.ColumnTransformer",
 135      ]
 136  
 137      return [
 138          estimator
 139          for estimator in estimators_to_patch
 140          if not any(
 141              estimator.__module__.startswith(excluded_module_name)
 142              or (estimator.__module__ + "." + estimator.__name__) in excluded_class_names
 143              for excluded_module_name in excluded_module_names
 144          )
 145      ]
 146  
 147  
 148  def get_default_pip_requirements(include_cloudpickle=False, include_skops=False):
 149      """
 150      Returns:
 151          A list of default pip requirements for MLflow Models produced by this flavor.
 152          Calls to :func:`save_model()` and :func:`log_model()` produce a pip environment
 153          that, at minimum, contains these requirements.
 154      """
 155      pip_deps = [_get_pinned_requirement("scikit-learn", module="sklearn")]
 156      if include_cloudpickle:
 157          pip_deps += [_get_pinned_requirement("cloudpickle")]
 158      if include_skops:
 159          pip_deps += [_get_pinned_requirement("skops")]
 160  
 161      return pip_deps
 162  
 163  
 164  def get_default_conda_env(include_cloudpickle=False, include_skops=False):
 165      """
 166      Returns:
 167          The default Conda environment for MLflow Models produced by calls to
 168          :func:`save_model()` and :func:`log_model()`.
 169      """
 170      return _mlflow_conda_env(
 171          additional_pip_deps=get_default_pip_requirements(include_cloudpickle, include_skops)
 172      )
 173  
 174  
 175  @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="scikit-learn"))
 176  def save_model(
 177      sk_model,
 178      path,
 179      conda_env=None,
 180      code_paths=None,
 181      mlflow_model=None,
 182      serialization_format=SERIALIZATION_FORMAT_CLOUDPICKLE,
 183      signature: ModelSignature = None,
 184      input_example: ModelInputExample = None,
 185      pip_requirements=None,
 186      extra_pip_requirements=None,
 187      pyfunc_predict_fn="predict",
 188      metadata=None,
 189      skops_trusted_types=None,
 190      extra_files=None,
 191  ):
 192      """
 193      Save a scikit-learn model to a path on the local file system. Produces a MLflow Model
 194      containing the following flavors:
 195  
 196          - :py:mod:`mlflow.sklearn`
 197          - :py:mod:`mlflow.pyfunc`. NOTE: This flavor is only included for scikit-learn models
 198            that define `predict()`, since `predict()` is required for pyfunc model inference.
 199  
 200      Args:
 201          sk_model: scikit-learn model to be saved.
 202          path: Local path where the model is to be saved.
 203          conda_env: {{ conda_env }}
 204          code_paths: {{ code_paths }}
 205          mlflow_model: :py:mod:`mlflow.models.Model` this flavor is being added to.
 206          serialization_format: The format in which to serialize the model. This should be one of
 207              the formats "skops", "cloudpickle" or "pickle".
 208              The "skops" format guarantees safe deserialization.
 209              The "cloudpickle" format, provides better cross-system compatibility by identifying and
 210              packaging code dependencies with the serialized model, but requires exercising
 211              caution because these formats rely on Python's object serialization mechanism,
 212              which can execute arbitrary code during deserialization.
 213  
 214          signature: {{ signature }}
 215          input_example: {{ input_example }}
 216          pip_requirements: {{ pip_requirements }}
 217          extra_pip_requirements: {{ extra_pip_requirements }}
 218          pyfunc_predict_fn: The name of the prediction function to use for inference with the
 219              pyfunc representation of the resulting MLflow Model. Current supported functions
 220              are: ``"predict"``, ``"predict_proba"``, ``"predict_log_proba"``,
 221              ``"predict_joint_log_proba"``, and ``"score"``.
 222          metadata: {{ metadata }}
 223          skops_trusted_types: A list of trusted types when loading model that is saved as
 224              the ``mlflow.sklearn.SERIALIZATION_FORMAT_SKOPS`` format.
 225          extra_files: {{ extra_files }}
 226  
 227      .. code-block:: python
 228          :caption: Example
 229  
 230          import mlflow.sklearn
 231          from sklearn.datasets import load_iris
 232          from sklearn import tree
 233  
 234          iris = load_iris()
 235          sk_model = tree.DecisionTreeClassifier()
 236          sk_model = sk_model.fit(iris.data, iris.target)
 237  
 238          # Save the model in cloudpickle format
 239          # set path to location for persistence
 240          sk_path_dir_1 = ...
 241          mlflow.sklearn.save_model(
 242              sk_model,
 243              sk_path_dir_1,
 244              serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE,
 245          )
 246  
 247          # save the model in pickle format
 248          # set path to location for persistence
 249          sk_path_dir_2 = ...
 250          mlflow.sklearn.save_model(
 251              sk_model,
 252              sk_path_dir_2,
 253              serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_PICKLE,
 254          )
 255      """
 256      import sklearn
 257  
 258      _validate_env_arguments(conda_env, pip_requirements, extra_pip_requirements)
 259  
 260      if serialization_format not in SUPPORTED_SERIALIZATION_FORMATS:
 261          raise MlflowException(
 262              message=(
 263                  f"Unrecognized serialization format: {serialization_format}. Please specify one"
 264                  f" of the following supported formats: {SUPPORTED_SERIALIZATION_FORMATS}."
 265              ),
 266              error_code=INVALID_PARAMETER_VALUE,
 267          )
 268  
 269      if serialization_format != SERIALIZATION_FORMAT_SKOPS and not is_in_databricks_runtime():
 270          _logger.warning(
 271              "Saving scikit-learn models in the pickle or cloudpickle format requires exercising "
 272              "caution because these formats rely on Python's object serialization mechanism, "
 273              "which can execute arbitrary code during deserialization. "
 274              "The recommended safe alternative is the 'skops' format. "
 275              "For more information, see: https://scikit-learn.org/stable/model_persistence.html",
 276          )
 277  
 278      _validate_and_prepare_target_save_path(path)
 279      code_path_subdir = _validate_and_copy_code_paths(code_paths, path)
 280  
 281      if mlflow_model is None:
 282          mlflow_model = Model()
 283      saved_example = _save_example(mlflow_model, input_example, path)
 284  
 285      if signature is None and saved_example is not None:
 286          wrapped_model = _SklearnModelWrapper(sk_model)
 287          signature = _infer_signature_from_input_example(saved_example, wrapped_model)
 288      elif signature is False:
 289          signature = None
 290  
 291      if signature is not None:
 292          mlflow_model.signature = signature
 293      if metadata is not None:
 294          mlflow_model.metadata = metadata
 295  
 296      if serialization_format == SERIALIZATION_FORMAT_SKOPS:
 297          model_data_subpath = _SKOPS_MODEL_DATA_SUBPATH
 298      else:
 299          model_data_subpath = _PICKLE_MODEL_DATA_SUBPATH
 300      model_data_path = os.path.join(path, model_data_subpath)
 301      _save_model(
 302          sk_model=sk_model,
 303          output_path=model_data_path,
 304          serialization_format=serialization_format,
 305          skops_trusted_types=skops_trusted_types,
 306      )
 307  
 308      extra_files_config = _copy_extra_files(extra_files, path)
 309  
 310      # `PyFuncModel` only works for sklearn models that define a predict function
 311  
 312      if hasattr(sk_model, pyfunc_predict_fn):
 313          pyfunc.add_to_model(
 314              mlflow_model,
 315              loader_module="mlflow.sklearn",
 316              model_path=model_data_subpath,
 317              conda_env=_CONDA_ENV_FILE_NAME,
 318              python_env=_PYTHON_ENV_FILE_NAME,
 319              code=code_path_subdir,
 320              predict_fn=pyfunc_predict_fn,
 321          )
 322      else:
 323          _logger.warning(
 324              f"Model was missing function: {pyfunc_predict_fn}. Not logging python_function flavor!"
 325          )
 326  
 327      mlflow_model.add_flavor(
 328          FLAVOR_NAME,
 329          pickled_model=model_data_subpath,
 330          sklearn_version=sklearn.__version__,
 331          serialization_format=serialization_format,
 332          code=code_path_subdir,
 333          skops_trusted_types=skops_trusted_types,
 334          **extra_files_config,
 335      )
 336      if size := get_total_file_size(path):
 337          mlflow_model.model_size_bytes = size
 338      mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME))
 339  
 340      if conda_env is None:
 341          if pip_requirements is None:
 342              include_cloudpickle = serialization_format == SERIALIZATION_FORMAT_CLOUDPICKLE
 343              include_skops = serialization_format == SERIALIZATION_FORMAT_SKOPS
 344              default_reqs = get_default_pip_requirements(
 345                  include_cloudpickle=include_cloudpickle,
 346                  include_skops=include_skops,
 347              )
 348              # To ensure `_load_pyfunc` can successfully load the model during the dependency
 349              # inference, `mlflow_model.save` must be called beforehand to save an MLmodel file.
 350              inferred_reqs = mlflow.models.infer_pip_requirements(
 351                  model_data_path,
 352                  FLAVOR_NAME,
 353                  fallback=default_reqs,
 354              )
 355              default_reqs = sorted(set(inferred_reqs).union(default_reqs))
 356          else:
 357              default_reqs = None
 358          conda_env, pip_requirements, pip_constraints = _process_pip_requirements(
 359              default_reqs,
 360              pip_requirements,
 361              extra_pip_requirements,
 362          )
 363      else:
 364          conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env)
 365  
 366      with open(os.path.join(path, _CONDA_ENV_FILE_NAME), "w") as f:
 367          yaml.safe_dump(conda_env, stream=f, default_flow_style=False)
 368  
 369      # Save `constraints.txt` if necessary
 370      if pip_constraints:
 371          write_to(os.path.join(path, _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints))
 372  
 373      # Save `requirements.txt`
 374      write_to(os.path.join(path, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements))
 375  
 376      _PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME))
 377  
 378  
 379  @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="scikit-learn"))
 380  def log_model(
 381      sk_model,
 382      artifact_path: str | None = None,
 383      conda_env=None,
 384      code_paths=None,
 385      serialization_format=SERIALIZATION_FORMAT_CLOUDPICKLE,
 386      registered_model_name=None,
 387      signature: ModelSignature = None,
 388      input_example: ModelInputExample = None,
 389      await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
 390      pip_requirements=None,
 391      extra_pip_requirements=None,
 392      pyfunc_predict_fn="predict",
 393      metadata=None,
 394      extra_files=None,
 395      params: dict[str, Any] | None = None,
 396      tags: dict[str, Any] | None = None,
 397      model_type: str | None = None,
 398      step: int = 0,
 399      model_id: str | None = None,
 400      name: str | None = None,
 401      skops_trusted_types: list[str] | None = None,
 402      **kwargs,
 403  ):
 404      """
 405      Log a scikit-learn model as an MLflow artifact for the current run. Produces an MLflow Model
 406      containing the following flavors:
 407  
 408          - :py:mod:`mlflow.sklearn`
 409          - :py:mod:`mlflow.pyfunc`. NOTE: This flavor is only included for scikit-learn models
 410            that define `predict()`, since `predict()` is required for pyfunc model inference.
 411  
 412      Args:
 413          sk_model: scikit-learn model to be saved.
 414          artifact_path: Deprecated. Use `name` instead.
 415          conda_env: {{ conda_env }}
 416          code_paths: {{ code_paths }}
 417          serialization_format: The format in which to serialize the model. This should be one of
 418              the formats "skops", "cloudpickle" or "pickle".
 419              The "skops" format guarantees safe deserialization.
 420              The "cloudpickle" format, provides better cross-system compatibility by identifying and
 421              packaging code dependencies with the serialized model, but requires exercising
 422              caution because these formats rely on Python's object serialization mechanism,
 423              which can execute arbitrary code during deserialization.
 424          registered_model_name: If given, create a model version under
 425              ``registered_model_name``, also creating a registered model if one
 426              with the given name does not exist.
 427          signature: {{ signature }}
 428          input_example: {{ input_example }}
 429          await_registration_for: Number of seconds to wait for the model version to finish
 430              being created and is in ``READY`` status. By default, the function
 431              waits for five minutes. Specify 0 or None to skip waiting.
 432          pip_requirements: {{ pip_requirements }}
 433          extra_pip_requirements: {{ extra_pip_requirements }}
 434          pyfunc_predict_fn: The name of the prediction function to use for inference with the
 435              pyfunc representation of the resulting MLflow Model. Current supported functions
 436              are: ``"predict"``, ``"predict_proba"``, ``"predict_log_proba"``,
 437              ``"predict_joint_log_proba"``, and ``"score"``.
 438          metadata: {{ metadata }}
 439          extra_files: {{ extra_files }}
 440          params: {{ params }}
 441          tags: {{ tags }}
 442          model_type: {{ model_type }}
 443          step: {{ step }}
 444          model_id: {{ model_id }}
 445          name: {{ name }}
 446          skops_trusted_types: A list of trusted types when loading model that is saved as
 447              the ``mlflow.sklearn.SERIALIZATION_FORMAT_SKOPS`` format.
 448          kwargs: Extra arguments to pass to :py:func:`mlflow.models.Model.log`.
 449  
 450      Returns:
 451          A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
 452          metadata of the logged model.
 453  
 454      .. code-block:: python
 455          :caption: Example
 456  
 457          import mlflow
 458          import mlflow.sklearn
 459          from mlflow.models import infer_signature
 460          from sklearn.datasets import load_iris
 461          from sklearn import tree
 462  
 463          with mlflow.start_run():
 464              # load dataset and train model
 465              iris = load_iris()
 466              sk_model = tree.DecisionTreeClassifier()
 467              sk_model = sk_model.fit(iris.data, iris.target)
 468  
 469              # log model params
 470              mlflow.log_param("criterion", sk_model.criterion)
 471              mlflow.log_param("splitter", sk_model.splitter)
 472              signature = infer_signature(iris.data, sk_model.predict(iris.data))
 473  
 474              # log model
 475              mlflow.sklearn.log_model(sk_model, name="sk_models", signature=signature)
 476  
 477      """
 478      return Model.log(
 479          artifact_path=artifact_path,
 480          name=name,
 481          flavor=mlflow.sklearn,
 482          sk_model=sk_model,
 483          conda_env=conda_env,
 484          code_paths=code_paths,
 485          serialization_format=serialization_format,
 486          registered_model_name=registered_model_name,
 487          signature=signature,
 488          input_example=input_example,
 489          await_registration_for=await_registration_for,
 490          pip_requirements=pip_requirements,
 491          extra_pip_requirements=extra_pip_requirements,
 492          pyfunc_predict_fn=pyfunc_predict_fn,
 493          metadata=metadata,
 494          extra_files=extra_files,
 495          params=params,
 496          tags=tags,
 497          model_type=model_type,
 498          step=step,
 499          model_id=model_id,
 500          skops_trusted_types=skops_trusted_types,
 501          **kwargs,
 502      )
 503  
 504  
 505  def _load_model_from_local_file(path, serialization_format, skops_trusted_types=None):
 506      """Load a scikit-learn model saved as an MLflow artifact on the local file system.
 507  
 508      Args:
 509          path: Local filesystem path to the MLflow Model saved with the ``sklearn`` flavor
 510          serialization_format: The format in which the model was serialized. This should be one of
 511              the following: ``mlflow.sklearn.SERIALIZATION_FORMAT_PICKLE`` or
 512              ``mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE``.
 513      """
 514      # TODO: we could validate the scikit-learn version here
 515      if serialization_format not in SUPPORTED_SERIALIZATION_FORMATS:
 516          raise MlflowException(
 517              message=(
 518                  f"Unrecognized serialization format: {serialization_format}. Please specify one"
 519                  f" of the following supported formats: {SUPPORTED_SERIALIZATION_FORMATS}."
 520              ),
 521              error_code=INVALID_PARAMETER_VALUE,
 522          )
 523  
 524      if serialization_format != SERIALIZATION_FORMAT_SKOPS:
 525          if (
 526              not MLFLOW_ALLOW_PICKLE_DESERIALIZATION.get()
 527              and not is_in_databricks_runtime()
 528              and not is_in_databricks_model_serving_environment()
 529          ):
 530              raise MlflowException(
 531                  "Deserializing model using pickle is disallowed, but this model is saved "
 532                  "in pickle format. To address this issue, you need to set environment variable "
 533                  "'MLFLOW_ALLOW_PICKLE_DESERIALIZATION' to 'true', or save the model in "
 534                  "'skops' format."
 535              )
 536  
 537      if serialization_format == SERIALIZATION_FORMAT_SKOPS:
 538          import skops.io
 539  
 540          return skops.io.load(path, trusted=skops_trusted_types)
 541      else:
 542          with open(path, "rb") as f:
 543              # Models serialized with Cloudpickle cannot necessarily be deserialized using Pickle;
 544              # That's why we check the serialization format of the model before deserializing
 545              if serialization_format == SERIALIZATION_FORMAT_PICKLE:
 546                  return pickle.load(f)
 547              elif serialization_format == SERIALIZATION_FORMAT_CLOUDPICKLE:
 548                  import cloudpickle
 549  
 550                  return cloudpickle.load(f)
 551  
 552  
 553  def _load_pyfunc(path):
 554      """
 555      Load PyFunc implementation. Called by ``pyfunc.load_model``.
 556  
 557      Args:
 558          path: Local filesystem path to the MLflow Model with the ``sklearn`` flavor.
 559      """
 560      # When ``path`` is a file, it refers directly to a serialized scikit-learn model
 561      # object (e.g., model.pkl or model.skops). The MLmodel file in the parent directory
 562      # contains the serialization format and other flavor configuration.
 563      model_dir = os.path.dirname(path) if os.path.isfile(path) else path
 564  
 565      try:
 566          sklearn_flavor_conf = _get_flavor_configuration(
 567              model_path=model_dir, flavor_name=FLAVOR_NAME
 568          )
 569          serialization_format = sklearn_flavor_conf.get(
 570              "serialization_format", SERIALIZATION_FORMAT_PICKLE
 571          )
 572          skops_trusted_types = sklearn_flavor_conf.get("skops_trusted_types", None)
 573      except MlflowException:
 574          _logger.warning(
 575              "Could not find scikit-learn flavor configuration during model loading process."
 576              " Assuming 'pickle' serialization format."
 577          )
 578          serialization_format = SERIALIZATION_FORMAT_PICKLE
 579          skops_trusted_types = None
 580  
 581      if not os.path.isfile(path):
 582          pyfunc_flavor_conf = _get_flavor_configuration(
 583              model_path=path, flavor_name=pyfunc.FLAVOR_NAME
 584          )
 585          path = os.path.join(path, pyfunc_flavor_conf["model_path"])
 586  
 587      return _SklearnModelWrapper(
 588          _load_model_from_local_file(
 589              path=path,
 590              serialization_format=serialization_format,
 591              skops_trusted_types=skops_trusted_types,
 592          )
 593      )
 594  
 595  
 596  class _SklearnModelWrapper:
 597      _SUPPORTED_CUSTOM_PREDICT_FN = [
 598          "predict_proba",
 599          "predict_log_proba",
 600          "predict_joint_log_proba",
 601          "score",
 602      ]
 603  
 604      def __init__(self, sklearn_model):
 605          self.sklearn_model = sklearn_model
 606  
 607          # Patch the model with custom predict functions that can be specified
 608          # via `pyfunc_predict_fn` argument when saving or logging.
 609          for predict_fn in self._SUPPORTED_CUSTOM_PREDICT_FN:
 610              if fn := getattr(self.sklearn_model, predict_fn, None):
 611                  setattr(self, predict_fn, fn)
 612  
 613      def get_raw_model(self):
 614          """
 615          Returns the underlying scikit-learn model.
 616          """
 617          return self.sklearn_model
 618  
 619      def predict(
 620          self,
 621          data,
 622          params: dict[str, Any] | None = None,
 623      ):
 624          """
 625          Args:
 626              data: Model input data.
 627              params: Additional parameters to pass to the model for inference.
 628  
 629          Returns:
 630              Model predictions.
 631          """
 632          return self.sklearn_model.predict(data)
 633  
 634  
 635  class _SklearnCustomModelPicklingError(pickle.PicklingError):
 636      """
 637      Exception for describing error raised during pickling custom sklearn estimator
 638      """
 639  
 640      def __init__(self, sk_model, original_exception):
 641          """
 642          Args:
 643              sk_model: The custom sklearn model to be pickled
 644              original_exception: The original exception raised
 645          """
 646          super().__init__(
 647              f"Pickling custom sklearn model {sk_model.__class__.__name__} failed "
 648              f"when saving model: {original_exception}"
 649          )
 650          self.original_exception = original_exception
 651  
 652  
 653  def _dump_model(pickle_lib, sk_model, out):
 654      try:
 655          # Using python's default protocol to optimize compatibility.
 656          # Otherwise cloudpickle uses latest protocol leading to incompatibilities.
 657          # See https://github.com/mlflow/mlflow/issues/5419
 658          pickle_lib.dump(sk_model, out, protocol=pickle.DEFAULT_PROTOCOL)
 659      except (pickle.PicklingError, TypeError, AttributeError) as e:
 660          if sk_model.__class__ not in _gen_estimators_to_patch():
 661              raise _SklearnCustomModelPicklingError(sk_model, e)
 662          else:
 663              raise
 664  
 665  
 666  def _save_model(sk_model, output_path, serialization_format, skops_trusted_types):
 667      """
 668      Args:
 669          sk_model: The scikit-learn model to serialize.
 670          output_path: The file path to which to write the serialized model.
 671          serialization_format: The format in which to serialize the model. This should be one of
 672              the following: ``mlflow.sklearn.SERIALIZATION_FORMAT_PICKLE`` or
 673              ``mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE``.
 674          skops_trusted_types: A list of trusted types when loading model that is saved as
 675              the ``mlflow.sklearn.SERIALIZATION_FORMAT_SKOPS`` format.
 676      """
 677      if serialization_format == SERIALIZATION_FORMAT_SKOPS:
 678          import skops.io
 679          from skops.io.exceptions import UntrustedTypesFoundException
 680  
 681          try:
 682              skops.io.dump(sk_model, output_path)
 683              skops.io.load(output_path, trusted=skops_trusted_types)
 684          except UntrustedTypesFoundException as e:
 685              shutil.rmtree(output_path, ignore_errors=True)
 686              raise MlflowException(
 687                  "The saved sklearn model references untrusted types. "
 688                  "If you are sure loading these types is safe, "
 689                  "set the 'skops_trusted_types' parameter when calling 'log_model' or 'save_model' "
 690                  "to the list of trusted types. "
 691                  f"Root error: {e!s}"
 692              )
 693          except Exception as e:
 694              shutil.rmtree(output_path, ignore_errors=True)
 695              raise MlflowException(
 696                  "The sklearn model could not be serialized in the skops serialization format. "
 697                  "skops does not support custom functions or classes that are not defined at the "
 698                  "top level. To work around this limitation, you can set the serialization_format "
 699                  "'cloudpickle', while exercising caution due to the possible arbitrary "
 700                  "code during model deserialization using CloudPickle."
 701              ) from e
 702          return
 703  
 704      with open(output_path, "wb") as out:
 705          if serialization_format == SERIALIZATION_FORMAT_PICKLE:
 706              _dump_model(pickle, sk_model, out)
 707          elif serialization_format == SERIALIZATION_FORMAT_CLOUDPICKLE:
 708              import cloudpickle
 709  
 710              _dump_model(cloudpickle, sk_model, out)
 711          else:
 712              raise MlflowException(
 713                  message=f"Unrecognized serialization format: {serialization_format}",
 714                  error_code=INTERNAL_ERROR,
 715              )
 716  
 717  
 718  def load_model(model_uri, dst_path=None):
 719      """
 720      Load a scikit-learn model from a local file or a run.
 721  
 722      Args:
 723          model_uri: The location, in URI format, of the MLflow model, for example:
 724  
 725              - ``/Users/me/path/to/local/model``
 726              - ``relative/path/to/local/model``
 727              - ``s3://my_bucket/path/to/model``
 728              - ``runs:/<mlflow_run_id>/run-relative/path/to/model``
 729              - ``models:/<model_name>/<model_version>``
 730              - ``models:/<model_name>/<stage>``
 731  
 732              For more information about supported URI schemes, see
 733              `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html#
 734              artifact-locations>`_.
 735          dst_path: The local filesystem path to which to download the model artifact.
 736              This directory must already exist. If unspecified, a local output
 737              path will be created.
 738  
 739      Returns:
 740          A scikit-learn model.
 741  
 742      .. code-block:: python
 743          :caption: Example
 744  
 745          import mlflow.sklearn
 746  
 747          sk_model = mlflow.sklearn.load_model("runs:/96771d893a5e46159d9f3b49bf9013e2/sk_models")
 748  
 749          # use Pandas DataFrame to make predictions
 750          pandas_df = ...
 751          predictions = sk_model.predict(pandas_df)
 752      """
 753      local_model_path = _download_artifact_from_uri(artifact_uri=model_uri, output_path=dst_path)
 754      flavor_conf = _get_flavor_configuration(model_path=local_model_path, flavor_name=FLAVOR_NAME)
 755      _add_code_from_conf_to_system_path(local_model_path, flavor_conf)
 756      sklearn_model_artifacts_path = os.path.join(local_model_path, flavor_conf["pickled_model"])
 757      serialization_format = flavor_conf.get("serialization_format", SERIALIZATION_FORMAT_PICKLE)
 758      skops_trusted_types = flavor_conf.get("skops_trusted_types", None)
 759      return _load_model_from_local_file(
 760          path=sklearn_model_artifacts_path,
 761          serialization_format=serialization_format,
 762          skops_trusted_types=skops_trusted_types,
 763      )
 764  
 765  
 766  # The `_apis_autologging_disabled` contains APIs which is incompatible with autologging,
 767  # when user call these APIs, autolog is temporarily disabled.
 768  _apis_autologging_disabled = [
 769      "cross_validate",
 770      "cross_val_predict",
 771      "cross_val_score",
 772      "learning_curve",
 773      "permutation_test_score",
 774      "validation_curve",
 775  ]
 776  
 777  
 778  class _AutologgingMetricsManager:
 779      """
 780      This class is designed for holding information which is used by autologging metrics
 781      It will hold information of:
 782      (1) a map of "prediction result object id" to a tuple of dataset name(the dataset is
 783         the one which generate the prediction result) and run_id.
 784         Note: We need this map instead of setting the run_id into the "prediction result object"
 785         because the object maybe a numpy array which does not support additional attribute
 786         assignment.
 787      (2) _log_post_training_metrics_enabled flag, in the following method scope:
 788         `model.fit` and `model.score`, in order to avoid nested/duplicated autologging metric, when
 789         run into these scopes, we need temporarily disable the metric autologging.
 790      (3) _eval_dataset_info_map, it is a double level map:
 791         `_eval_dataset_info_map[run_id][eval_dataset_var_name]` will get a list, each
 792         element in the list is an id of "eval_dataset" instance.
 793         This data structure is used for:
 794          * generating unique dataset name key when autologging metric. For each eval dataset object,
 795            if they have the same eval_dataset_var_name, but object ids are different,
 796            then they will be assigned different name (via appending index to the
 797            eval_dataset_var_name) when autologging.
 798      (4) _metric_api_call_info, it is a double level map:
 799         `_metric_api_call_info[run_id][metric_name]` will get a list of tuples, each tuple is:
 800           (logged_metric_key, metric_call_command_string)
 801          each call command string is like `metric_fn(arg1, arg2, ...)`
 802          This data structure is used for:
 803           * storing the call arguments dict for each metric call, we need log them into metric_info
 804             artifact file.
 805  
 806      Note: this class is not thread-safe.
 807      Design rule for this class:
 808       Because this class instance is a global instance, in order to prevent memory leak, it should
 809       only holds IDs and other small objects references. This class internal data structure should
 810       avoid reference to user dataset variables or model variables.
 811      """
 812  
 813      def __init__(self):
 814          self._pred_result_id_mapping = {}
 815          self._eval_dataset_info_map = defaultdict(lambda: defaultdict(list))
 816          self._metric_api_call_info = defaultdict(lambda: defaultdict(list))
 817          self._log_post_training_metrics_enabled = True
 818          self._metric_info_artifact_need_update = defaultdict(lambda: False)
 819          self._model_id_mapping = {}
 820  
 821      def should_log_post_training_metrics(self):
 822          """
 823          Check whether we should run patching code for autologging post training metrics.
 824          This checking should surround the whole patched code due to the safe guard checking,
 825          See following note.
 826  
 827          Note: It includes checking `_SklearnTrainingSession.is_active()`, This is a safe guarding
 828          for meta-estimator (e.g. GridSearchCV) case:
 829            running GridSearchCV.fit, the nested `estimator.fit` will be called in parallel,
 830            but, the _autolog_training_status is a global status without thread-safe lock protecting.
 831            This safe guarding will prevent code run into this case.
 832          """
 833          return not _SklearnTrainingSession.is_active() and self._log_post_training_metrics_enabled
 834  
 835      def disable_log_post_training_metrics(self):
 836          class LogPostTrainingMetricsDisabledScope:
 837              def __enter__(inner_self):
 838                  inner_self.old_status = self._log_post_training_metrics_enabled
 839                  self._log_post_training_metrics_enabled = False
 840  
 841              def __exit__(inner_self, exc_type, exc_val, exc_tb):
 842                  self._log_post_training_metrics_enabled = inner_self.old_status
 843  
 844          return LogPostTrainingMetricsDisabledScope()
 845  
 846      @staticmethod
 847      def get_run_id_for_model(model):
 848          return getattr(model, "_mlflow_run_id", None)
 849  
 850      @staticmethod
 851      def is_metric_value_loggable(metric_value):
 852          """
 853          Check whether the specified `metric_value` is a numeric value which can be logged
 854          as an MLflow metric.
 855          """
 856          return isinstance(metric_value, (int, float, np.number)) and not isinstance(
 857              metric_value, bool
 858          )
 859  
 860      def register_model(self, model, run_id):
 861          """
 862          In `patched_fit`, we need register the model with the run_id used in `patched_fit`
 863          So that in following metric autologging, the metric will be logged into the registered
 864          run_id
 865          """
 866          model._mlflow_run_id = run_id
 867  
 868      def record_model_id(self, model, model_id):
 869          """
 870          Record the id(model) -> model_id mapping so that we can log metrics to the
 871          model later.
 872          """
 873          self._model_id_mapping[id(model)] = model_id
 874  
 875      def get_model_id_for_model(self, model) -> str | None:
 876          return self._model_id_mapping.get(id(model))
 877  
 878      @staticmethod
 879      def gen_name_with_index(name, index):
 880          assert index >= 0
 881          if index == 0:
 882              return name
 883          else:
 884              # Use '-' as the separator between name and index,
 885              # The '-' is not valid character in python var name
 886              # so it can prevent name conflicts after appending index.
 887              return f"{name}-{index + 1}"
 888  
 889      def register_prediction_input_dataset(self, model, eval_dataset):
 890          """
 891          Register prediction input dataset into eval_dataset_info_map, it will do:
 892           1. inspect eval dataset var name.
 893           2. check whether eval_dataset_info_map already registered this eval dataset.
 894              will check by object id.
 895           3. register eval dataset with id.
 896           4. return eval dataset name with index.
 897  
 898          Note: this method include inspecting argument variable name.
 899           So should be called directly from the "patched method", to ensure it capture
 900           correct argument variable name.
 901          """
 902          eval_dataset_name = _inspect_original_var_name(
 903              eval_dataset, fallback_name="unknown_dataset"
 904          )
 905          eval_dataset_id = id(eval_dataset)
 906  
 907          run_id = self.get_run_id_for_model(model)
 908          registered_dataset_list = self._eval_dataset_info_map[run_id][eval_dataset_name]
 909  
 910          for i, id_i in enumerate(registered_dataset_list):
 911              if eval_dataset_id == id_i:
 912                  index = i
 913                  break
 914          else:
 915              index = len(registered_dataset_list)
 916  
 917          if index == len(registered_dataset_list):
 918              # register new eval dataset
 919              registered_dataset_list.append(eval_dataset_id)
 920  
 921          return self.gen_name_with_index(eval_dataset_name, index)
 922  
 923      def register_prediction_result(self, run_id, eval_dataset_name, predict_result, model_id=None):
 924          """
 925          Register the relationship
 926           id(prediction_result) --> (eval_dataset_name, run_id, model_id)
 927          into map `_pred_result_id_mapping`
 928          """
 929          value = (eval_dataset_name, run_id, model_id)
 930          prediction_result_id = id(predict_result)
 931          self._pred_result_id_mapping[prediction_result_id] = value
 932  
 933          def clean_id(id_):
 934              _AUTOLOGGING_METRICS_MANAGER._pred_result_id_mapping.pop(id_, None)
 935  
 936          # When the `predict_result` object being GCed, its ID may be reused, so register a finalizer
 937          # to clear the ID from the dict for preventing wrong ID mapping.
 938          weakref.finalize(predict_result, clean_id, prediction_result_id)
 939  
 940      @staticmethod
 941      def gen_metric_call_command(self_obj, metric_fn, *call_pos_args, **call_kwargs):
 942          """
 943          Generate metric function call command string like `metric_fn(arg1, arg2, ...)`
 944          Note: this method include inspecting argument variable name.
 945           So should be called directly from the "patched method", to ensure it capture
 946           correct argument variable name.
 947  
 948          Args:
 949              self_obj: If the metric_fn is a method of an instance (e.g. `model.score`),
 950                  the `self_obj` represent the instance.
 951              metric_fn: metric function.
 952              call_pos_args: the positional arguments of the metric function call. If `metric_fn`
 953                  is instance method, then the `call_pos_args` should exclude the first `self`
 954                  argument.
 955              call_kwargs: the keyword arguments of the metric function call.
 956          """
 957  
 958          arg_list = []
 959  
 960          def arg_to_str(arg):
 961              if arg is None or np.isscalar(arg):
 962                  if isinstance(arg, str) and len(arg) > 32:
 963                      # truncate too long string
 964                      return repr(arg[:32] + "...")
 965                  return repr(arg)
 966              else:
 967                  # dataset arguments or other non-scalar type argument
 968                  return _inspect_original_var_name(arg, fallback_name=f"<{arg.__class__.__name__}>")
 969  
 970          param_sig = inspect.signature(metric_fn).parameters
 971          arg_names = list(param_sig.keys())
 972  
 973          if self_obj is not None:
 974              # If metric_fn is a method of an instance, e.g. `model.score`,
 975              # then the first argument is `self` which we need exclude it.
 976              arg_names.pop(0)
 977  
 978          if self_obj is not None:
 979              call_fn_name = f"{self_obj.__class__.__name__}.{metric_fn.__name__}"
 980          else:
 981              call_fn_name = metric_fn.__name__
 982  
 983          # Attach param signature key for positinal param values
 984          for arg_name, arg in zip(arg_names, call_pos_args):
 985              arg_list.append(f"{arg_name}={arg_to_str(arg)}")
 986  
 987          for arg_name, arg in call_kwargs.items():
 988              arg_list.append(f"{arg_name}={arg_to_str(arg)}")
 989  
 990          arg_list_str = ", ".join(arg_list)
 991  
 992          return f"{call_fn_name}({arg_list_str})"
 993  
 994      def register_metric_api_call(self, run_id, metric_name, dataset_name, call_command):
 995          """
 996          This method will do:
 997          (1) Generate and return metric key, format is:
 998            {metric_name}[-{call_index}]_{eval_dataset_name}
 999            metric_name is generated by metric function name, if multiple calls on the same
1000            metric API happen, the following calls will be assigned with an increasing "call index".
1001          (2) Register the metric key with the "call command" information into
1002            `_AUTOLOGGING_METRICS_MANAGER`. See doc of `gen_metric_call_command` method for
1003            details of "call command".
1004          """
1005  
1006          call_cmd_list = self._metric_api_call_info[run_id][metric_name]
1007  
1008          index = len(call_cmd_list)
1009          metric_name_with_index = self.gen_name_with_index(metric_name, index)
1010          metric_key = f"{metric_name_with_index}_{dataset_name}"
1011  
1012          call_cmd_list.append((metric_key, call_command))
1013  
1014          # Set the flag to true, represent the metric info in this run need update.
1015          # Later when `log_eval_metric` called, it will generate a new metric_info artifact
1016          # and overwrite the old artifact.
1017          self._metric_info_artifact_need_update[run_id] = True
1018          return metric_key
1019  
1020      def get_info_for_metric_api_call(self, call_pos_args, call_kwargs):
1021          """
1022          Given a metric api call (include the called metric function, and call arguments)
1023          Register the call information (arguments dict) into the `metric_api_call_arg_dict_list_map`
1024          and return a tuple of (run_id, eval_dataset_name, model_id)
1025          """
1026          call_arg_list = list(call_pos_args) + list(call_kwargs.values())
1027  
1028          dataset_id_list = self._pred_result_id_mapping.keys()
1029  
1030          # Note: some metric API the arguments is not like `y_true`, `y_pred`
1031          #  e.g.
1032          #    https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score
1033          #    https://scikit-learn.org/stable/modules/generated/sklearn.metrics.silhouette_score.html#sklearn.metrics.silhouette_score
1034          for arg in call_arg_list:
1035              if arg is not None and not np.isscalar(arg) and id(arg) in dataset_id_list:
1036                  dataset_name, run_id, model_id = self._pred_result_id_mapping[id(arg)]
1037                  break
1038          else:
1039              return None, None, None
1040  
1041          return run_id, dataset_name, model_id
1042  
1043      def log_post_training_metric(self, run_id, key, value, model_id=None):
1044          """
1045          Log the metric into the specified mlflow run.
1046          and it will also update the metric_info artifact if needed.
1047          If model_id is not None, metrics are logged into the model as well.
1048          """
1049          # Note: if the case log the same metric key multiple times,
1050          #  newer value will overwrite old value
1051          client = MlflowClient()
1052          client.log_metric(run_id=run_id, key=key, value=value, model_id=model_id)
1053          if self._metric_info_artifact_need_update[run_id]:
1054              call_commands_list = []
1055              for v in self._metric_api_call_info[run_id].values():
1056                  call_commands_list.extend(v)
1057  
1058              call_commands_list.sort(key=lambda x: x[0])
1059              dict_to_log = OrderedDict(call_commands_list)
1060              client.log_dict(run_id=run_id, dictionary=dict_to_log, artifact_file="metric_info.json")
1061              self._metric_info_artifact_need_update[run_id] = False
1062  
1063  
1064  # The global `_AutologgingMetricsManager` instance which holds information used in
1065  # post-training metric autologging. See doc of class `_AutologgingMetricsManager` for details.
1066  _AUTOLOGGING_METRICS_MANAGER = _AutologgingMetricsManager()
1067  
1068  
1069  _metric_api_excluding_list = ["check_scoring", "get_scorer", "make_scorer", "get_scorer_names"]
1070  
1071  
1072  def _get_metric_name_list():
1073      """
1074      Return metric function name list in `sklearn.metrics` module
1075      """
1076      from sklearn import metrics
1077  
1078      metric_list = []
1079      for metric_method_name in metrics.__all__:
1080          # excludes plot_* methods
1081          # exclude class (e.g. metrics.ConfusionMatrixDisplay)
1082          metric_method = getattr(metrics, metric_method_name)
1083          if (
1084              metric_method_name not in _metric_api_excluding_list
1085              and not inspect.isclass(metric_method)
1086              and callable(metric_method)
1087              and not metric_method_name.startswith("plot_")
1088          ):
1089              metric_list.append(metric_method_name)
1090      return metric_list
1091  
1092  
1093  def _patch_estimator_method_if_available(
1094      flavor_name, class_def, func_name, patched_fn, manage_run, extra_tags=None
1095  ):
1096      if not hasattr(class_def, func_name):
1097          return
1098  
1099      original = gorilla.get_original_attribute(
1100          class_def, func_name, bypass_descriptor_protocol=False
1101      )
1102      # Retrieve raw attribute while bypassing the descriptor protocol
1103      raw_original_obj = gorilla.get_original_attribute(
1104          class_def, func_name, bypass_descriptor_protocol=True
1105      )
1106      if raw_original_obj == original and (callable(original) or isinstance(original, property)):
1107          # normal method or property decorated method
1108          safe_patch(
1109              flavor_name,
1110              class_def,
1111              func_name,
1112              patched_fn,
1113              manage_run=manage_run,
1114              extra_tags=extra_tags,
1115          )
1116      elif hasattr(raw_original_obj, "delegate_names") or hasattr(raw_original_obj, "check"):
1117          # sklearn delegated method
1118          safe_patch(
1119              flavor_name,
1120              raw_original_obj,
1121              "fn",
1122              patched_fn,
1123              manage_run=manage_run,
1124              extra_tags=extra_tags,
1125          )
1126      else:
1127          # unsupported method type. skip patching
1128          pass
1129  
1130  
1131  @autologging_integration(FLAVOR_NAME)
1132  def autolog(
1133      log_input_examples=False,
1134      log_model_signatures=True,
1135      log_models=True,
1136      log_datasets=True,
1137      disable=False,
1138      exclusive=False,
1139      disable_for_unsupported_versions=False,
1140      silent=False,
1141      max_tuning_runs=5,
1142      log_post_training_metrics=True,
1143      serialization_format=SERIALIZATION_FORMAT_CLOUDPICKLE,
1144      registered_model_name=None,
1145      pos_label=None,
1146      extra_tags=None,
1147  ):
1148      """
1149      Enables (or disables) and configures autologging for scikit-learn estimators.
1150  
1151      **When is autologging performed?**
1152        Autologging is performed when you call:
1153  
1154        - ``estimator.fit()``
1155        - ``estimator.fit_predict()``
1156        - ``estimator.fit_transform()``
1157  
1158      **Logged information**
1159        **Parameters**
1160          - Parameters obtained by ``estimator.get_params(deep=True)``. Note that ``get_params``
1161            is called with ``deep=True``. This means when you fit a meta estimator that chains
1162            a series of estimators, the parameters of these child estimators are also logged.
1163  
1164        **Training metrics**
1165          - A training score obtained by ``estimator.score``. Note that the training score is
1166            computed using parameters given to ``fit()``.
1167          - Common metrics for classifier:
1168  
1169            - `precision score`_
1170  
1171            .. _precision score:
1172                https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html
1173  
1174            - `recall score`_
1175  
1176            .. _recall score:
1177                https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html
1178  
1179            - `f1 score`_
1180  
1181            .. _f1 score:
1182                https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html
1183  
1184            - `accuracy score`_
1185  
1186            .. _accuracy score:
1187                https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html
1188  
1189            If the classifier has method ``predict_proba``, we additionally log:
1190  
1191            - `log loss`_
1192  
1193            .. _log loss:
1194                https://scikit-learn.org/stable/modules/generated/sklearn.metrics.log_loss.html
1195  
1196            - `roc auc score`_
1197  
1198            .. _roc auc score:
1199                https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html
1200  
1201          - Common metrics for regressor:
1202  
1203            - `mean squared error`_
1204  
1205            .. _mean squared error:
1206                https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html
1207  
1208            - root mean squared error
1209  
1210            - `mean absolute error`_
1211  
1212            .. _mean absolute error:
1213                https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_error.html
1214  
1215            - `r2 score`_
1216  
1217            .. _r2 score:
1218                https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html
1219  
1220        .. _post training metrics:
1221  
1222        **Post training metrics**
1223          When users call metric APIs after model training, MLflow tries to capture the metric API
1224          results and log them as MLflow metrics to the Run associated with the model. The following
1225          types of scikit-learn metric APIs are supported:
1226  
1227          - model.score
1228          - metric APIs defined in the `sklearn.metrics` module
1229  
1230          For post training metrics autologging, the metric key format is:
1231          "{metric_name}[-{call_index}]_{dataset_name}"
1232  
1233          - If the metric function is from `sklearn.metrics`, the MLflow "metric_name" is the
1234            metric function name. If the metric function is `model.score`, then "metric_name" is
1235            "{model_class_name}_score".
1236          - If multiple calls are made to the same scikit-learn metric API, each subsequent call
1237            adds a "call_index" (starting from 2) to the metric key.
1238          - MLflow uses the prediction input dataset variable name as the "dataset_name" in the
1239            metric key. The "prediction input dataset variable" refers to the variable which was
1240            used as the first argument of the associated `model.predict` or `model.score` call.
1241            Note: MLflow captures the "prediction input dataset" instance in the outermost call
1242            frame and fetches the variable name in the outermost call frame. If the "prediction
1243            input dataset" instance is an intermediate expression without a defined variable
1244            name, the dataset name is set to "unknown_dataset". If multiple "prediction input
1245            dataset" instances have the same variable name, then subsequent ones will append an
1246            index (starting from 2) to the inspected dataset name.
1247  
1248          **Limitations**
1249             - MLflow can only map the original prediction result object returned by a model
1250               prediction API (including predict / predict_proba / predict_log_proba / transform,
1251               but excluding fit_predict / fit_transform.) to an MLflow run.
1252               MLflow cannot find run information
1253               for other objects derived from a given prediction result (e.g. by copying or selecting
1254               a subset of the prediction result). scikit-learn metric APIs invoked on derived objects
1255               do not log metrics to MLflow.
1256             - Autologging must be enabled before scikit-learn metric APIs are imported from
1257               `sklearn.metrics`. Metric APIs imported before autologging is enabled do not log
1258               metrics to MLflow runs.
1259             - If user define a scorer which is not based on metric APIs in `sklearn.metrics`, then
1260               then post training metric autologging for the scorer is invalid.
1261  
1262          **Tags**
1263            - An estimator class name (e.g. "LinearRegression").
1264            - A fully qualified estimator class name
1265              (e.g. "sklearn.linear_model._base.LinearRegression").
1266  
1267          **Artifacts**
1268            - An MLflow Model with the :py:mod:`mlflow.sklearn` flavor containing a fitted estimator
1269              (logged by :py:func:`mlflow.sklearn.log_model()`). The Model also contains the
1270              :py:mod:`mlflow.pyfunc` flavor when the scikit-learn estimator defines `predict()`.
1271            - For post training metrics API calls, a "metric_info.json" artifact is logged. This is a
1272              JSON object whose keys are MLflow post training metric names
1273              (see "Post training metrics" section for the key format) and whose values are the
1274              corresponding metric call commands that produced the metrics, e.g.
1275              ``accuracy_score(y_true=test_iris_y, y_pred=pred_iris_y, normalize=False)``.
1276  
1277      **How does autologging work for meta estimators?**
1278        When a meta estimator (e.g. `Pipeline`_, `GridSearchCV`_) calls ``fit()``, it internally calls
1279        ``fit()`` on its child estimators. Autologging does NOT perform logging on these constituent
1280        ``fit()`` calls.
1281  
1282        **Parameter search**
1283            In addition to recording the information discussed above, autologging for parameter
1284            search meta estimators (`GridSearchCV`_ and `RandomizedSearchCV`_) records child runs
1285            with metrics for each set of explored parameters, as well as artifacts and parameters
1286            for the best model (if available).
1287  
1288      **Supported estimators**
1289        - All estimators obtained by `sklearn.utils.all_estimators`_ (including meta estimators).
1290        - `Pipeline`_
1291        - Parameter search estimators (`GridSearchCV`_ and `RandomizedSearchCV`_)
1292  
1293      .. _sklearn.utils.all_estimators:
1294          https://scikit-learn.org/stable/modules/generated/sklearn.utils.all_estimators.html
1295  
1296      .. _Pipeline:
1297          https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html
1298  
1299      .. _GridSearchCV:
1300          https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html
1301  
1302      .. _RandomizedSearchCV:
1303          https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html
1304  
1305      **Example**
1306  
1307      `See more examples <https://github.com/mlflow/mlflow/blob/master/examples/sklearn_autolog>`_
1308  
1309      .. code-block:: python
1310  
1311          from pprint import pprint
1312          import numpy as np
1313          from sklearn.linear_model import LinearRegression
1314          import mlflow
1315          from mlflow import MlflowClient
1316  
1317  
1318          def fetch_logged_data(run_id):
1319              client = MlflowClient()
1320              data = client.get_run(run_id).data
1321              tags = {k: v for k, v in data.tags.items() if not k.startswith("mlflow.")}
1322              artifacts = [f.path for f in client.list_artifacts(run_id, "model")]
1323              return data.params, data.metrics, tags, artifacts
1324  
1325  
1326          # enable autologging
1327          mlflow.sklearn.autolog()
1328  
1329          # prepare training data
1330          X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
1331          y = np.dot(X, np.array([1, 2])) + 3
1332  
1333          # train a model
1334          model = LinearRegression()
1335          with mlflow.start_run() as run:
1336              model.fit(X, y)
1337  
1338          # fetch logged data
1339          params, metrics, tags, artifacts = fetch_logged_data(run.info.run_id)
1340  
1341          pprint(params)
1342          # {'copy_X': 'True',
1343          #  'fit_intercept': 'True',
1344          #  'n_jobs': 'None',
1345          #  'normalize': 'False'}
1346  
1347          pprint(metrics)
1348          # {'training_score': 1.0,
1349          #  'training_mean_absolute_error': 2.220446049250313e-16,
1350          #  'training_mean_squared_error': 1.9721522630525295e-31,
1351          #  'training_r2_score': 1.0,
1352          #  'training_root_mean_squared_error': 4.440892098500626e-16}
1353  
1354          pprint(tags)
1355          # {'estimator_class': 'sklearn.linear_model._base.LinearRegression',
1356          #  'estimator_name': 'LinearRegression'}
1357  
1358          pprint(artifacts)
1359          # ['model/MLmodel', 'model/conda.yaml', 'model/model.pkl']
1360  
1361      Args:
1362          log_input_examples: If ``True``, input examples from training datasets are collected and
1363              logged along with scikit-learn model artifacts during training. If
1364              ``False``, input examples are not logged.
1365              Note: Input examples are MLflow model attributes
1366              and are only collected if ``log_models`` is also ``True``.
1367          log_model_signatures: If ``True``,
1368              :py:class:`ModelSignatures <mlflow.models.ModelSignature>`
1369              describing model inputs and outputs are collected and logged along
1370              with scikit-learn model artifacts during training. If ``False``,
1371              signatures are not logged.
1372              Note: Model signatures are MLflow model attributes
1373              and are only collected if ``log_models`` is also ``True``.
1374          log_models: If ``True``, trained models are logged as MLflow model artifacts.
1375              If ``False``, trained models are not logged.
1376              Input examples and model signatures, which are attributes of MLflow models,
1377              are also omitted when ``log_models`` is ``False``.
1378          log_datasets: If ``True``, train and validation dataset information is logged to MLflow
1379              Tracking if applicable. If ``False``, dataset information is not logged.
1380          disable: If ``True``, disables the scikit-learn autologging integration. If ``False``,
1381              enables the scikit-learn autologging integration.
1382          exclusive: If ``True``, autologged content is not logged to user-created fluent runs.
1383              If ``False``, autologged content is logged to the active fluent run,
1384              which may be user-created.
1385          disable_for_unsupported_versions: If ``True``, disable autologging for versions of
1386              scikit-learn that have not been tested against this version of the MLflow
1387              client or are incompatible.
1388          silent: If ``True``, suppress all event logs and warnings from MLflow during scikit-learn
1389              autologging. If ``False``, show all events and warnings during scikit-learn
1390              autologging.
1391          max_tuning_runs: The maximum number of child MLflow runs created for hyperparameter
1392              search estimators. To create child runs for the best `k` results from
1393              the search, set `max_tuning_runs` to `k`. The default value is to track
1394              the best 5 search parameter sets. If `max_tuning_runs=None`, then
1395              a child run is created for each search parameter set. Note: The best k
1396              results is based on ordering in `rank_test_score`. In the case of
1397              multi-metric evaluation with a custom scorer, the first scorer's
1398              `rank_test_score_<scorer_name>` will be used to select the best k
1399              results. To change metric used for selecting best k results, change
1400              ordering of dict passed as `scoring` parameter for estimator.
1401          log_post_training_metrics: If ``True``, post training metrics are logged. Defaults to
1402              ``True``. See the `post training metrics`_ section for more
1403              details.
1404          serialization_format: The format in which to serialize the model. This should be one of
1405              the following: "pickle", "cloudpickle" or "skops".
1406          registered_model_name: If given, each time a model is trained, it is registered as a
1407              new model version of the registered model with this name.
1408              The registered model is created if it does not already exist.
1409          pos_label: If given, used as the positive label to compute binary classification
1410              training metrics such as precision, recall, f1, etc. This parameter should
1411              only be set for binary classification model. If used for multi-label model,
1412              the training metrics calculation will fail and the training metrics won't
1413              be logged. If used for regression model, the parameter will be ignored.
1414          extra_tags: A dictionary of extra tags to set on each managed run created by autologging.
1415      """
1416      _autolog(
1417          flavor_name=FLAVOR_NAME,
1418          log_input_examples=log_input_examples,
1419          log_model_signatures=log_model_signatures,
1420          log_models=log_models,
1421          log_datasets=log_datasets,
1422          disable=disable,
1423          exclusive=exclusive,
1424          disable_for_unsupported_versions=disable_for_unsupported_versions,
1425          silent=silent,
1426          max_tuning_runs=max_tuning_runs,
1427          log_post_training_metrics=log_post_training_metrics,
1428          serialization_format=serialization_format,
1429          pos_label=pos_label,
1430          extra_tags=extra_tags,
1431      )
1432  
1433  
1434  def _autolog(
1435      flavor_name=FLAVOR_NAME,
1436      log_input_examples=False,
1437      log_model_signatures=True,
1438      log_models=True,
1439      log_datasets=True,
1440      disable=False,
1441      exclusive=False,
1442      disable_for_unsupported_versions=False,
1443      silent=False,
1444      max_tuning_runs=5,
1445      log_post_training_metrics=True,
1446      serialization_format=SERIALIZATION_FORMAT_CLOUDPICKLE,
1447      pos_label=None,
1448      extra_tags=None,
1449  ):
1450      """
1451      Internal autologging function for scikit-learn models.
1452  
1453      Args:
1454          flavor_name: A string value. Enable a ``mlflow.sklearn`` autologging routine
1455              for a flavor. By default it enables autologging for original
1456              scikit-learn models, as ``mlflow.sklearn.autolog()`` does. If
1457              the argument is `xgboost`, autologging for XGBoost scikit-learn
1458              models is enabled.
1459      """
1460      import pandas as pd
1461      import sklearn.metrics
1462      import sklearn.model_selection
1463  
1464      from mlflow.models import infer_signature
1465      from mlflow.sklearn.utils import (
1466          _TRAINING_PREFIX,
1467          _create_child_runs_for_parameter_search,
1468          _gen_lightgbm_sklearn_estimators_to_patch,
1469          _gen_xgboost_sklearn_estimators_to_patch,
1470          _get_estimator_info_tags,
1471          _get_X_y_and_sample_weight,
1472          _is_parameter_search_estimator,
1473          _log_estimator_content,
1474          _log_parameter_search_results_as_artifact,
1475      )
1476      from mlflow.tracking.context import registry as context_registry
1477  
1478      if max_tuning_runs is not None and max_tuning_runs < 0:
1479          raise MlflowException(
1480              message=f"`max_tuning_runs` must be non-negative, instead got {max_tuning_runs}.",
1481              error_code=INVALID_PARAMETER_VALUE,
1482          )
1483  
1484      def fit_mlflow_xgboost_and_lightgbm(original, self, *args, **kwargs):
1485          """
1486          Autologging function for XGBoost and LightGBM scikit-learn models
1487          """
1488          # Obtain a copy of a model input example from the training dataset prior to model training
1489          # for subsequent use during model logging, ensuring that the input example and inferred
1490          # model signature to not include any mutations from model training
1491          input_example_exc = None
1492          try:
1493              input_example = deepcopy(
1494                  _get_X_y_and_sample_weight(self.fit, args, kwargs)[0][:INPUT_EXAMPLE_SAMPLE_ROWS]
1495              )
1496          except Exception as e:
1497              input_example_exc = e
1498  
1499          def get_input_example():
1500              if input_example_exc is not None:
1501                  raise input_example_exc
1502              else:
1503                  return input_example
1504  
1505          # parameter, metric, and non-model artifact logging are done in
1506          # `train()` in `mlflow.xgboost.autolog()` and `mlflow.lightgbm.autolog()`
1507          fit_output = original(self, *args, **kwargs)
1508          # log models after training
1509          if log_models:
1510              input_example, signature = resolve_input_example_and_signature(
1511                  get_input_example,
1512                  lambda input_example: infer_signature(
1513                      input_example,
1514                      # Copy the input example so that it is not mutated by the call to
1515                      # predict() prior to signature inference
1516                      self.predict(deepcopy(input_example)),
1517                  ),
1518                  log_input_examples,
1519                  log_model_signatures,
1520                  _logger,
1521              )
1522              log_model_func = (
1523                  mlflow.xgboost.log_model
1524                  if flavor_name == mlflow.xgboost.FLAVOR_NAME
1525                  else mlflow.lightgbm.log_model
1526              )
1527              registered_model_name = get_autologging_config(
1528                  flavor_name, "registered_model_name", None
1529              )
1530              if flavor_name == mlflow.xgboost.FLAVOR_NAME:
1531                  model_format = get_autologging_config(flavor_name, "model_format", "ubj")
1532                  model_info = log_model_func(
1533                      self,
1534                      "model",
1535                      signature=signature,
1536                      input_example=input_example,
1537                      registered_model_name=registered_model_name,
1538                      model_format=model_format,
1539                  )
1540              else:
1541                  model_info = log_model_func(
1542                      self,
1543                      "model",
1544                      signature=signature,
1545                      input_example=input_example,
1546                      registered_model_name=registered_model_name,
1547                  )
1548              _AUTOLOGGING_METRICS_MANAGER.record_model_id(self, model_info.model_id)
1549          return fit_output
1550  
1551      def fit_mlflow(original, self, *args, **kwargs):
1552          """
1553          Autologging function that performs model training by executing the training method
1554          referred to be `func_name` on the instance of `clazz` referred to by `self` & records
1555          MLflow parameters, metrics, tags, and artifacts to a corresponding MLflow Run.
1556          """
1557          # Obtain a copy of the training dataset prior to model training for subsequent
1558          # use during model logging & input example extraction, ensuring that we don't
1559          # attempt to infer input examples on data that was mutated during training
1560          (X, y_true, sample_weight) = _get_X_y_and_sample_weight(self.fit, args, kwargs)
1561          autologging_client = MlflowAutologgingQueueingClient()
1562          _log_pretraining_metadata(autologging_client, self, X, y_true)
1563          params_logging_future = autologging_client.flush(synchronous=False)
1564          fit_output = original(self, *args, **kwargs)
1565          _log_posttraining_metadata(autologging_client, self, X, y_true, sample_weight)
1566          autologging_client.flush(synchronous=True)
1567          params_logging_future.await_completion()
1568          return fit_output
1569  
1570      def _log_pretraining_metadata(autologging_client, estimator, X, y):
1571          """
1572          Records metadata (e.g., params and tags) for a scikit-learn estimator prior to training.
1573          This is intended to be invoked within a patched scikit-learn training routine
1574          (e.g., `fit()`, `fit_transform()`, ...) and assumes the existence of an active
1575          MLflow run that can be referenced via the fluent Tracking API.
1576  
1577          Args:
1578              autologging_client: An instance of `MlflowAutologgingQueueingClient` used for
1579                  efficiently logging run data to MLflow Tracking.
1580              estimator: The scikit-learn estimator for which to log metadata.
1581          """
1582          # Deep parameter logging includes parameters from children of a given
1583          # estimator. For some meta estimators (e.g., pipelines), recording
1584          # these parameters is desirable. For parameter search estimators,
1585          # however, child estimators act as seeds for the parameter search
1586          # process; accordingly, we avoid logging initial, untuned parameters
1587          # for these seed estimators.
1588          should_log_params_deeply = not _is_parameter_search_estimator(estimator)
1589          run_id = mlflow.active_run().info.run_id
1590          autologging_client.log_params(
1591              run_id=mlflow.active_run().info.run_id,
1592              params=estimator.get_params(deep=should_log_params_deeply),
1593          )
1594          autologging_client.set_tags(
1595              run_id=run_id,
1596              tags=_get_estimator_info_tags(estimator),
1597          )
1598  
1599          if log_datasets:
1600              try:
1601                  context_tags = context_registry.resolve_tags()
1602                  source = CodeDatasetSource(context_tags)
1603  
1604                  if dataset := _create_dataset(X, source, y):
1605                      tags = [InputTag(key=MLFLOW_DATASET_CONTEXT, value="train")]
1606                      dataset_input = DatasetInput(dataset=dataset._to_mlflow_entity(), tags=tags)
1607  
1608                      autologging_client.log_inputs(
1609                          run_id=mlflow.active_run().info.run_id, datasets=[dataset_input]
1610                      )
1611              except Exception as e:
1612                  _logger.warning(
1613                      "Failed to log training dataset information to MLflow Tracking. Reason: %s", e
1614                  )
1615  
1616      def _log_posttraining_metadata(autologging_client, estimator, X, y, sample_weight):
1617          """
1618          Records metadata for a scikit-learn estimator after training has completed.
1619          This is intended to be invoked within a patched scikit-learn training routine
1620          (e.g., `fit()`, `fit_transform()`, ...) and assumes the existence of an active
1621          MLflow run that can be referenced via the fluent Tracking API.
1622  
1623          Args:
1624              autologging_client: An instance of `MlflowAutologgingQueueingClient` used for
1625                  efficiently logging run data to MLflow Tracking.
1626              estimator: The scikit-learn estimator for which to log metadata.
1627              X: The training dataset samples passed to the ``estimator.fit()`` function.
1628              y: The training dataset labels passed to the ``estimator.fit()`` function.
1629              sample_weight: Sample weights passed to the ``estimator.fit()`` function.
1630          """
1631          # Fetch an input example using the first several rows of the array-like
1632          # training data supplied to the training routine (e.g., `fit()`). Copy the
1633          # example to avoid mutation during subsequent metric computations
1634          input_example_exc = None
1635          try:
1636              input_example = deepcopy(X[:INPUT_EXAMPLE_SAMPLE_ROWS])
1637          except Exception as e:
1638              input_example_exc = e
1639  
1640          def get_input_example():
1641              if input_example_exc is not None:
1642                  raise input_example_exc
1643              else:
1644                  return input_example
1645  
1646          def infer_model_signature(input_example):
1647              if hasattr(estimator, "predict"):
1648                  # Copy the input example so that it is not mutated by the call to
1649                  # predict() prior to signature inference
1650                  model_output = estimator.predict(deepcopy(input_example))
1651              elif hasattr(estimator, "transform"):
1652                  model_output = estimator.transform(deepcopy(input_example))
1653              else:
1654                  raise Exception(
1655                      "the trained model does not have a `predict` or `transform` "
1656                      "function, which is required in order to infer the signature"
1657                  )
1658  
1659              return infer_signature(input_example, model_output)
1660  
1661          def _log_model_with_except_handling(*args, **kwargs):
1662              try:
1663                  return log_model(*args, **kwargs)
1664              except _SklearnCustomModelPicklingError as e:
1665                  _logger.warning(str(e))
1666  
1667          model_id = None
1668          if log_models:
1669              # Will only resolve `input_example` and `signature` if `log_models` is `True`.
1670              input_example, signature = resolve_input_example_and_signature(
1671                  get_input_example,
1672                  infer_model_signature,
1673                  log_input_examples,
1674                  log_model_signatures,
1675                  _logger,
1676              )
1677              registered_model_name = get_autologging_config(
1678                  FLAVOR_NAME, "registered_model_name", None
1679              )
1680              should_log_params_deeply = not _is_parameter_search_estimator(estimator)
1681              params = estimator.get_params(deep=should_log_params_deeply)
1682              if hasattr(estimator, "best_params_"):
1683                  params |= {
1684                      f"best_{param_name}": param_value
1685                      for param_name, param_value in estimator.best_params_.items()
1686                  }
1687              if logged_model := _log_model_with_except_handling(
1688                  estimator,
1689                  name="model",
1690                  signature=signature,
1691                  input_example=input_example,
1692                  serialization_format=serialization_format,
1693                  registered_model_name=registered_model_name,
1694                  params=params,
1695              ):
1696                  model_id = logged_model.model_id
1697                  _AUTOLOGGING_METRICS_MANAGER.record_model_id(estimator, logged_model.model_id)
1698  
1699          # log common metrics and artifacts for estimators (classifier, regressor)
1700          context_tags = context_registry.resolve_tags()
1701          source = CodeDatasetSource(context_tags)
1702          try:
1703              dataset = _create_dataset(X, source, y)
1704          except Exception:
1705              _logger.debug("Failed to create dataset for logging.", exc_info=True)
1706              dataset = None
1707          logged_metrics = _log_estimator_content(
1708              autologging_client=autologging_client,
1709              estimator=estimator,
1710              prefix=_TRAINING_PREFIX,
1711              run_id=mlflow.active_run().info.run_id,
1712              X=X,
1713              y_true=y,
1714              sample_weight=sample_weight,
1715              pos_label=pos_label,
1716              dataset=dataset,
1717              model_id=model_id,
1718          )
1719          if y is None and not logged_metrics:
1720              _logger.warning(
1721                  "Training metrics will not be recorded because training labels were not specified."
1722                  " To automatically record training metrics, provide training labels as inputs to"
1723                  " the model training function."
1724              )
1725  
1726          best_estimator_model_id = None
1727          best_estimator_params = None
1728          if _is_parameter_search_estimator(estimator):
1729              if hasattr(estimator, "best_estimator_") and log_models:
1730                  best_estimator_params = estimator.best_estimator_.get_params(deep=True)
1731                  if model_info := _log_model_with_except_handling(
1732                      estimator.best_estimator_,
1733                      name="best_estimator",
1734                      signature=signature,
1735                      input_example=input_example,
1736                      serialization_format=serialization_format,
1737                      params=best_estimator_params,
1738                  ):
1739                      best_estimator_model_id = model_info.model_id
1740  
1741              if hasattr(estimator, "best_score_"):
1742                  autologging_client.log_metrics(
1743                      run_id=mlflow.active_run().info.run_id,
1744                      metrics={"best_cv_score": estimator.best_score_},
1745                      dataset=dataset,
1746                      model_id=model_id,
1747                  )
1748  
1749              if hasattr(estimator, "best_params_"):
1750                  best_params = {
1751                      f"best_{param_name}": param_value
1752                      for param_name, param_value in estimator.best_params_.items()
1753                  }
1754                  autologging_client.log_params(
1755                      run_id=mlflow.active_run().info.run_id,
1756                      params=best_params,
1757                  )
1758  
1759              if hasattr(estimator, "cv_results_"):
1760                  try:
1761                      # Fetch environment-specific tags (e.g., user and source) to ensure that lineage
1762                      # information is consistent with the parent run
1763                      child_tags = context_registry.resolve_tags()
1764                      child_tags.update({MLFLOW_AUTOLOGGING: flavor_name})
1765                      _create_child_runs_for_parameter_search(
1766                          autologging_client=autologging_client,
1767                          cv_estimator=estimator,
1768                          parent_run=mlflow.active_run(),
1769                          max_tuning_runs=max_tuning_runs,
1770                          child_tags=child_tags,
1771                          dataset=dataset,
1772                          best_estimator_params=best_estimator_params,
1773                          best_estimator_model_id=best_estimator_model_id,
1774                      )
1775                  except Exception as e:
1776                      _logger.warning(
1777                          "Encountered exception during creation of child runs for parameter search."
1778                          f" Child runs may be missing. Exception: {e}"
1779                      )
1780  
1781                  try:
1782                      cv_results_df = pd.DataFrame.from_dict(estimator.cv_results_)
1783                      _log_parameter_search_results_as_artifact(
1784                          cv_results_df, mlflow.active_run().info.run_id
1785                      )
1786                  except Exception as e:
1787                      _logger.warning(
1788                          f"Failed to log parameter search results as an artifact. Exception: {e}"
1789                      )
1790  
1791      def patched_fit(fit_impl, allow_children_patch, original, self, *args, **kwargs):
1792          """
1793          Autologging patch function to be applied to a sklearn model class that defines a `fit`
1794          method and inherits from `BaseEstimator` (thereby defining the `get_params()` method)
1795  
1796          Args:
1797              fit_impl: The patched fit function implementation, the function should be defined as
1798                  `fit_mlflow(original, self, *args, **kwargs)`, the `original` argument
1799                  refers to the original `EstimatorClass.fit` method, the `self` argument
1800                  refers to the estimator instance being patched, the `*args` and
1801                  `**kwargs` are arguments passed to the original fit method.
1802              allow_children_patch: Whether to allow children sklearn session logging or not.
1803              original: the original `EstimatorClass.fit` method to be patched.
1804              self: the estimator instance being patched.
1805              args: positional arguments to be passed to the original fit method.
1806              kwargs: keyword arguments to be passed to the original fit method.
1807          """
1808          should_log_post_training_metrics = (
1809              log_post_training_metrics
1810              and _AUTOLOGGING_METRICS_MANAGER.should_log_post_training_metrics()
1811          )
1812  
1813          with _SklearnTrainingSession(estimator=self, allow_children=allow_children_patch) as t:
1814              if t.should_log():
1815                  # In `fit_mlflow` call, it will also call metric API for computing training metrics
1816                  # so we need temporarily disable the post_training_metrics patching.
1817                  with _AUTOLOGGING_METRICS_MANAGER.disable_log_post_training_metrics():
1818                      result = fit_impl(original, self, *args, **kwargs)
1819                  if should_log_post_training_metrics:
1820                      _AUTOLOGGING_METRICS_MANAGER.register_model(
1821                          self, mlflow.active_run().info.run_id
1822                      )
1823                  return result
1824              else:
1825                  return original(self, *args, **kwargs)
1826  
1827      def patched_predict(original, self, *args, **kwargs):
1828          """
1829          In `patched_predict`, register the prediction result instance with the run id and
1830           eval dataset name. e.g.
1831          ```
1832          prediction_result = model_1.predict(eval_X)
1833          ```
1834          then we need register the following relationship into the `_AUTOLOGGING_METRICS_MANAGER`:
1835          id(prediction_result) --> (eval_dataset_name, run_id)
1836  
1837          Note: we cannot set additional attributes "eval_dataset_name" and "run_id" into
1838          the prediction_result object, because certain dataset type like numpy does not support
1839          additional attribute assignment.
1840          """
1841          run_id = _AUTOLOGGING_METRICS_MANAGER.get_run_id_for_model(self)
1842          if _AUTOLOGGING_METRICS_MANAGER.should_log_post_training_metrics() and run_id:
1843              # Avoid nested patch when nested inference calls happens.
1844              with _AUTOLOGGING_METRICS_MANAGER.disable_log_post_training_metrics():
1845                  predict_result = original(self, *args, **kwargs)
1846              eval_dataset = get_instance_method_first_arg_value(original, args, kwargs)
1847              eval_dataset_name = _AUTOLOGGING_METRICS_MANAGER.register_prediction_input_dataset(
1848                  self, eval_dataset
1849              )
1850              _AUTOLOGGING_METRICS_MANAGER.register_prediction_result(
1851                  run_id,
1852                  eval_dataset_name,
1853                  predict_result,
1854                  model_id=_AUTOLOGGING_METRICS_MANAGER.get_model_id_for_model(self),
1855              )
1856              if log_datasets:
1857                  try:
1858                      context_tags = context_registry.resolve_tags()
1859                      source = CodeDatasetSource(context_tags)
1860  
1861                      # log the dataset
1862                      if dataset := _create_dataset(eval_dataset, source):
1863                          tags = [InputTag(key=MLFLOW_DATASET_CONTEXT, value="eval")]
1864                          dataset_input = DatasetInput(dataset=dataset._to_mlflow_entity(), tags=tags)
1865  
1866                          # log the dataset
1867                          client = mlflow.MlflowClient()
1868                          client.log_inputs(run_id=run_id, datasets=[dataset_input])
1869                  except Exception as e:
1870                      _logger.warning(
1871                          "Failed to log evaluation dataset information to "
1872                          "MLflow Tracking. Reason: %s",
1873                          e,
1874                      )
1875              return predict_result
1876          else:
1877              return original(self, *args, **kwargs)
1878  
1879      def patched_metric_api(original, *args, **kwargs):
1880          if _AUTOLOGGING_METRICS_MANAGER.should_log_post_training_metrics():
1881              # one metric api may call another metric api,
1882              # to avoid this, call disable_log_post_training_metrics to avoid nested patch
1883              with _AUTOLOGGING_METRICS_MANAGER.disable_log_post_training_metrics():
1884                  metric = original(*args, **kwargs)
1885  
1886              if _AUTOLOGGING_METRICS_MANAGER.is_metric_value_loggable(metric):
1887                  metric_name = original.__name__
1888                  call_command = _AUTOLOGGING_METRICS_MANAGER.gen_metric_call_command(
1889                      None, original, *args, **kwargs
1890                  )
1891  
1892                  (run_id, dataset_name, model_id) = (
1893                      _AUTOLOGGING_METRICS_MANAGER.get_info_for_metric_api_call(args, kwargs)
1894                  )
1895                  if run_id and dataset_name:
1896                      metric_key = _AUTOLOGGING_METRICS_MANAGER.register_metric_api_call(
1897                          run_id, metric_name, dataset_name, call_command
1898                      )
1899                      _AUTOLOGGING_METRICS_MANAGER.log_post_training_metric(
1900                          run_id, metric_key, metric, model_id=model_id
1901                      )
1902  
1903              return metric
1904          else:
1905              return original(*args, **kwargs)
1906  
1907      # we need patch model.score method because:
1908      #  some model.score() implementation won't call metric APIs in `sklearn.metrics`
1909      #  e.g.
1910      #  https://github.com/scikit-learn/scikit-learn/blob/82df48934eba1df9a1ed3be98aaace8eada59e6e/sklearn/covariance/_empirical_covariance.py#L220
1911      def patched_model_score(original, self, *args, **kwargs):
1912          run_id = _AUTOLOGGING_METRICS_MANAGER.get_run_id_for_model(self)
1913          if _AUTOLOGGING_METRICS_MANAGER.should_log_post_training_metrics() and run_id:
1914              # `model.score` may call metric APIs internally, in order to prevent nested metric call
1915              # being logged, temporarily disable post_training_metrics patching.
1916              with _AUTOLOGGING_METRICS_MANAGER.disable_log_post_training_metrics():
1917                  score_value = original(self, *args, **kwargs)
1918  
1919              if _AUTOLOGGING_METRICS_MANAGER.is_metric_value_loggable(score_value):
1920                  metric_name = f"{self.__class__.__name__}_score"
1921                  call_command = _AUTOLOGGING_METRICS_MANAGER.gen_metric_call_command(
1922                      self, original, *args, **kwargs
1923                  )
1924  
1925                  eval_dataset = get_instance_method_first_arg_value(original, args, kwargs)
1926                  eval_dataset_name = _AUTOLOGGING_METRICS_MANAGER.register_prediction_input_dataset(
1927                      self, eval_dataset
1928                  )
1929                  metric_key = _AUTOLOGGING_METRICS_MANAGER.register_metric_api_call(
1930                      run_id, metric_name, eval_dataset_name, call_command
1931                  )
1932                  model_id = _AUTOLOGGING_METRICS_MANAGER.get_model_id_for_model(self)
1933                  _AUTOLOGGING_METRICS_MANAGER.log_post_training_metric(
1934                      run_id, metric_key, score_value, model_id=model_id
1935                  )
1936  
1937              return score_value
1938          else:
1939              return original(self, *args, **kwargs)
1940  
1941      def _apply_sklearn_descriptor_unbound_method_call_fix():
1942          import sklearn
1943  
1944          if Version(sklearn.__version__) <= Version("0.24.2"):
1945              import sklearn.utils.metaestimators
1946  
1947              if not hasattr(sklearn.utils.metaestimators, "_IffHasAttrDescriptor"):
1948                  return
1949  
1950              def patched_IffHasAttrDescriptor__get__(self, obj, type=None):
1951                  """
1952                  For sklearn version <= 0.24.2, `_IffHasAttrDescriptor.__get__` method does not
1953                  support unbound method call.
1954                  See https://github.com/scikit-learn/scikit-learn/issues/20614
1955                  This patched function is for hot patch.
1956                  """
1957  
1958                  # raise an AttributeError if the attribute is not present on the object
1959                  if obj is not None:
1960                      # delegate only on instances, not the classes.
1961                      # this is to allow access to the docstrings.
1962                      for delegate_name in self.delegate_names:
1963                          try:
1964                              delegate = sklearn.utils.metaestimators.attrgetter(delegate_name)(obj)
1965                          except AttributeError:
1966                              continue
1967                          else:
1968                              getattr(delegate, self.attribute_name)
1969                              break
1970                      else:
1971                          sklearn.utils.metaestimators.attrgetter(self.delegate_names[-1])(obj)
1972  
1973                      def out(*args, **kwargs):
1974                          return self.fn(obj, *args, **kwargs)
1975  
1976                  else:
1977                      # This makes it possible to use the decorated method as an unbound method,
1978                      # for instance when monkeypatching.
1979                      def out(*args, **kwargs):
1980                          return self.fn(*args, **kwargs)
1981  
1982                  # update the docstring of the returned function
1983                  functools.update_wrapper(out, self.fn)
1984                  return out
1985  
1986              update_wrapper_extended(
1987                  patched_IffHasAttrDescriptor__get__,
1988                  sklearn.utils.metaestimators._IffHasAttrDescriptor.__get__,
1989              )
1990  
1991              sklearn.utils.metaestimators._IffHasAttrDescriptor.__get__ = (
1992                  patched_IffHasAttrDescriptor__get__
1993              )
1994  
1995      _apply_sklearn_descriptor_unbound_method_call_fix()
1996  
1997      if flavor_name == mlflow.xgboost.FLAVOR_NAME:
1998          estimators_to_patch = _gen_xgboost_sklearn_estimators_to_patch()
1999          patched_fit_impl = fit_mlflow_xgboost_and_lightgbm
2000          allow_children_patch = True
2001      elif flavor_name == mlflow.lightgbm.FLAVOR_NAME:
2002          estimators_to_patch = _gen_lightgbm_sklearn_estimators_to_patch()
2003          patched_fit_impl = fit_mlflow_xgboost_and_lightgbm
2004          allow_children_patch = True
2005      else:
2006          estimators_to_patch = _gen_estimators_to_patch()
2007          patched_fit_impl = fit_mlflow
2008          allow_children_patch = False
2009  
2010      for class_def in estimators_to_patch:
2011          # Patch fitting methods
2012          for func_name in ["fit", "fit_transform", "fit_predict"]:
2013              _patch_estimator_method_if_available(
2014                  flavor_name,
2015                  class_def,
2016                  func_name,
2017                  functools.partial(patched_fit, patched_fit_impl, allow_children_patch),
2018                  manage_run=True,
2019                  extra_tags=extra_tags,
2020              )
2021  
2022          # Patch inference methods
2023          for func_name in ["predict", "predict_proba", "transform", "predict_log_proba"]:
2024              _patch_estimator_method_if_available(
2025                  flavor_name,
2026                  class_def,
2027                  func_name,
2028                  patched_predict,
2029                  manage_run=False,
2030              )
2031  
2032          # Patch scoring methods
2033          _patch_estimator_method_if_available(
2034              flavor_name,
2035              class_def,
2036              "score",
2037              patched_model_score,
2038              manage_run=False,
2039              extra_tags=extra_tags,
2040          )
2041  
2042      if log_post_training_metrics:
2043          for metric_name in _get_metric_name_list():
2044              safe_patch(
2045                  flavor_name, sklearn.metrics, metric_name, patched_metric_api, manage_run=False
2046              )
2047  
2048          # `sklearn.metrics.SCORERS` was removed in scikit-learn 1.3
2049          if hasattr(sklearn.metrics, "get_scorer_names"):
2050              for scoring in sklearn.metrics.get_scorer_names():
2051                  scorer = sklearn.metrics.get_scorer(scoring)
2052                  safe_patch(flavor_name, scorer, "_score_func", patched_metric_api, manage_run=False)
2053          else:
2054              for scorer in sklearn.metrics.SCORERS.values():
2055                  safe_patch(flavor_name, scorer, "_score_func", patched_metric_api, manage_run=False)
2056  
2057      def patched_fn_with_autolog_disabled(original, *args, **kwargs):
2058          with disable_autologging():
2059              return original(*args, **kwargs)
2060  
2061      for disable_autolog_func_name in _apis_autologging_disabled:
2062          safe_patch(
2063              flavor_name,
2064              sklearn.model_selection,
2065              disable_autolog_func_name,
2066              patched_fn_with_autolog_disabled,
2067              manage_run=False,
2068          )
2069  
2070      def _create_dataset(X, source, y=None, dataset_name=None):
2071          # create a dataset
2072          from scipy.sparse import issparse
2073  
2074          if isinstance(X, pd.DataFrame):
2075              dataset = from_pandas(df=X, source=source)
2076          elif issparse(X):
2077              arr_X = X.toarray()
2078              if y is not None:
2079                  dataset = from_numpy(
2080                      features=arr_X,
2081                      targets=y.toarray() if issparse(y) else y,
2082                      source=source,
2083                      name=dataset_name,
2084                  )
2085              else:
2086                  dataset = from_numpy(features=arr_X, source=source, name=dataset_name)
2087          elif isinstance(X, np.ndarray):
2088              if y is not None:
2089                  dataset = from_numpy(features=X, targets=y, source=source, name=dataset_name)
2090              else:
2091                  dataset = from_numpy(features=X, source=source, name=dataset_name)
2092          elif is_polars_dataframe(X):
2093              from mlflow.data.polars_dataset import from_polars
2094  
2095              dataset = from_polars(df=X, source=source, name=dataset_name)
2096          else:
2097              _logger.warning("Unrecognized dataset type %s. Dataset logging skipped.", type(X))
2098              return None
2099          return dataset