__init__.py
1 """ 2 The ``mlflow.sklearn`` module provides an API for logging and loading scikit-learn models. This 3 module exports scikit-learn models with the following flavors: 4 5 Python (native) `pickle <https://scikit-learn.org/stable/modules/model_persistence.html>`_ format 6 This is the main flavor that can be loaded back into scikit-learn. 7 8 :py:mod:`mlflow.pyfunc` 9 Produced for use by generic pyfunc-based deployment tools and batch inference. 10 NOTE: The `mlflow.pyfunc` flavor is only added for scikit-learn models that define `predict()`, 11 since `predict()` is required for pyfunc model inference. 12 """ 13 14 import functools 15 import inspect 16 import logging 17 import os 18 import pickle 19 import shutil 20 import weakref 21 from collections import OrderedDict, defaultdict 22 from copy import deepcopy 23 from typing import Any 24 25 import numpy as np 26 import yaml 27 from packaging.version import Version 28 29 import mlflow 30 from mlflow import pyfunc 31 from mlflow.data.code_dataset_source import CodeDatasetSource 32 from mlflow.data.numpy_dataset import from_numpy 33 from mlflow.data.pandas_dataset import from_pandas 34 from mlflow.entities.dataset_input import DatasetInput 35 from mlflow.entities.input_tag import InputTag 36 from mlflow.environment_variables import MLFLOW_ALLOW_PICKLE_DESERIALIZATION 37 from mlflow.exceptions import MlflowException 38 from mlflow.models import Model, ModelInputExample, ModelSignature 39 from mlflow.models.model import MLMODEL_FILE_NAME 40 from mlflow.models.signature import _infer_signature_from_input_example 41 from mlflow.models.utils import _save_example 42 from mlflow.protos.databricks_pb2 import INTERNAL_ERROR, INVALID_PARAMETER_VALUE 43 from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS 44 from mlflow.tracking.artifact_utils import _download_artifact_from_uri 45 from mlflow.tracking.client import MlflowClient 46 from mlflow.utils import _inspect_original_var_name, gorilla 47 from mlflow.utils.autologging_utils import ( 48 INPUT_EXAMPLE_SAMPLE_ROWS, 49 MlflowAutologgingQueueingClient, 50 _get_new_training_session_class, 51 autologging_integration, 52 disable_autologging, 53 get_autologging_config, 54 get_instance_method_first_arg_value, 55 resolve_input_example_and_signature, 56 safe_patch, 57 update_wrapper_extended, 58 ) 59 from mlflow.utils.data_utils import is_polars_dataframe 60 from mlflow.utils.databricks_utils import ( 61 is_in_databricks_model_serving_environment, 62 is_in_databricks_runtime, 63 ) 64 from mlflow.utils.docstring_utils import LOG_MODEL_PARAM_DOCS, format_docstring 65 from mlflow.utils.environment import ( 66 _CONDA_ENV_FILE_NAME, 67 _CONSTRAINTS_FILE_NAME, 68 _PYTHON_ENV_FILE_NAME, 69 _REQUIREMENTS_FILE_NAME, 70 _mlflow_conda_env, 71 _process_conda_env, 72 _process_pip_requirements, 73 _PythonEnv, 74 _validate_env_arguments, 75 ) 76 from mlflow.utils.file_utils import get_total_file_size, write_to 77 from mlflow.utils.mlflow_tags import ( 78 MLFLOW_AUTOLOGGING, 79 MLFLOW_DATASET_CONTEXT, 80 ) 81 from mlflow.utils.model_utils import ( 82 _add_code_from_conf_to_system_path, 83 _copy_extra_files, 84 _get_flavor_configuration, 85 _validate_and_copy_code_paths, 86 _validate_and_prepare_target_save_path, 87 ) 88 from mlflow.utils.requirements_utils import _get_pinned_requirement 89 90 FLAVOR_NAME = "sklearn" 91 92 SERIALIZATION_FORMAT_SKOPS = "skops" 93 SERIALIZATION_FORMAT_PICKLE = "pickle" 94 SERIALIZATION_FORMAT_CLOUDPICKLE = "cloudpickle" 95 96 SUPPORTED_SERIALIZATION_FORMATS = [ 97 SERIALIZATION_FORMAT_SKOPS, 98 SERIALIZATION_FORMAT_PICKLE, 99 SERIALIZATION_FORMAT_CLOUDPICKLE, 100 ] 101 102 _logger = logging.getLogger(__name__) 103 _SklearnTrainingSession = _get_new_training_session_class() 104 105 _PICKLE_MODEL_DATA_SUBPATH = "model.pkl" 106 _SKOPS_MODEL_DATA_SUBPATH = "model.skops" 107 108 109 def _gen_estimators_to_patch(): 110 from mlflow.sklearn.utils import ( 111 _all_estimators, 112 _get_meta_estimators_for_autologging, 113 ) 114 115 _, estimators_to_patch = zip(*_all_estimators()) 116 # Ensure that relevant meta estimators (e.g. GridSearchCV, Pipeline) are selected 117 # for patching if they are not already included in the output of `all_estimators()` 118 estimators_to_patch = set(estimators_to_patch).union( 119 set(_get_meta_estimators_for_autologging()) 120 ) 121 # Exclude certain preprocessing & feature manipulation estimators from patching. These 122 # estimators represent data manipulation routines (e.g., normalization, label encoding) 123 # rather than ML algorithms. Accordingly, we should not create MLflow runs and log 124 # parameters / metrics for these routines, unless they are captured as part of an ML pipeline 125 # (via `sklearn.pipeline.Pipeline`) 126 excluded_module_names = [ 127 "sklearn.preprocessing", 128 "sklearn.impute", 129 "sklearn.feature_extraction", 130 "sklearn.feature_selection", 131 ] 132 133 excluded_class_names = [ 134 "sklearn.compose._column_transformer.ColumnTransformer", 135 ] 136 137 return [ 138 estimator 139 for estimator in estimators_to_patch 140 if not any( 141 estimator.__module__.startswith(excluded_module_name) 142 or (estimator.__module__ + "." + estimator.__name__) in excluded_class_names 143 for excluded_module_name in excluded_module_names 144 ) 145 ] 146 147 148 def get_default_pip_requirements(include_cloudpickle=False, include_skops=False): 149 """ 150 Returns: 151 A list of default pip requirements for MLflow Models produced by this flavor. 152 Calls to :func:`save_model()` and :func:`log_model()` produce a pip environment 153 that, at minimum, contains these requirements. 154 """ 155 pip_deps = [_get_pinned_requirement("scikit-learn", module="sklearn")] 156 if include_cloudpickle: 157 pip_deps += [_get_pinned_requirement("cloudpickle")] 158 if include_skops: 159 pip_deps += [_get_pinned_requirement("skops")] 160 161 return pip_deps 162 163 164 def get_default_conda_env(include_cloudpickle=False, include_skops=False): 165 """ 166 Returns: 167 The default Conda environment for MLflow Models produced by calls to 168 :func:`save_model()` and :func:`log_model()`. 169 """ 170 return _mlflow_conda_env( 171 additional_pip_deps=get_default_pip_requirements(include_cloudpickle, include_skops) 172 ) 173 174 175 @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="scikit-learn")) 176 def save_model( 177 sk_model, 178 path, 179 conda_env=None, 180 code_paths=None, 181 mlflow_model=None, 182 serialization_format=SERIALIZATION_FORMAT_CLOUDPICKLE, 183 signature: ModelSignature = None, 184 input_example: ModelInputExample = None, 185 pip_requirements=None, 186 extra_pip_requirements=None, 187 pyfunc_predict_fn="predict", 188 metadata=None, 189 skops_trusted_types=None, 190 extra_files=None, 191 ): 192 """ 193 Save a scikit-learn model to a path on the local file system. Produces a MLflow Model 194 containing the following flavors: 195 196 - :py:mod:`mlflow.sklearn` 197 - :py:mod:`mlflow.pyfunc`. NOTE: This flavor is only included for scikit-learn models 198 that define `predict()`, since `predict()` is required for pyfunc model inference. 199 200 Args: 201 sk_model: scikit-learn model to be saved. 202 path: Local path where the model is to be saved. 203 conda_env: {{ conda_env }} 204 code_paths: {{ code_paths }} 205 mlflow_model: :py:mod:`mlflow.models.Model` this flavor is being added to. 206 serialization_format: The format in which to serialize the model. This should be one of 207 the formats "skops", "cloudpickle" or "pickle". 208 The "skops" format guarantees safe deserialization. 209 The "cloudpickle" format, provides better cross-system compatibility by identifying and 210 packaging code dependencies with the serialized model, but requires exercising 211 caution because these formats rely on Python's object serialization mechanism, 212 which can execute arbitrary code during deserialization. 213 214 signature: {{ signature }} 215 input_example: {{ input_example }} 216 pip_requirements: {{ pip_requirements }} 217 extra_pip_requirements: {{ extra_pip_requirements }} 218 pyfunc_predict_fn: The name of the prediction function to use for inference with the 219 pyfunc representation of the resulting MLflow Model. Current supported functions 220 are: ``"predict"``, ``"predict_proba"``, ``"predict_log_proba"``, 221 ``"predict_joint_log_proba"``, and ``"score"``. 222 metadata: {{ metadata }} 223 skops_trusted_types: A list of trusted types when loading model that is saved as 224 the ``mlflow.sklearn.SERIALIZATION_FORMAT_SKOPS`` format. 225 extra_files: {{ extra_files }} 226 227 .. code-block:: python 228 :caption: Example 229 230 import mlflow.sklearn 231 from sklearn.datasets import load_iris 232 from sklearn import tree 233 234 iris = load_iris() 235 sk_model = tree.DecisionTreeClassifier() 236 sk_model = sk_model.fit(iris.data, iris.target) 237 238 # Save the model in cloudpickle format 239 # set path to location for persistence 240 sk_path_dir_1 = ... 241 mlflow.sklearn.save_model( 242 sk_model, 243 sk_path_dir_1, 244 serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE, 245 ) 246 247 # save the model in pickle format 248 # set path to location for persistence 249 sk_path_dir_2 = ... 250 mlflow.sklearn.save_model( 251 sk_model, 252 sk_path_dir_2, 253 serialization_format=mlflow.sklearn.SERIALIZATION_FORMAT_PICKLE, 254 ) 255 """ 256 import sklearn 257 258 _validate_env_arguments(conda_env, pip_requirements, extra_pip_requirements) 259 260 if serialization_format not in SUPPORTED_SERIALIZATION_FORMATS: 261 raise MlflowException( 262 message=( 263 f"Unrecognized serialization format: {serialization_format}. Please specify one" 264 f" of the following supported formats: {SUPPORTED_SERIALIZATION_FORMATS}." 265 ), 266 error_code=INVALID_PARAMETER_VALUE, 267 ) 268 269 if serialization_format != SERIALIZATION_FORMAT_SKOPS and not is_in_databricks_runtime(): 270 _logger.warning( 271 "Saving scikit-learn models in the pickle or cloudpickle format requires exercising " 272 "caution because these formats rely on Python's object serialization mechanism, " 273 "which can execute arbitrary code during deserialization. " 274 "The recommended safe alternative is the 'skops' format. " 275 "For more information, see: https://scikit-learn.org/stable/model_persistence.html", 276 ) 277 278 _validate_and_prepare_target_save_path(path) 279 code_path_subdir = _validate_and_copy_code_paths(code_paths, path) 280 281 if mlflow_model is None: 282 mlflow_model = Model() 283 saved_example = _save_example(mlflow_model, input_example, path) 284 285 if signature is None and saved_example is not None: 286 wrapped_model = _SklearnModelWrapper(sk_model) 287 signature = _infer_signature_from_input_example(saved_example, wrapped_model) 288 elif signature is False: 289 signature = None 290 291 if signature is not None: 292 mlflow_model.signature = signature 293 if metadata is not None: 294 mlflow_model.metadata = metadata 295 296 if serialization_format == SERIALIZATION_FORMAT_SKOPS: 297 model_data_subpath = _SKOPS_MODEL_DATA_SUBPATH 298 else: 299 model_data_subpath = _PICKLE_MODEL_DATA_SUBPATH 300 model_data_path = os.path.join(path, model_data_subpath) 301 _save_model( 302 sk_model=sk_model, 303 output_path=model_data_path, 304 serialization_format=serialization_format, 305 skops_trusted_types=skops_trusted_types, 306 ) 307 308 extra_files_config = _copy_extra_files(extra_files, path) 309 310 # `PyFuncModel` only works for sklearn models that define a predict function 311 312 if hasattr(sk_model, pyfunc_predict_fn): 313 pyfunc.add_to_model( 314 mlflow_model, 315 loader_module="mlflow.sklearn", 316 model_path=model_data_subpath, 317 conda_env=_CONDA_ENV_FILE_NAME, 318 python_env=_PYTHON_ENV_FILE_NAME, 319 code=code_path_subdir, 320 predict_fn=pyfunc_predict_fn, 321 ) 322 else: 323 _logger.warning( 324 f"Model was missing function: {pyfunc_predict_fn}. Not logging python_function flavor!" 325 ) 326 327 mlflow_model.add_flavor( 328 FLAVOR_NAME, 329 pickled_model=model_data_subpath, 330 sklearn_version=sklearn.__version__, 331 serialization_format=serialization_format, 332 code=code_path_subdir, 333 skops_trusted_types=skops_trusted_types, 334 **extra_files_config, 335 ) 336 if size := get_total_file_size(path): 337 mlflow_model.model_size_bytes = size 338 mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME)) 339 340 if conda_env is None: 341 if pip_requirements is None: 342 include_cloudpickle = serialization_format == SERIALIZATION_FORMAT_CLOUDPICKLE 343 include_skops = serialization_format == SERIALIZATION_FORMAT_SKOPS 344 default_reqs = get_default_pip_requirements( 345 include_cloudpickle=include_cloudpickle, 346 include_skops=include_skops, 347 ) 348 # To ensure `_load_pyfunc` can successfully load the model during the dependency 349 # inference, `mlflow_model.save` must be called beforehand to save an MLmodel file. 350 inferred_reqs = mlflow.models.infer_pip_requirements( 351 model_data_path, 352 FLAVOR_NAME, 353 fallback=default_reqs, 354 ) 355 default_reqs = sorted(set(inferred_reqs).union(default_reqs)) 356 else: 357 default_reqs = None 358 conda_env, pip_requirements, pip_constraints = _process_pip_requirements( 359 default_reqs, 360 pip_requirements, 361 extra_pip_requirements, 362 ) 363 else: 364 conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env) 365 366 with open(os.path.join(path, _CONDA_ENV_FILE_NAME), "w") as f: 367 yaml.safe_dump(conda_env, stream=f, default_flow_style=False) 368 369 # Save `constraints.txt` if necessary 370 if pip_constraints: 371 write_to(os.path.join(path, _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints)) 372 373 # Save `requirements.txt` 374 write_to(os.path.join(path, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements)) 375 376 _PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME)) 377 378 379 @format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name="scikit-learn")) 380 def log_model( 381 sk_model, 382 artifact_path: str | None = None, 383 conda_env=None, 384 code_paths=None, 385 serialization_format=SERIALIZATION_FORMAT_CLOUDPICKLE, 386 registered_model_name=None, 387 signature: ModelSignature = None, 388 input_example: ModelInputExample = None, 389 await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, 390 pip_requirements=None, 391 extra_pip_requirements=None, 392 pyfunc_predict_fn="predict", 393 metadata=None, 394 extra_files=None, 395 params: dict[str, Any] | None = None, 396 tags: dict[str, Any] | None = None, 397 model_type: str | None = None, 398 step: int = 0, 399 model_id: str | None = None, 400 name: str | None = None, 401 skops_trusted_types: list[str] | None = None, 402 **kwargs, 403 ): 404 """ 405 Log a scikit-learn model as an MLflow artifact for the current run. Produces an MLflow Model 406 containing the following flavors: 407 408 - :py:mod:`mlflow.sklearn` 409 - :py:mod:`mlflow.pyfunc`. NOTE: This flavor is only included for scikit-learn models 410 that define `predict()`, since `predict()` is required for pyfunc model inference. 411 412 Args: 413 sk_model: scikit-learn model to be saved. 414 artifact_path: Deprecated. Use `name` instead. 415 conda_env: {{ conda_env }} 416 code_paths: {{ code_paths }} 417 serialization_format: The format in which to serialize the model. This should be one of 418 the formats "skops", "cloudpickle" or "pickle". 419 The "skops" format guarantees safe deserialization. 420 The "cloudpickle" format, provides better cross-system compatibility by identifying and 421 packaging code dependencies with the serialized model, but requires exercising 422 caution because these formats rely on Python's object serialization mechanism, 423 which can execute arbitrary code during deserialization. 424 registered_model_name: If given, create a model version under 425 ``registered_model_name``, also creating a registered model if one 426 with the given name does not exist. 427 signature: {{ signature }} 428 input_example: {{ input_example }} 429 await_registration_for: Number of seconds to wait for the model version to finish 430 being created and is in ``READY`` status. By default, the function 431 waits for five minutes. Specify 0 or None to skip waiting. 432 pip_requirements: {{ pip_requirements }} 433 extra_pip_requirements: {{ extra_pip_requirements }} 434 pyfunc_predict_fn: The name of the prediction function to use for inference with the 435 pyfunc representation of the resulting MLflow Model. Current supported functions 436 are: ``"predict"``, ``"predict_proba"``, ``"predict_log_proba"``, 437 ``"predict_joint_log_proba"``, and ``"score"``. 438 metadata: {{ metadata }} 439 extra_files: {{ extra_files }} 440 params: {{ params }} 441 tags: {{ tags }} 442 model_type: {{ model_type }} 443 step: {{ step }} 444 model_id: {{ model_id }} 445 name: {{ name }} 446 skops_trusted_types: A list of trusted types when loading model that is saved as 447 the ``mlflow.sklearn.SERIALIZATION_FORMAT_SKOPS`` format. 448 kwargs: Extra arguments to pass to :py:func:`mlflow.models.Model.log`. 449 450 Returns: 451 A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the 452 metadata of the logged model. 453 454 .. code-block:: python 455 :caption: Example 456 457 import mlflow 458 import mlflow.sklearn 459 from mlflow.models import infer_signature 460 from sklearn.datasets import load_iris 461 from sklearn import tree 462 463 with mlflow.start_run(): 464 # load dataset and train model 465 iris = load_iris() 466 sk_model = tree.DecisionTreeClassifier() 467 sk_model = sk_model.fit(iris.data, iris.target) 468 469 # log model params 470 mlflow.log_param("criterion", sk_model.criterion) 471 mlflow.log_param("splitter", sk_model.splitter) 472 signature = infer_signature(iris.data, sk_model.predict(iris.data)) 473 474 # log model 475 mlflow.sklearn.log_model(sk_model, name="sk_models", signature=signature) 476 477 """ 478 return Model.log( 479 artifact_path=artifact_path, 480 name=name, 481 flavor=mlflow.sklearn, 482 sk_model=sk_model, 483 conda_env=conda_env, 484 code_paths=code_paths, 485 serialization_format=serialization_format, 486 registered_model_name=registered_model_name, 487 signature=signature, 488 input_example=input_example, 489 await_registration_for=await_registration_for, 490 pip_requirements=pip_requirements, 491 extra_pip_requirements=extra_pip_requirements, 492 pyfunc_predict_fn=pyfunc_predict_fn, 493 metadata=metadata, 494 extra_files=extra_files, 495 params=params, 496 tags=tags, 497 model_type=model_type, 498 step=step, 499 model_id=model_id, 500 skops_trusted_types=skops_trusted_types, 501 **kwargs, 502 ) 503 504 505 def _load_model_from_local_file(path, serialization_format, skops_trusted_types=None): 506 """Load a scikit-learn model saved as an MLflow artifact on the local file system. 507 508 Args: 509 path: Local filesystem path to the MLflow Model saved with the ``sklearn`` flavor 510 serialization_format: The format in which the model was serialized. This should be one of 511 the following: ``mlflow.sklearn.SERIALIZATION_FORMAT_PICKLE`` or 512 ``mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE``. 513 """ 514 # TODO: we could validate the scikit-learn version here 515 if serialization_format not in SUPPORTED_SERIALIZATION_FORMATS: 516 raise MlflowException( 517 message=( 518 f"Unrecognized serialization format: {serialization_format}. Please specify one" 519 f" of the following supported formats: {SUPPORTED_SERIALIZATION_FORMATS}." 520 ), 521 error_code=INVALID_PARAMETER_VALUE, 522 ) 523 524 if serialization_format != SERIALIZATION_FORMAT_SKOPS: 525 if ( 526 not MLFLOW_ALLOW_PICKLE_DESERIALIZATION.get() 527 and not is_in_databricks_runtime() 528 and not is_in_databricks_model_serving_environment() 529 ): 530 raise MlflowException( 531 "Deserializing model using pickle is disallowed, but this model is saved " 532 "in pickle format. To address this issue, you need to set environment variable " 533 "'MLFLOW_ALLOW_PICKLE_DESERIALIZATION' to 'true', or save the model in " 534 "'skops' format." 535 ) 536 537 if serialization_format == SERIALIZATION_FORMAT_SKOPS: 538 import skops.io 539 540 return skops.io.load(path, trusted=skops_trusted_types) 541 else: 542 with open(path, "rb") as f: 543 # Models serialized with Cloudpickle cannot necessarily be deserialized using Pickle; 544 # That's why we check the serialization format of the model before deserializing 545 if serialization_format == SERIALIZATION_FORMAT_PICKLE: 546 return pickle.load(f) 547 elif serialization_format == SERIALIZATION_FORMAT_CLOUDPICKLE: 548 import cloudpickle 549 550 return cloudpickle.load(f) 551 552 553 def _load_pyfunc(path): 554 """ 555 Load PyFunc implementation. Called by ``pyfunc.load_model``. 556 557 Args: 558 path: Local filesystem path to the MLflow Model with the ``sklearn`` flavor. 559 """ 560 # When ``path`` is a file, it refers directly to a serialized scikit-learn model 561 # object (e.g., model.pkl or model.skops). The MLmodel file in the parent directory 562 # contains the serialization format and other flavor configuration. 563 model_dir = os.path.dirname(path) if os.path.isfile(path) else path 564 565 try: 566 sklearn_flavor_conf = _get_flavor_configuration( 567 model_path=model_dir, flavor_name=FLAVOR_NAME 568 ) 569 serialization_format = sklearn_flavor_conf.get( 570 "serialization_format", SERIALIZATION_FORMAT_PICKLE 571 ) 572 skops_trusted_types = sklearn_flavor_conf.get("skops_trusted_types", None) 573 except MlflowException: 574 _logger.warning( 575 "Could not find scikit-learn flavor configuration during model loading process." 576 " Assuming 'pickle' serialization format." 577 ) 578 serialization_format = SERIALIZATION_FORMAT_PICKLE 579 skops_trusted_types = None 580 581 if not os.path.isfile(path): 582 pyfunc_flavor_conf = _get_flavor_configuration( 583 model_path=path, flavor_name=pyfunc.FLAVOR_NAME 584 ) 585 path = os.path.join(path, pyfunc_flavor_conf["model_path"]) 586 587 return _SklearnModelWrapper( 588 _load_model_from_local_file( 589 path=path, 590 serialization_format=serialization_format, 591 skops_trusted_types=skops_trusted_types, 592 ) 593 ) 594 595 596 class _SklearnModelWrapper: 597 _SUPPORTED_CUSTOM_PREDICT_FN = [ 598 "predict_proba", 599 "predict_log_proba", 600 "predict_joint_log_proba", 601 "score", 602 ] 603 604 def __init__(self, sklearn_model): 605 self.sklearn_model = sklearn_model 606 607 # Patch the model with custom predict functions that can be specified 608 # via `pyfunc_predict_fn` argument when saving or logging. 609 for predict_fn in self._SUPPORTED_CUSTOM_PREDICT_FN: 610 if fn := getattr(self.sklearn_model, predict_fn, None): 611 setattr(self, predict_fn, fn) 612 613 def get_raw_model(self): 614 """ 615 Returns the underlying scikit-learn model. 616 """ 617 return self.sklearn_model 618 619 def predict( 620 self, 621 data, 622 params: dict[str, Any] | None = None, 623 ): 624 """ 625 Args: 626 data: Model input data. 627 params: Additional parameters to pass to the model for inference. 628 629 Returns: 630 Model predictions. 631 """ 632 return self.sklearn_model.predict(data) 633 634 635 class _SklearnCustomModelPicklingError(pickle.PicklingError): 636 """ 637 Exception for describing error raised during pickling custom sklearn estimator 638 """ 639 640 def __init__(self, sk_model, original_exception): 641 """ 642 Args: 643 sk_model: The custom sklearn model to be pickled 644 original_exception: The original exception raised 645 """ 646 super().__init__( 647 f"Pickling custom sklearn model {sk_model.__class__.__name__} failed " 648 f"when saving model: {original_exception}" 649 ) 650 self.original_exception = original_exception 651 652 653 def _dump_model(pickle_lib, sk_model, out): 654 try: 655 # Using python's default protocol to optimize compatibility. 656 # Otherwise cloudpickle uses latest protocol leading to incompatibilities. 657 # See https://github.com/mlflow/mlflow/issues/5419 658 pickle_lib.dump(sk_model, out, protocol=pickle.DEFAULT_PROTOCOL) 659 except (pickle.PicklingError, TypeError, AttributeError) as e: 660 if sk_model.__class__ not in _gen_estimators_to_patch(): 661 raise _SklearnCustomModelPicklingError(sk_model, e) 662 else: 663 raise 664 665 666 def _save_model(sk_model, output_path, serialization_format, skops_trusted_types): 667 """ 668 Args: 669 sk_model: The scikit-learn model to serialize. 670 output_path: The file path to which to write the serialized model. 671 serialization_format: The format in which to serialize the model. This should be one of 672 the following: ``mlflow.sklearn.SERIALIZATION_FORMAT_PICKLE`` or 673 ``mlflow.sklearn.SERIALIZATION_FORMAT_CLOUDPICKLE``. 674 skops_trusted_types: A list of trusted types when loading model that is saved as 675 the ``mlflow.sklearn.SERIALIZATION_FORMAT_SKOPS`` format. 676 """ 677 if serialization_format == SERIALIZATION_FORMAT_SKOPS: 678 import skops.io 679 from skops.io.exceptions import UntrustedTypesFoundException 680 681 try: 682 skops.io.dump(sk_model, output_path) 683 skops.io.load(output_path, trusted=skops_trusted_types) 684 except UntrustedTypesFoundException as e: 685 shutil.rmtree(output_path, ignore_errors=True) 686 raise MlflowException( 687 "The saved sklearn model references untrusted types. " 688 "If you are sure loading these types is safe, " 689 "set the 'skops_trusted_types' parameter when calling 'log_model' or 'save_model' " 690 "to the list of trusted types. " 691 f"Root error: {e!s}" 692 ) 693 except Exception as e: 694 shutil.rmtree(output_path, ignore_errors=True) 695 raise MlflowException( 696 "The sklearn model could not be serialized in the skops serialization format. " 697 "skops does not support custom functions or classes that are not defined at the " 698 "top level. To work around this limitation, you can set the serialization_format " 699 "'cloudpickle', while exercising caution due to the possible arbitrary " 700 "code during model deserialization using CloudPickle." 701 ) from e 702 return 703 704 with open(output_path, "wb") as out: 705 if serialization_format == SERIALIZATION_FORMAT_PICKLE: 706 _dump_model(pickle, sk_model, out) 707 elif serialization_format == SERIALIZATION_FORMAT_CLOUDPICKLE: 708 import cloudpickle 709 710 _dump_model(cloudpickle, sk_model, out) 711 else: 712 raise MlflowException( 713 message=f"Unrecognized serialization format: {serialization_format}", 714 error_code=INTERNAL_ERROR, 715 ) 716 717 718 def load_model(model_uri, dst_path=None): 719 """ 720 Load a scikit-learn model from a local file or a run. 721 722 Args: 723 model_uri: The location, in URI format, of the MLflow model, for example: 724 725 - ``/Users/me/path/to/local/model`` 726 - ``relative/path/to/local/model`` 727 - ``s3://my_bucket/path/to/model`` 728 - ``runs:/<mlflow_run_id>/run-relative/path/to/model`` 729 - ``models:/<model_name>/<model_version>`` 730 - ``models:/<model_name>/<stage>`` 731 732 For more information about supported URI schemes, see 733 `Referencing Artifacts <https://www.mlflow.org/docs/latest/concepts.html# 734 artifact-locations>`_. 735 dst_path: The local filesystem path to which to download the model artifact. 736 This directory must already exist. If unspecified, a local output 737 path will be created. 738 739 Returns: 740 A scikit-learn model. 741 742 .. code-block:: python 743 :caption: Example 744 745 import mlflow.sklearn 746 747 sk_model = mlflow.sklearn.load_model("runs:/96771d893a5e46159d9f3b49bf9013e2/sk_models") 748 749 # use Pandas DataFrame to make predictions 750 pandas_df = ... 751 predictions = sk_model.predict(pandas_df) 752 """ 753 local_model_path = _download_artifact_from_uri(artifact_uri=model_uri, output_path=dst_path) 754 flavor_conf = _get_flavor_configuration(model_path=local_model_path, flavor_name=FLAVOR_NAME) 755 _add_code_from_conf_to_system_path(local_model_path, flavor_conf) 756 sklearn_model_artifacts_path = os.path.join(local_model_path, flavor_conf["pickled_model"]) 757 serialization_format = flavor_conf.get("serialization_format", SERIALIZATION_FORMAT_PICKLE) 758 skops_trusted_types = flavor_conf.get("skops_trusted_types", None) 759 return _load_model_from_local_file( 760 path=sklearn_model_artifacts_path, 761 serialization_format=serialization_format, 762 skops_trusted_types=skops_trusted_types, 763 ) 764 765 766 # The `_apis_autologging_disabled` contains APIs which is incompatible with autologging, 767 # when user call these APIs, autolog is temporarily disabled. 768 _apis_autologging_disabled = [ 769 "cross_validate", 770 "cross_val_predict", 771 "cross_val_score", 772 "learning_curve", 773 "permutation_test_score", 774 "validation_curve", 775 ] 776 777 778 class _AutologgingMetricsManager: 779 """ 780 This class is designed for holding information which is used by autologging metrics 781 It will hold information of: 782 (1) a map of "prediction result object id" to a tuple of dataset name(the dataset is 783 the one which generate the prediction result) and run_id. 784 Note: We need this map instead of setting the run_id into the "prediction result object" 785 because the object maybe a numpy array which does not support additional attribute 786 assignment. 787 (2) _log_post_training_metrics_enabled flag, in the following method scope: 788 `model.fit` and `model.score`, in order to avoid nested/duplicated autologging metric, when 789 run into these scopes, we need temporarily disable the metric autologging. 790 (3) _eval_dataset_info_map, it is a double level map: 791 `_eval_dataset_info_map[run_id][eval_dataset_var_name]` will get a list, each 792 element in the list is an id of "eval_dataset" instance. 793 This data structure is used for: 794 * generating unique dataset name key when autologging metric. For each eval dataset object, 795 if they have the same eval_dataset_var_name, but object ids are different, 796 then they will be assigned different name (via appending index to the 797 eval_dataset_var_name) when autologging. 798 (4) _metric_api_call_info, it is a double level map: 799 `_metric_api_call_info[run_id][metric_name]` will get a list of tuples, each tuple is: 800 (logged_metric_key, metric_call_command_string) 801 each call command string is like `metric_fn(arg1, arg2, ...)` 802 This data structure is used for: 803 * storing the call arguments dict for each metric call, we need log them into metric_info 804 artifact file. 805 806 Note: this class is not thread-safe. 807 Design rule for this class: 808 Because this class instance is a global instance, in order to prevent memory leak, it should 809 only holds IDs and other small objects references. This class internal data structure should 810 avoid reference to user dataset variables or model variables. 811 """ 812 813 def __init__(self): 814 self._pred_result_id_mapping = {} 815 self._eval_dataset_info_map = defaultdict(lambda: defaultdict(list)) 816 self._metric_api_call_info = defaultdict(lambda: defaultdict(list)) 817 self._log_post_training_metrics_enabled = True 818 self._metric_info_artifact_need_update = defaultdict(lambda: False) 819 self._model_id_mapping = {} 820 821 def should_log_post_training_metrics(self): 822 """ 823 Check whether we should run patching code for autologging post training metrics. 824 This checking should surround the whole patched code due to the safe guard checking, 825 See following note. 826 827 Note: It includes checking `_SklearnTrainingSession.is_active()`, This is a safe guarding 828 for meta-estimator (e.g. GridSearchCV) case: 829 running GridSearchCV.fit, the nested `estimator.fit` will be called in parallel, 830 but, the _autolog_training_status is a global status without thread-safe lock protecting. 831 This safe guarding will prevent code run into this case. 832 """ 833 return not _SklearnTrainingSession.is_active() and self._log_post_training_metrics_enabled 834 835 def disable_log_post_training_metrics(self): 836 class LogPostTrainingMetricsDisabledScope: 837 def __enter__(inner_self): 838 inner_self.old_status = self._log_post_training_metrics_enabled 839 self._log_post_training_metrics_enabled = False 840 841 def __exit__(inner_self, exc_type, exc_val, exc_tb): 842 self._log_post_training_metrics_enabled = inner_self.old_status 843 844 return LogPostTrainingMetricsDisabledScope() 845 846 @staticmethod 847 def get_run_id_for_model(model): 848 return getattr(model, "_mlflow_run_id", None) 849 850 @staticmethod 851 def is_metric_value_loggable(metric_value): 852 """ 853 Check whether the specified `metric_value` is a numeric value which can be logged 854 as an MLflow metric. 855 """ 856 return isinstance(metric_value, (int, float, np.number)) and not isinstance( 857 metric_value, bool 858 ) 859 860 def register_model(self, model, run_id): 861 """ 862 In `patched_fit`, we need register the model with the run_id used in `patched_fit` 863 So that in following metric autologging, the metric will be logged into the registered 864 run_id 865 """ 866 model._mlflow_run_id = run_id 867 868 def record_model_id(self, model, model_id): 869 """ 870 Record the id(model) -> model_id mapping so that we can log metrics to the 871 model later. 872 """ 873 self._model_id_mapping[id(model)] = model_id 874 875 def get_model_id_for_model(self, model) -> str | None: 876 return self._model_id_mapping.get(id(model)) 877 878 @staticmethod 879 def gen_name_with_index(name, index): 880 assert index >= 0 881 if index == 0: 882 return name 883 else: 884 # Use '-' as the separator between name and index, 885 # The '-' is not valid character in python var name 886 # so it can prevent name conflicts after appending index. 887 return f"{name}-{index + 1}" 888 889 def register_prediction_input_dataset(self, model, eval_dataset): 890 """ 891 Register prediction input dataset into eval_dataset_info_map, it will do: 892 1. inspect eval dataset var name. 893 2. check whether eval_dataset_info_map already registered this eval dataset. 894 will check by object id. 895 3. register eval dataset with id. 896 4. return eval dataset name with index. 897 898 Note: this method include inspecting argument variable name. 899 So should be called directly from the "patched method", to ensure it capture 900 correct argument variable name. 901 """ 902 eval_dataset_name = _inspect_original_var_name( 903 eval_dataset, fallback_name="unknown_dataset" 904 ) 905 eval_dataset_id = id(eval_dataset) 906 907 run_id = self.get_run_id_for_model(model) 908 registered_dataset_list = self._eval_dataset_info_map[run_id][eval_dataset_name] 909 910 for i, id_i in enumerate(registered_dataset_list): 911 if eval_dataset_id == id_i: 912 index = i 913 break 914 else: 915 index = len(registered_dataset_list) 916 917 if index == len(registered_dataset_list): 918 # register new eval dataset 919 registered_dataset_list.append(eval_dataset_id) 920 921 return self.gen_name_with_index(eval_dataset_name, index) 922 923 def register_prediction_result(self, run_id, eval_dataset_name, predict_result, model_id=None): 924 """ 925 Register the relationship 926 id(prediction_result) --> (eval_dataset_name, run_id, model_id) 927 into map `_pred_result_id_mapping` 928 """ 929 value = (eval_dataset_name, run_id, model_id) 930 prediction_result_id = id(predict_result) 931 self._pred_result_id_mapping[prediction_result_id] = value 932 933 def clean_id(id_): 934 _AUTOLOGGING_METRICS_MANAGER._pred_result_id_mapping.pop(id_, None) 935 936 # When the `predict_result` object being GCed, its ID may be reused, so register a finalizer 937 # to clear the ID from the dict for preventing wrong ID mapping. 938 weakref.finalize(predict_result, clean_id, prediction_result_id) 939 940 @staticmethod 941 def gen_metric_call_command(self_obj, metric_fn, *call_pos_args, **call_kwargs): 942 """ 943 Generate metric function call command string like `metric_fn(arg1, arg2, ...)` 944 Note: this method include inspecting argument variable name. 945 So should be called directly from the "patched method", to ensure it capture 946 correct argument variable name. 947 948 Args: 949 self_obj: If the metric_fn is a method of an instance (e.g. `model.score`), 950 the `self_obj` represent the instance. 951 metric_fn: metric function. 952 call_pos_args: the positional arguments of the metric function call. If `metric_fn` 953 is instance method, then the `call_pos_args` should exclude the first `self` 954 argument. 955 call_kwargs: the keyword arguments of the metric function call. 956 """ 957 958 arg_list = [] 959 960 def arg_to_str(arg): 961 if arg is None or np.isscalar(arg): 962 if isinstance(arg, str) and len(arg) > 32: 963 # truncate too long string 964 return repr(arg[:32] + "...") 965 return repr(arg) 966 else: 967 # dataset arguments or other non-scalar type argument 968 return _inspect_original_var_name(arg, fallback_name=f"<{arg.__class__.__name__}>") 969 970 param_sig = inspect.signature(metric_fn).parameters 971 arg_names = list(param_sig.keys()) 972 973 if self_obj is not None: 974 # If metric_fn is a method of an instance, e.g. `model.score`, 975 # then the first argument is `self` which we need exclude it. 976 arg_names.pop(0) 977 978 if self_obj is not None: 979 call_fn_name = f"{self_obj.__class__.__name__}.{metric_fn.__name__}" 980 else: 981 call_fn_name = metric_fn.__name__ 982 983 # Attach param signature key for positinal param values 984 for arg_name, arg in zip(arg_names, call_pos_args): 985 arg_list.append(f"{arg_name}={arg_to_str(arg)}") 986 987 for arg_name, arg in call_kwargs.items(): 988 arg_list.append(f"{arg_name}={arg_to_str(arg)}") 989 990 arg_list_str = ", ".join(arg_list) 991 992 return f"{call_fn_name}({arg_list_str})" 993 994 def register_metric_api_call(self, run_id, metric_name, dataset_name, call_command): 995 """ 996 This method will do: 997 (1) Generate and return metric key, format is: 998 {metric_name}[-{call_index}]_{eval_dataset_name} 999 metric_name is generated by metric function name, if multiple calls on the same 1000 metric API happen, the following calls will be assigned with an increasing "call index". 1001 (2) Register the metric key with the "call command" information into 1002 `_AUTOLOGGING_METRICS_MANAGER`. See doc of `gen_metric_call_command` method for 1003 details of "call command". 1004 """ 1005 1006 call_cmd_list = self._metric_api_call_info[run_id][metric_name] 1007 1008 index = len(call_cmd_list) 1009 metric_name_with_index = self.gen_name_with_index(metric_name, index) 1010 metric_key = f"{metric_name_with_index}_{dataset_name}" 1011 1012 call_cmd_list.append((metric_key, call_command)) 1013 1014 # Set the flag to true, represent the metric info in this run need update. 1015 # Later when `log_eval_metric` called, it will generate a new metric_info artifact 1016 # and overwrite the old artifact. 1017 self._metric_info_artifact_need_update[run_id] = True 1018 return metric_key 1019 1020 def get_info_for_metric_api_call(self, call_pos_args, call_kwargs): 1021 """ 1022 Given a metric api call (include the called metric function, and call arguments) 1023 Register the call information (arguments dict) into the `metric_api_call_arg_dict_list_map` 1024 and return a tuple of (run_id, eval_dataset_name, model_id) 1025 """ 1026 call_arg_list = list(call_pos_args) + list(call_kwargs.values()) 1027 1028 dataset_id_list = self._pred_result_id_mapping.keys() 1029 1030 # Note: some metric API the arguments is not like `y_true`, `y_pred` 1031 # e.g. 1032 # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score 1033 # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.silhouette_score.html#sklearn.metrics.silhouette_score 1034 for arg in call_arg_list: 1035 if arg is not None and not np.isscalar(arg) and id(arg) in dataset_id_list: 1036 dataset_name, run_id, model_id = self._pred_result_id_mapping[id(arg)] 1037 break 1038 else: 1039 return None, None, None 1040 1041 return run_id, dataset_name, model_id 1042 1043 def log_post_training_metric(self, run_id, key, value, model_id=None): 1044 """ 1045 Log the metric into the specified mlflow run. 1046 and it will also update the metric_info artifact if needed. 1047 If model_id is not None, metrics are logged into the model as well. 1048 """ 1049 # Note: if the case log the same metric key multiple times, 1050 # newer value will overwrite old value 1051 client = MlflowClient() 1052 client.log_metric(run_id=run_id, key=key, value=value, model_id=model_id) 1053 if self._metric_info_artifact_need_update[run_id]: 1054 call_commands_list = [] 1055 for v in self._metric_api_call_info[run_id].values(): 1056 call_commands_list.extend(v) 1057 1058 call_commands_list.sort(key=lambda x: x[0]) 1059 dict_to_log = OrderedDict(call_commands_list) 1060 client.log_dict(run_id=run_id, dictionary=dict_to_log, artifact_file="metric_info.json") 1061 self._metric_info_artifact_need_update[run_id] = False 1062 1063 1064 # The global `_AutologgingMetricsManager` instance which holds information used in 1065 # post-training metric autologging. See doc of class `_AutologgingMetricsManager` for details. 1066 _AUTOLOGGING_METRICS_MANAGER = _AutologgingMetricsManager() 1067 1068 1069 _metric_api_excluding_list = ["check_scoring", "get_scorer", "make_scorer", "get_scorer_names"] 1070 1071 1072 def _get_metric_name_list(): 1073 """ 1074 Return metric function name list in `sklearn.metrics` module 1075 """ 1076 from sklearn import metrics 1077 1078 metric_list = [] 1079 for metric_method_name in metrics.__all__: 1080 # excludes plot_* methods 1081 # exclude class (e.g. metrics.ConfusionMatrixDisplay) 1082 metric_method = getattr(metrics, metric_method_name) 1083 if ( 1084 metric_method_name not in _metric_api_excluding_list 1085 and not inspect.isclass(metric_method) 1086 and callable(metric_method) 1087 and not metric_method_name.startswith("plot_") 1088 ): 1089 metric_list.append(metric_method_name) 1090 return metric_list 1091 1092 1093 def _patch_estimator_method_if_available( 1094 flavor_name, class_def, func_name, patched_fn, manage_run, extra_tags=None 1095 ): 1096 if not hasattr(class_def, func_name): 1097 return 1098 1099 original = gorilla.get_original_attribute( 1100 class_def, func_name, bypass_descriptor_protocol=False 1101 ) 1102 # Retrieve raw attribute while bypassing the descriptor protocol 1103 raw_original_obj = gorilla.get_original_attribute( 1104 class_def, func_name, bypass_descriptor_protocol=True 1105 ) 1106 if raw_original_obj == original and (callable(original) or isinstance(original, property)): 1107 # normal method or property decorated method 1108 safe_patch( 1109 flavor_name, 1110 class_def, 1111 func_name, 1112 patched_fn, 1113 manage_run=manage_run, 1114 extra_tags=extra_tags, 1115 ) 1116 elif hasattr(raw_original_obj, "delegate_names") or hasattr(raw_original_obj, "check"): 1117 # sklearn delegated method 1118 safe_patch( 1119 flavor_name, 1120 raw_original_obj, 1121 "fn", 1122 patched_fn, 1123 manage_run=manage_run, 1124 extra_tags=extra_tags, 1125 ) 1126 else: 1127 # unsupported method type. skip patching 1128 pass 1129 1130 1131 @autologging_integration(FLAVOR_NAME) 1132 def autolog( 1133 log_input_examples=False, 1134 log_model_signatures=True, 1135 log_models=True, 1136 log_datasets=True, 1137 disable=False, 1138 exclusive=False, 1139 disable_for_unsupported_versions=False, 1140 silent=False, 1141 max_tuning_runs=5, 1142 log_post_training_metrics=True, 1143 serialization_format=SERIALIZATION_FORMAT_CLOUDPICKLE, 1144 registered_model_name=None, 1145 pos_label=None, 1146 extra_tags=None, 1147 ): 1148 """ 1149 Enables (or disables) and configures autologging for scikit-learn estimators. 1150 1151 **When is autologging performed?** 1152 Autologging is performed when you call: 1153 1154 - ``estimator.fit()`` 1155 - ``estimator.fit_predict()`` 1156 - ``estimator.fit_transform()`` 1157 1158 **Logged information** 1159 **Parameters** 1160 - Parameters obtained by ``estimator.get_params(deep=True)``. Note that ``get_params`` 1161 is called with ``deep=True``. This means when you fit a meta estimator that chains 1162 a series of estimators, the parameters of these child estimators are also logged. 1163 1164 **Training metrics** 1165 - A training score obtained by ``estimator.score``. Note that the training score is 1166 computed using parameters given to ``fit()``. 1167 - Common metrics for classifier: 1168 1169 - `precision score`_ 1170 1171 .. _precision score: 1172 https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_score.html 1173 1174 - `recall score`_ 1175 1176 .. _recall score: 1177 https://scikit-learn.org/stable/modules/generated/sklearn.metrics.recall_score.html 1178 1179 - `f1 score`_ 1180 1181 .. _f1 score: 1182 https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html 1183 1184 - `accuracy score`_ 1185 1186 .. _accuracy score: 1187 https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html 1188 1189 If the classifier has method ``predict_proba``, we additionally log: 1190 1191 - `log loss`_ 1192 1193 .. _log loss: 1194 https://scikit-learn.org/stable/modules/generated/sklearn.metrics.log_loss.html 1195 1196 - `roc auc score`_ 1197 1198 .. _roc auc score: 1199 https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html 1200 1201 - Common metrics for regressor: 1202 1203 - `mean squared error`_ 1204 1205 .. _mean squared error: 1206 https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html 1207 1208 - root mean squared error 1209 1210 - `mean absolute error`_ 1211 1212 .. _mean absolute error: 1213 https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_error.html 1214 1215 - `r2 score`_ 1216 1217 .. _r2 score: 1218 https://scikit-learn.org/stable/modules/generated/sklearn.metrics.r2_score.html 1219 1220 .. _post training metrics: 1221 1222 **Post training metrics** 1223 When users call metric APIs after model training, MLflow tries to capture the metric API 1224 results and log them as MLflow metrics to the Run associated with the model. The following 1225 types of scikit-learn metric APIs are supported: 1226 1227 - model.score 1228 - metric APIs defined in the `sklearn.metrics` module 1229 1230 For post training metrics autologging, the metric key format is: 1231 "{metric_name}[-{call_index}]_{dataset_name}" 1232 1233 - If the metric function is from `sklearn.metrics`, the MLflow "metric_name" is the 1234 metric function name. If the metric function is `model.score`, then "metric_name" is 1235 "{model_class_name}_score". 1236 - If multiple calls are made to the same scikit-learn metric API, each subsequent call 1237 adds a "call_index" (starting from 2) to the metric key. 1238 - MLflow uses the prediction input dataset variable name as the "dataset_name" in the 1239 metric key. The "prediction input dataset variable" refers to the variable which was 1240 used as the first argument of the associated `model.predict` or `model.score` call. 1241 Note: MLflow captures the "prediction input dataset" instance in the outermost call 1242 frame and fetches the variable name in the outermost call frame. If the "prediction 1243 input dataset" instance is an intermediate expression without a defined variable 1244 name, the dataset name is set to "unknown_dataset". If multiple "prediction input 1245 dataset" instances have the same variable name, then subsequent ones will append an 1246 index (starting from 2) to the inspected dataset name. 1247 1248 **Limitations** 1249 - MLflow can only map the original prediction result object returned by a model 1250 prediction API (including predict / predict_proba / predict_log_proba / transform, 1251 but excluding fit_predict / fit_transform.) to an MLflow run. 1252 MLflow cannot find run information 1253 for other objects derived from a given prediction result (e.g. by copying or selecting 1254 a subset of the prediction result). scikit-learn metric APIs invoked on derived objects 1255 do not log metrics to MLflow. 1256 - Autologging must be enabled before scikit-learn metric APIs are imported from 1257 `sklearn.metrics`. Metric APIs imported before autologging is enabled do not log 1258 metrics to MLflow runs. 1259 - If user define a scorer which is not based on metric APIs in `sklearn.metrics`, then 1260 then post training metric autologging for the scorer is invalid. 1261 1262 **Tags** 1263 - An estimator class name (e.g. "LinearRegression"). 1264 - A fully qualified estimator class name 1265 (e.g. "sklearn.linear_model._base.LinearRegression"). 1266 1267 **Artifacts** 1268 - An MLflow Model with the :py:mod:`mlflow.sklearn` flavor containing a fitted estimator 1269 (logged by :py:func:`mlflow.sklearn.log_model()`). The Model also contains the 1270 :py:mod:`mlflow.pyfunc` flavor when the scikit-learn estimator defines `predict()`. 1271 - For post training metrics API calls, a "metric_info.json" artifact is logged. This is a 1272 JSON object whose keys are MLflow post training metric names 1273 (see "Post training metrics" section for the key format) and whose values are the 1274 corresponding metric call commands that produced the metrics, e.g. 1275 ``accuracy_score(y_true=test_iris_y, y_pred=pred_iris_y, normalize=False)``. 1276 1277 **How does autologging work for meta estimators?** 1278 When a meta estimator (e.g. `Pipeline`_, `GridSearchCV`_) calls ``fit()``, it internally calls 1279 ``fit()`` on its child estimators. Autologging does NOT perform logging on these constituent 1280 ``fit()`` calls. 1281 1282 **Parameter search** 1283 In addition to recording the information discussed above, autologging for parameter 1284 search meta estimators (`GridSearchCV`_ and `RandomizedSearchCV`_) records child runs 1285 with metrics for each set of explored parameters, as well as artifacts and parameters 1286 for the best model (if available). 1287 1288 **Supported estimators** 1289 - All estimators obtained by `sklearn.utils.all_estimators`_ (including meta estimators). 1290 - `Pipeline`_ 1291 - Parameter search estimators (`GridSearchCV`_ and `RandomizedSearchCV`_) 1292 1293 .. _sklearn.utils.all_estimators: 1294 https://scikit-learn.org/stable/modules/generated/sklearn.utils.all_estimators.html 1295 1296 .. _Pipeline: 1297 https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html 1298 1299 .. _GridSearchCV: 1300 https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html 1301 1302 .. _RandomizedSearchCV: 1303 https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html 1304 1305 **Example** 1306 1307 `See more examples <https://github.com/mlflow/mlflow/blob/master/examples/sklearn_autolog>`_ 1308 1309 .. code-block:: python 1310 1311 from pprint import pprint 1312 import numpy as np 1313 from sklearn.linear_model import LinearRegression 1314 import mlflow 1315 from mlflow import MlflowClient 1316 1317 1318 def fetch_logged_data(run_id): 1319 client = MlflowClient() 1320 data = client.get_run(run_id).data 1321 tags = {k: v for k, v in data.tags.items() if not k.startswith("mlflow.")} 1322 artifacts = [f.path for f in client.list_artifacts(run_id, "model")] 1323 return data.params, data.metrics, tags, artifacts 1324 1325 1326 # enable autologging 1327 mlflow.sklearn.autolog() 1328 1329 # prepare training data 1330 X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]]) 1331 y = np.dot(X, np.array([1, 2])) + 3 1332 1333 # train a model 1334 model = LinearRegression() 1335 with mlflow.start_run() as run: 1336 model.fit(X, y) 1337 1338 # fetch logged data 1339 params, metrics, tags, artifacts = fetch_logged_data(run.info.run_id) 1340 1341 pprint(params) 1342 # {'copy_X': 'True', 1343 # 'fit_intercept': 'True', 1344 # 'n_jobs': 'None', 1345 # 'normalize': 'False'} 1346 1347 pprint(metrics) 1348 # {'training_score': 1.0, 1349 # 'training_mean_absolute_error': 2.220446049250313e-16, 1350 # 'training_mean_squared_error': 1.9721522630525295e-31, 1351 # 'training_r2_score': 1.0, 1352 # 'training_root_mean_squared_error': 4.440892098500626e-16} 1353 1354 pprint(tags) 1355 # {'estimator_class': 'sklearn.linear_model._base.LinearRegression', 1356 # 'estimator_name': 'LinearRegression'} 1357 1358 pprint(artifacts) 1359 # ['model/MLmodel', 'model/conda.yaml', 'model/model.pkl'] 1360 1361 Args: 1362 log_input_examples: If ``True``, input examples from training datasets are collected and 1363 logged along with scikit-learn model artifacts during training. If 1364 ``False``, input examples are not logged. 1365 Note: Input examples are MLflow model attributes 1366 and are only collected if ``log_models`` is also ``True``. 1367 log_model_signatures: If ``True``, 1368 :py:class:`ModelSignatures <mlflow.models.ModelSignature>` 1369 describing model inputs and outputs are collected and logged along 1370 with scikit-learn model artifacts during training. If ``False``, 1371 signatures are not logged. 1372 Note: Model signatures are MLflow model attributes 1373 and are only collected if ``log_models`` is also ``True``. 1374 log_models: If ``True``, trained models are logged as MLflow model artifacts. 1375 If ``False``, trained models are not logged. 1376 Input examples and model signatures, which are attributes of MLflow models, 1377 are also omitted when ``log_models`` is ``False``. 1378 log_datasets: If ``True``, train and validation dataset information is logged to MLflow 1379 Tracking if applicable. If ``False``, dataset information is not logged. 1380 disable: If ``True``, disables the scikit-learn autologging integration. If ``False``, 1381 enables the scikit-learn autologging integration. 1382 exclusive: If ``True``, autologged content is not logged to user-created fluent runs. 1383 If ``False``, autologged content is logged to the active fluent run, 1384 which may be user-created. 1385 disable_for_unsupported_versions: If ``True``, disable autologging for versions of 1386 scikit-learn that have not been tested against this version of the MLflow 1387 client or are incompatible. 1388 silent: If ``True``, suppress all event logs and warnings from MLflow during scikit-learn 1389 autologging. If ``False``, show all events and warnings during scikit-learn 1390 autologging. 1391 max_tuning_runs: The maximum number of child MLflow runs created for hyperparameter 1392 search estimators. To create child runs for the best `k` results from 1393 the search, set `max_tuning_runs` to `k`. The default value is to track 1394 the best 5 search parameter sets. If `max_tuning_runs=None`, then 1395 a child run is created for each search parameter set. Note: The best k 1396 results is based on ordering in `rank_test_score`. In the case of 1397 multi-metric evaluation with a custom scorer, the first scorer's 1398 `rank_test_score_<scorer_name>` will be used to select the best k 1399 results. To change metric used for selecting best k results, change 1400 ordering of dict passed as `scoring` parameter for estimator. 1401 log_post_training_metrics: If ``True``, post training metrics are logged. Defaults to 1402 ``True``. See the `post training metrics`_ section for more 1403 details. 1404 serialization_format: The format in which to serialize the model. This should be one of 1405 the following: "pickle", "cloudpickle" or "skops". 1406 registered_model_name: If given, each time a model is trained, it is registered as a 1407 new model version of the registered model with this name. 1408 The registered model is created if it does not already exist. 1409 pos_label: If given, used as the positive label to compute binary classification 1410 training metrics such as precision, recall, f1, etc. This parameter should 1411 only be set for binary classification model. If used for multi-label model, 1412 the training metrics calculation will fail and the training metrics won't 1413 be logged. If used for regression model, the parameter will be ignored. 1414 extra_tags: A dictionary of extra tags to set on each managed run created by autologging. 1415 """ 1416 _autolog( 1417 flavor_name=FLAVOR_NAME, 1418 log_input_examples=log_input_examples, 1419 log_model_signatures=log_model_signatures, 1420 log_models=log_models, 1421 log_datasets=log_datasets, 1422 disable=disable, 1423 exclusive=exclusive, 1424 disable_for_unsupported_versions=disable_for_unsupported_versions, 1425 silent=silent, 1426 max_tuning_runs=max_tuning_runs, 1427 log_post_training_metrics=log_post_training_metrics, 1428 serialization_format=serialization_format, 1429 pos_label=pos_label, 1430 extra_tags=extra_tags, 1431 ) 1432 1433 1434 def _autolog( 1435 flavor_name=FLAVOR_NAME, 1436 log_input_examples=False, 1437 log_model_signatures=True, 1438 log_models=True, 1439 log_datasets=True, 1440 disable=False, 1441 exclusive=False, 1442 disable_for_unsupported_versions=False, 1443 silent=False, 1444 max_tuning_runs=5, 1445 log_post_training_metrics=True, 1446 serialization_format=SERIALIZATION_FORMAT_CLOUDPICKLE, 1447 pos_label=None, 1448 extra_tags=None, 1449 ): 1450 """ 1451 Internal autologging function for scikit-learn models. 1452 1453 Args: 1454 flavor_name: A string value. Enable a ``mlflow.sklearn`` autologging routine 1455 for a flavor. By default it enables autologging for original 1456 scikit-learn models, as ``mlflow.sklearn.autolog()`` does. If 1457 the argument is `xgboost`, autologging for XGBoost scikit-learn 1458 models is enabled. 1459 """ 1460 import pandas as pd 1461 import sklearn.metrics 1462 import sklearn.model_selection 1463 1464 from mlflow.models import infer_signature 1465 from mlflow.sklearn.utils import ( 1466 _TRAINING_PREFIX, 1467 _create_child_runs_for_parameter_search, 1468 _gen_lightgbm_sklearn_estimators_to_patch, 1469 _gen_xgboost_sklearn_estimators_to_patch, 1470 _get_estimator_info_tags, 1471 _get_X_y_and_sample_weight, 1472 _is_parameter_search_estimator, 1473 _log_estimator_content, 1474 _log_parameter_search_results_as_artifact, 1475 ) 1476 from mlflow.tracking.context import registry as context_registry 1477 1478 if max_tuning_runs is not None and max_tuning_runs < 0: 1479 raise MlflowException( 1480 message=f"`max_tuning_runs` must be non-negative, instead got {max_tuning_runs}.", 1481 error_code=INVALID_PARAMETER_VALUE, 1482 ) 1483 1484 def fit_mlflow_xgboost_and_lightgbm(original, self, *args, **kwargs): 1485 """ 1486 Autologging function for XGBoost and LightGBM scikit-learn models 1487 """ 1488 # Obtain a copy of a model input example from the training dataset prior to model training 1489 # for subsequent use during model logging, ensuring that the input example and inferred 1490 # model signature to not include any mutations from model training 1491 input_example_exc = None 1492 try: 1493 input_example = deepcopy( 1494 _get_X_y_and_sample_weight(self.fit, args, kwargs)[0][:INPUT_EXAMPLE_SAMPLE_ROWS] 1495 ) 1496 except Exception as e: 1497 input_example_exc = e 1498 1499 def get_input_example(): 1500 if input_example_exc is not None: 1501 raise input_example_exc 1502 else: 1503 return input_example 1504 1505 # parameter, metric, and non-model artifact logging are done in 1506 # `train()` in `mlflow.xgboost.autolog()` and `mlflow.lightgbm.autolog()` 1507 fit_output = original(self, *args, **kwargs) 1508 # log models after training 1509 if log_models: 1510 input_example, signature = resolve_input_example_and_signature( 1511 get_input_example, 1512 lambda input_example: infer_signature( 1513 input_example, 1514 # Copy the input example so that it is not mutated by the call to 1515 # predict() prior to signature inference 1516 self.predict(deepcopy(input_example)), 1517 ), 1518 log_input_examples, 1519 log_model_signatures, 1520 _logger, 1521 ) 1522 log_model_func = ( 1523 mlflow.xgboost.log_model 1524 if flavor_name == mlflow.xgboost.FLAVOR_NAME 1525 else mlflow.lightgbm.log_model 1526 ) 1527 registered_model_name = get_autologging_config( 1528 flavor_name, "registered_model_name", None 1529 ) 1530 if flavor_name == mlflow.xgboost.FLAVOR_NAME: 1531 model_format = get_autologging_config(flavor_name, "model_format", "ubj") 1532 model_info = log_model_func( 1533 self, 1534 "model", 1535 signature=signature, 1536 input_example=input_example, 1537 registered_model_name=registered_model_name, 1538 model_format=model_format, 1539 ) 1540 else: 1541 model_info = log_model_func( 1542 self, 1543 "model", 1544 signature=signature, 1545 input_example=input_example, 1546 registered_model_name=registered_model_name, 1547 ) 1548 _AUTOLOGGING_METRICS_MANAGER.record_model_id(self, model_info.model_id) 1549 return fit_output 1550 1551 def fit_mlflow(original, self, *args, **kwargs): 1552 """ 1553 Autologging function that performs model training by executing the training method 1554 referred to be `func_name` on the instance of `clazz` referred to by `self` & records 1555 MLflow parameters, metrics, tags, and artifacts to a corresponding MLflow Run. 1556 """ 1557 # Obtain a copy of the training dataset prior to model training for subsequent 1558 # use during model logging & input example extraction, ensuring that we don't 1559 # attempt to infer input examples on data that was mutated during training 1560 (X, y_true, sample_weight) = _get_X_y_and_sample_weight(self.fit, args, kwargs) 1561 autologging_client = MlflowAutologgingQueueingClient() 1562 _log_pretraining_metadata(autologging_client, self, X, y_true) 1563 params_logging_future = autologging_client.flush(synchronous=False) 1564 fit_output = original(self, *args, **kwargs) 1565 _log_posttraining_metadata(autologging_client, self, X, y_true, sample_weight) 1566 autologging_client.flush(synchronous=True) 1567 params_logging_future.await_completion() 1568 return fit_output 1569 1570 def _log_pretraining_metadata(autologging_client, estimator, X, y): 1571 """ 1572 Records metadata (e.g., params and tags) for a scikit-learn estimator prior to training. 1573 This is intended to be invoked within a patched scikit-learn training routine 1574 (e.g., `fit()`, `fit_transform()`, ...) and assumes the existence of an active 1575 MLflow run that can be referenced via the fluent Tracking API. 1576 1577 Args: 1578 autologging_client: An instance of `MlflowAutologgingQueueingClient` used for 1579 efficiently logging run data to MLflow Tracking. 1580 estimator: The scikit-learn estimator for which to log metadata. 1581 """ 1582 # Deep parameter logging includes parameters from children of a given 1583 # estimator. For some meta estimators (e.g., pipelines), recording 1584 # these parameters is desirable. For parameter search estimators, 1585 # however, child estimators act as seeds for the parameter search 1586 # process; accordingly, we avoid logging initial, untuned parameters 1587 # for these seed estimators. 1588 should_log_params_deeply = not _is_parameter_search_estimator(estimator) 1589 run_id = mlflow.active_run().info.run_id 1590 autologging_client.log_params( 1591 run_id=mlflow.active_run().info.run_id, 1592 params=estimator.get_params(deep=should_log_params_deeply), 1593 ) 1594 autologging_client.set_tags( 1595 run_id=run_id, 1596 tags=_get_estimator_info_tags(estimator), 1597 ) 1598 1599 if log_datasets: 1600 try: 1601 context_tags = context_registry.resolve_tags() 1602 source = CodeDatasetSource(context_tags) 1603 1604 if dataset := _create_dataset(X, source, y): 1605 tags = [InputTag(key=MLFLOW_DATASET_CONTEXT, value="train")] 1606 dataset_input = DatasetInput(dataset=dataset._to_mlflow_entity(), tags=tags) 1607 1608 autologging_client.log_inputs( 1609 run_id=mlflow.active_run().info.run_id, datasets=[dataset_input] 1610 ) 1611 except Exception as e: 1612 _logger.warning( 1613 "Failed to log training dataset information to MLflow Tracking. Reason: %s", e 1614 ) 1615 1616 def _log_posttraining_metadata(autologging_client, estimator, X, y, sample_weight): 1617 """ 1618 Records metadata for a scikit-learn estimator after training has completed. 1619 This is intended to be invoked within a patched scikit-learn training routine 1620 (e.g., `fit()`, `fit_transform()`, ...) and assumes the existence of an active 1621 MLflow run that can be referenced via the fluent Tracking API. 1622 1623 Args: 1624 autologging_client: An instance of `MlflowAutologgingQueueingClient` used for 1625 efficiently logging run data to MLflow Tracking. 1626 estimator: The scikit-learn estimator for which to log metadata. 1627 X: The training dataset samples passed to the ``estimator.fit()`` function. 1628 y: The training dataset labels passed to the ``estimator.fit()`` function. 1629 sample_weight: Sample weights passed to the ``estimator.fit()`` function. 1630 """ 1631 # Fetch an input example using the first several rows of the array-like 1632 # training data supplied to the training routine (e.g., `fit()`). Copy the 1633 # example to avoid mutation during subsequent metric computations 1634 input_example_exc = None 1635 try: 1636 input_example = deepcopy(X[:INPUT_EXAMPLE_SAMPLE_ROWS]) 1637 except Exception as e: 1638 input_example_exc = e 1639 1640 def get_input_example(): 1641 if input_example_exc is not None: 1642 raise input_example_exc 1643 else: 1644 return input_example 1645 1646 def infer_model_signature(input_example): 1647 if hasattr(estimator, "predict"): 1648 # Copy the input example so that it is not mutated by the call to 1649 # predict() prior to signature inference 1650 model_output = estimator.predict(deepcopy(input_example)) 1651 elif hasattr(estimator, "transform"): 1652 model_output = estimator.transform(deepcopy(input_example)) 1653 else: 1654 raise Exception( 1655 "the trained model does not have a `predict` or `transform` " 1656 "function, which is required in order to infer the signature" 1657 ) 1658 1659 return infer_signature(input_example, model_output) 1660 1661 def _log_model_with_except_handling(*args, **kwargs): 1662 try: 1663 return log_model(*args, **kwargs) 1664 except _SklearnCustomModelPicklingError as e: 1665 _logger.warning(str(e)) 1666 1667 model_id = None 1668 if log_models: 1669 # Will only resolve `input_example` and `signature` if `log_models` is `True`. 1670 input_example, signature = resolve_input_example_and_signature( 1671 get_input_example, 1672 infer_model_signature, 1673 log_input_examples, 1674 log_model_signatures, 1675 _logger, 1676 ) 1677 registered_model_name = get_autologging_config( 1678 FLAVOR_NAME, "registered_model_name", None 1679 ) 1680 should_log_params_deeply = not _is_parameter_search_estimator(estimator) 1681 params = estimator.get_params(deep=should_log_params_deeply) 1682 if hasattr(estimator, "best_params_"): 1683 params |= { 1684 f"best_{param_name}": param_value 1685 for param_name, param_value in estimator.best_params_.items() 1686 } 1687 if logged_model := _log_model_with_except_handling( 1688 estimator, 1689 name="model", 1690 signature=signature, 1691 input_example=input_example, 1692 serialization_format=serialization_format, 1693 registered_model_name=registered_model_name, 1694 params=params, 1695 ): 1696 model_id = logged_model.model_id 1697 _AUTOLOGGING_METRICS_MANAGER.record_model_id(estimator, logged_model.model_id) 1698 1699 # log common metrics and artifacts for estimators (classifier, regressor) 1700 context_tags = context_registry.resolve_tags() 1701 source = CodeDatasetSource(context_tags) 1702 try: 1703 dataset = _create_dataset(X, source, y) 1704 except Exception: 1705 _logger.debug("Failed to create dataset for logging.", exc_info=True) 1706 dataset = None 1707 logged_metrics = _log_estimator_content( 1708 autologging_client=autologging_client, 1709 estimator=estimator, 1710 prefix=_TRAINING_PREFIX, 1711 run_id=mlflow.active_run().info.run_id, 1712 X=X, 1713 y_true=y, 1714 sample_weight=sample_weight, 1715 pos_label=pos_label, 1716 dataset=dataset, 1717 model_id=model_id, 1718 ) 1719 if y is None and not logged_metrics: 1720 _logger.warning( 1721 "Training metrics will not be recorded because training labels were not specified." 1722 " To automatically record training metrics, provide training labels as inputs to" 1723 " the model training function." 1724 ) 1725 1726 best_estimator_model_id = None 1727 best_estimator_params = None 1728 if _is_parameter_search_estimator(estimator): 1729 if hasattr(estimator, "best_estimator_") and log_models: 1730 best_estimator_params = estimator.best_estimator_.get_params(deep=True) 1731 if model_info := _log_model_with_except_handling( 1732 estimator.best_estimator_, 1733 name="best_estimator", 1734 signature=signature, 1735 input_example=input_example, 1736 serialization_format=serialization_format, 1737 params=best_estimator_params, 1738 ): 1739 best_estimator_model_id = model_info.model_id 1740 1741 if hasattr(estimator, "best_score_"): 1742 autologging_client.log_metrics( 1743 run_id=mlflow.active_run().info.run_id, 1744 metrics={"best_cv_score": estimator.best_score_}, 1745 dataset=dataset, 1746 model_id=model_id, 1747 ) 1748 1749 if hasattr(estimator, "best_params_"): 1750 best_params = { 1751 f"best_{param_name}": param_value 1752 for param_name, param_value in estimator.best_params_.items() 1753 } 1754 autologging_client.log_params( 1755 run_id=mlflow.active_run().info.run_id, 1756 params=best_params, 1757 ) 1758 1759 if hasattr(estimator, "cv_results_"): 1760 try: 1761 # Fetch environment-specific tags (e.g., user and source) to ensure that lineage 1762 # information is consistent with the parent run 1763 child_tags = context_registry.resolve_tags() 1764 child_tags.update({MLFLOW_AUTOLOGGING: flavor_name}) 1765 _create_child_runs_for_parameter_search( 1766 autologging_client=autologging_client, 1767 cv_estimator=estimator, 1768 parent_run=mlflow.active_run(), 1769 max_tuning_runs=max_tuning_runs, 1770 child_tags=child_tags, 1771 dataset=dataset, 1772 best_estimator_params=best_estimator_params, 1773 best_estimator_model_id=best_estimator_model_id, 1774 ) 1775 except Exception as e: 1776 _logger.warning( 1777 "Encountered exception during creation of child runs for parameter search." 1778 f" Child runs may be missing. Exception: {e}" 1779 ) 1780 1781 try: 1782 cv_results_df = pd.DataFrame.from_dict(estimator.cv_results_) 1783 _log_parameter_search_results_as_artifact( 1784 cv_results_df, mlflow.active_run().info.run_id 1785 ) 1786 except Exception as e: 1787 _logger.warning( 1788 f"Failed to log parameter search results as an artifact. Exception: {e}" 1789 ) 1790 1791 def patched_fit(fit_impl, allow_children_patch, original, self, *args, **kwargs): 1792 """ 1793 Autologging patch function to be applied to a sklearn model class that defines a `fit` 1794 method and inherits from `BaseEstimator` (thereby defining the `get_params()` method) 1795 1796 Args: 1797 fit_impl: The patched fit function implementation, the function should be defined as 1798 `fit_mlflow(original, self, *args, **kwargs)`, the `original` argument 1799 refers to the original `EstimatorClass.fit` method, the `self` argument 1800 refers to the estimator instance being patched, the `*args` and 1801 `**kwargs` are arguments passed to the original fit method. 1802 allow_children_patch: Whether to allow children sklearn session logging or not. 1803 original: the original `EstimatorClass.fit` method to be patched. 1804 self: the estimator instance being patched. 1805 args: positional arguments to be passed to the original fit method. 1806 kwargs: keyword arguments to be passed to the original fit method. 1807 """ 1808 should_log_post_training_metrics = ( 1809 log_post_training_metrics 1810 and _AUTOLOGGING_METRICS_MANAGER.should_log_post_training_metrics() 1811 ) 1812 1813 with _SklearnTrainingSession(estimator=self, allow_children=allow_children_patch) as t: 1814 if t.should_log(): 1815 # In `fit_mlflow` call, it will also call metric API for computing training metrics 1816 # so we need temporarily disable the post_training_metrics patching. 1817 with _AUTOLOGGING_METRICS_MANAGER.disable_log_post_training_metrics(): 1818 result = fit_impl(original, self, *args, **kwargs) 1819 if should_log_post_training_metrics: 1820 _AUTOLOGGING_METRICS_MANAGER.register_model( 1821 self, mlflow.active_run().info.run_id 1822 ) 1823 return result 1824 else: 1825 return original(self, *args, **kwargs) 1826 1827 def patched_predict(original, self, *args, **kwargs): 1828 """ 1829 In `patched_predict`, register the prediction result instance with the run id and 1830 eval dataset name. e.g. 1831 ``` 1832 prediction_result = model_1.predict(eval_X) 1833 ``` 1834 then we need register the following relationship into the `_AUTOLOGGING_METRICS_MANAGER`: 1835 id(prediction_result) --> (eval_dataset_name, run_id) 1836 1837 Note: we cannot set additional attributes "eval_dataset_name" and "run_id" into 1838 the prediction_result object, because certain dataset type like numpy does not support 1839 additional attribute assignment. 1840 """ 1841 run_id = _AUTOLOGGING_METRICS_MANAGER.get_run_id_for_model(self) 1842 if _AUTOLOGGING_METRICS_MANAGER.should_log_post_training_metrics() and run_id: 1843 # Avoid nested patch when nested inference calls happens. 1844 with _AUTOLOGGING_METRICS_MANAGER.disable_log_post_training_metrics(): 1845 predict_result = original(self, *args, **kwargs) 1846 eval_dataset = get_instance_method_first_arg_value(original, args, kwargs) 1847 eval_dataset_name = _AUTOLOGGING_METRICS_MANAGER.register_prediction_input_dataset( 1848 self, eval_dataset 1849 ) 1850 _AUTOLOGGING_METRICS_MANAGER.register_prediction_result( 1851 run_id, 1852 eval_dataset_name, 1853 predict_result, 1854 model_id=_AUTOLOGGING_METRICS_MANAGER.get_model_id_for_model(self), 1855 ) 1856 if log_datasets: 1857 try: 1858 context_tags = context_registry.resolve_tags() 1859 source = CodeDatasetSource(context_tags) 1860 1861 # log the dataset 1862 if dataset := _create_dataset(eval_dataset, source): 1863 tags = [InputTag(key=MLFLOW_DATASET_CONTEXT, value="eval")] 1864 dataset_input = DatasetInput(dataset=dataset._to_mlflow_entity(), tags=tags) 1865 1866 # log the dataset 1867 client = mlflow.MlflowClient() 1868 client.log_inputs(run_id=run_id, datasets=[dataset_input]) 1869 except Exception as e: 1870 _logger.warning( 1871 "Failed to log evaluation dataset information to " 1872 "MLflow Tracking. Reason: %s", 1873 e, 1874 ) 1875 return predict_result 1876 else: 1877 return original(self, *args, **kwargs) 1878 1879 def patched_metric_api(original, *args, **kwargs): 1880 if _AUTOLOGGING_METRICS_MANAGER.should_log_post_training_metrics(): 1881 # one metric api may call another metric api, 1882 # to avoid this, call disable_log_post_training_metrics to avoid nested patch 1883 with _AUTOLOGGING_METRICS_MANAGER.disable_log_post_training_metrics(): 1884 metric = original(*args, **kwargs) 1885 1886 if _AUTOLOGGING_METRICS_MANAGER.is_metric_value_loggable(metric): 1887 metric_name = original.__name__ 1888 call_command = _AUTOLOGGING_METRICS_MANAGER.gen_metric_call_command( 1889 None, original, *args, **kwargs 1890 ) 1891 1892 (run_id, dataset_name, model_id) = ( 1893 _AUTOLOGGING_METRICS_MANAGER.get_info_for_metric_api_call(args, kwargs) 1894 ) 1895 if run_id and dataset_name: 1896 metric_key = _AUTOLOGGING_METRICS_MANAGER.register_metric_api_call( 1897 run_id, metric_name, dataset_name, call_command 1898 ) 1899 _AUTOLOGGING_METRICS_MANAGER.log_post_training_metric( 1900 run_id, metric_key, metric, model_id=model_id 1901 ) 1902 1903 return metric 1904 else: 1905 return original(*args, **kwargs) 1906 1907 # we need patch model.score method because: 1908 # some model.score() implementation won't call metric APIs in `sklearn.metrics` 1909 # e.g. 1910 # https://github.com/scikit-learn/scikit-learn/blob/82df48934eba1df9a1ed3be98aaace8eada59e6e/sklearn/covariance/_empirical_covariance.py#L220 1911 def patched_model_score(original, self, *args, **kwargs): 1912 run_id = _AUTOLOGGING_METRICS_MANAGER.get_run_id_for_model(self) 1913 if _AUTOLOGGING_METRICS_MANAGER.should_log_post_training_metrics() and run_id: 1914 # `model.score` may call metric APIs internally, in order to prevent nested metric call 1915 # being logged, temporarily disable post_training_metrics patching. 1916 with _AUTOLOGGING_METRICS_MANAGER.disable_log_post_training_metrics(): 1917 score_value = original(self, *args, **kwargs) 1918 1919 if _AUTOLOGGING_METRICS_MANAGER.is_metric_value_loggable(score_value): 1920 metric_name = f"{self.__class__.__name__}_score" 1921 call_command = _AUTOLOGGING_METRICS_MANAGER.gen_metric_call_command( 1922 self, original, *args, **kwargs 1923 ) 1924 1925 eval_dataset = get_instance_method_first_arg_value(original, args, kwargs) 1926 eval_dataset_name = _AUTOLOGGING_METRICS_MANAGER.register_prediction_input_dataset( 1927 self, eval_dataset 1928 ) 1929 metric_key = _AUTOLOGGING_METRICS_MANAGER.register_metric_api_call( 1930 run_id, metric_name, eval_dataset_name, call_command 1931 ) 1932 model_id = _AUTOLOGGING_METRICS_MANAGER.get_model_id_for_model(self) 1933 _AUTOLOGGING_METRICS_MANAGER.log_post_training_metric( 1934 run_id, metric_key, score_value, model_id=model_id 1935 ) 1936 1937 return score_value 1938 else: 1939 return original(self, *args, **kwargs) 1940 1941 def _apply_sklearn_descriptor_unbound_method_call_fix(): 1942 import sklearn 1943 1944 if Version(sklearn.__version__) <= Version("0.24.2"): 1945 import sklearn.utils.metaestimators 1946 1947 if not hasattr(sklearn.utils.metaestimators, "_IffHasAttrDescriptor"): 1948 return 1949 1950 def patched_IffHasAttrDescriptor__get__(self, obj, type=None): 1951 """ 1952 For sklearn version <= 0.24.2, `_IffHasAttrDescriptor.__get__` method does not 1953 support unbound method call. 1954 See https://github.com/scikit-learn/scikit-learn/issues/20614 1955 This patched function is for hot patch. 1956 """ 1957 1958 # raise an AttributeError if the attribute is not present on the object 1959 if obj is not None: 1960 # delegate only on instances, not the classes. 1961 # this is to allow access to the docstrings. 1962 for delegate_name in self.delegate_names: 1963 try: 1964 delegate = sklearn.utils.metaestimators.attrgetter(delegate_name)(obj) 1965 except AttributeError: 1966 continue 1967 else: 1968 getattr(delegate, self.attribute_name) 1969 break 1970 else: 1971 sklearn.utils.metaestimators.attrgetter(self.delegate_names[-1])(obj) 1972 1973 def out(*args, **kwargs): 1974 return self.fn(obj, *args, **kwargs) 1975 1976 else: 1977 # This makes it possible to use the decorated method as an unbound method, 1978 # for instance when monkeypatching. 1979 def out(*args, **kwargs): 1980 return self.fn(*args, **kwargs) 1981 1982 # update the docstring of the returned function 1983 functools.update_wrapper(out, self.fn) 1984 return out 1985 1986 update_wrapper_extended( 1987 patched_IffHasAttrDescriptor__get__, 1988 sklearn.utils.metaestimators._IffHasAttrDescriptor.__get__, 1989 ) 1990 1991 sklearn.utils.metaestimators._IffHasAttrDescriptor.__get__ = ( 1992 patched_IffHasAttrDescriptor__get__ 1993 ) 1994 1995 _apply_sklearn_descriptor_unbound_method_call_fix() 1996 1997 if flavor_name == mlflow.xgboost.FLAVOR_NAME: 1998 estimators_to_patch = _gen_xgboost_sklearn_estimators_to_patch() 1999 patched_fit_impl = fit_mlflow_xgboost_and_lightgbm 2000 allow_children_patch = True 2001 elif flavor_name == mlflow.lightgbm.FLAVOR_NAME: 2002 estimators_to_patch = _gen_lightgbm_sklearn_estimators_to_patch() 2003 patched_fit_impl = fit_mlflow_xgboost_and_lightgbm 2004 allow_children_patch = True 2005 else: 2006 estimators_to_patch = _gen_estimators_to_patch() 2007 patched_fit_impl = fit_mlflow 2008 allow_children_patch = False 2009 2010 for class_def in estimators_to_patch: 2011 # Patch fitting methods 2012 for func_name in ["fit", "fit_transform", "fit_predict"]: 2013 _patch_estimator_method_if_available( 2014 flavor_name, 2015 class_def, 2016 func_name, 2017 functools.partial(patched_fit, patched_fit_impl, allow_children_patch), 2018 manage_run=True, 2019 extra_tags=extra_tags, 2020 ) 2021 2022 # Patch inference methods 2023 for func_name in ["predict", "predict_proba", "transform", "predict_log_proba"]: 2024 _patch_estimator_method_if_available( 2025 flavor_name, 2026 class_def, 2027 func_name, 2028 patched_predict, 2029 manage_run=False, 2030 ) 2031 2032 # Patch scoring methods 2033 _patch_estimator_method_if_available( 2034 flavor_name, 2035 class_def, 2036 "score", 2037 patched_model_score, 2038 manage_run=False, 2039 extra_tags=extra_tags, 2040 ) 2041 2042 if log_post_training_metrics: 2043 for metric_name in _get_metric_name_list(): 2044 safe_patch( 2045 flavor_name, sklearn.metrics, metric_name, patched_metric_api, manage_run=False 2046 ) 2047 2048 # `sklearn.metrics.SCORERS` was removed in scikit-learn 1.3 2049 if hasattr(sklearn.metrics, "get_scorer_names"): 2050 for scoring in sklearn.metrics.get_scorer_names(): 2051 scorer = sklearn.metrics.get_scorer(scoring) 2052 safe_patch(flavor_name, scorer, "_score_func", patched_metric_api, manage_run=False) 2053 else: 2054 for scorer in sklearn.metrics.SCORERS.values(): 2055 safe_patch(flavor_name, scorer, "_score_func", patched_metric_api, manage_run=False) 2056 2057 def patched_fn_with_autolog_disabled(original, *args, **kwargs): 2058 with disable_autologging(): 2059 return original(*args, **kwargs) 2060 2061 for disable_autolog_func_name in _apis_autologging_disabled: 2062 safe_patch( 2063 flavor_name, 2064 sklearn.model_selection, 2065 disable_autolog_func_name, 2066 patched_fn_with_autolog_disabled, 2067 manage_run=False, 2068 ) 2069 2070 def _create_dataset(X, source, y=None, dataset_name=None): 2071 # create a dataset 2072 from scipy.sparse import issparse 2073 2074 if isinstance(X, pd.DataFrame): 2075 dataset = from_pandas(df=X, source=source) 2076 elif issparse(X): 2077 arr_X = X.toarray() 2078 if y is not None: 2079 dataset = from_numpy( 2080 features=arr_X, 2081 targets=y.toarray() if issparse(y) else y, 2082 source=source, 2083 name=dataset_name, 2084 ) 2085 else: 2086 dataset = from_numpy(features=arr_X, source=source, name=dataset_name) 2087 elif isinstance(X, np.ndarray): 2088 if y is not None: 2089 dataset = from_numpy(features=X, targets=y, source=source, name=dataset_name) 2090 else: 2091 dataset = from_numpy(features=X, source=source, name=dataset_name) 2092 elif is_polars_dataframe(X): 2093 from mlflow.data.polars_dataset import from_polars 2094 2095 dataset = from_polars(df=X, source=source, name=dataset_name) 2096 else: 2097 _logger.warning("Unrecognized dataset type %s. Dataset logging skipped.", type(X)) 2098 return None 2099 return dataset