_lightning_autolog.py
1 import functools 2 import logging 3 import os 4 import tempfile 5 import warnings 6 7 import torch 8 from packaging.version import Version 9 10 import mlflow.pytorch 11 from mlflow.exceptions import MlflowException 12 from mlflow.ml_package_versions import _ML_PACKAGE_VERSIONS 13 from mlflow.models import infer_signature 14 from mlflow.tracking.fluent import _initialize_logged_model 15 from mlflow.utils import gorilla 16 from mlflow.utils.autologging_utils import ( 17 BatchMetricsLogger, 18 ExceptionSafeAbstractClass, 19 MlflowAutologgingQueueingClient, 20 disable_autologging, 21 get_autologging_config, 22 ) 23 from mlflow.utils.checkpoint_utils import MlflowModelCheckpointCallbackBase 24 25 logging.basicConfig(level=logging.ERROR) 26 MIN_REQ_VERSION = Version(_ML_PACKAGE_VERSIONS["pytorch-lightning"]["autologging"]["minimum"]) 27 MAX_REQ_VERSION = Version(_ML_PACKAGE_VERSIONS["pytorch-lightning"]["autologging"]["maximum"]) 28 29 import pytorch_lightning as pl 30 from pytorch_lightning.utilities import rank_zero_only 31 32 # The following are the downsides of using PyTorch Lightning's built-in MlflowLogger. 33 # 1. MlflowLogger doesn't provide a mechanism to store an entire model into mlflow. 34 # Only model checkpoint is saved. 35 # 2. For storing the model into mlflow `mlflow.pytorch` library is used 36 # and the library expects `mlflow` object to be instantiated. 37 # In case of MlflowLogger, Run management is completely controlled by the class and 38 # hence mlflow object needs to be reinstantiated by setting 39 # tracking uri, experiment_id and run_id which may lead to a race condition. 40 # TODO: Replace __MlflowPLCallback with Pytorch Lightning's built-in MlflowLogger 41 # once the above mentioned issues have been addressed 42 43 _logger = logging.getLogger(__name__) 44 45 _pl_version = Version(pl.__version__) 46 if _pl_version < Version("1.5.0"): 47 from pytorch_lightning.core.memory import ModelSummary 48 else: 49 from pytorch_lightning.utilities.model_summary import ModelSummary 50 51 52 def _get_optimizer_name(optimizer): 53 """ 54 In pytorch-lightning 1.1.0, `LightningOptimizer` was introduced: 55 https://github.com/PyTorchLightning/pytorch-lightning/pull/4658 56 57 If a user sets `enable_pl_optimizer` to True when instantiating a `Trainer` object, 58 each optimizer will be wrapped by `LightningOptimizer`: 59 https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.html 60 #pytorch_lightning.trainer.trainer.Trainer.params.enable_pl_optimizer 61 """ 62 if Version(pl.__version__) < Version("1.1.0"): 63 return optimizer.__class__.__name__ 64 else: 65 from pytorch_lightning.core.optimizer import LightningOptimizer 66 67 return ( 68 optimizer._optimizer.__class__.__name__ 69 if isinstance(optimizer, LightningOptimizer) 70 else optimizer.__class__.__name__ 71 ) 72 73 74 _MLFLOW_LIGHTNING_AUTOLOGGING_TMP_DIR_ENV = "_MLFLOW_LIGHTNING_AUTOLOGGING_TMP_DIR" 75 _INPUT_OUTPUT_TENSORS_FILENAME = "input_output_tensors.pkl" 76 77 78 class __MlflowPLCallback(pl.Callback, metaclass=ExceptionSafeAbstractClass): 79 """ 80 Callback for auto-logging metrics and parameters. 81 """ 82 83 def __init__( 84 self, 85 client, 86 metrics_logger, 87 run_id, 88 log_models, 89 log_every_n_epoch, 90 log_every_n_step, 91 log_model_signatures, 92 ): 93 if log_every_n_step and _pl_version < Version("1.1.0"): 94 raise MlflowException( 95 "log_every_n_step is only supported for PyTorch-Lightning >= 1.1.0" 96 ) 97 self.early_stopping = False 98 self.client = client 99 self.metrics_logger = metrics_logger 100 self.run_id = run_id 101 self.log_models = log_models 102 self.log_every_n_epoch = log_every_n_epoch 103 self.log_every_n_step = log_every_n_step 104 self._global_steps_per_training_step = 1 105 # Sets for tracking which metrics are logged on steps and which are logged on epochs 106 self._step_metrics = set() 107 self._epoch_metrics = set() 108 self.log_model_signatures = log_model_signatures 109 self._model_forward_patch = None 110 self._first_batch_checked = False 111 112 def _log_metrics(self, trainer, step, metric_items): 113 # pytorch-lightning runs a few steps of validation in the beginning of training 114 # as a sanity check to catch bugs without having to wait for the training routine 115 # to complete. During this check, we should skip logging metrics. 116 # https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#num-sanity-val-steps 117 sanity_checking = ( 118 # `running_sanity_check` has been renamed to `sanity_checking`: 119 # https://github.com/PyTorchLightning/pytorch-lightning/pull/9209 120 trainer.sanity_checking 121 if Version(pl.__version__) > Version("1.4.5") 122 else trainer.running_sanity_check 123 ) 124 if sanity_checking: 125 return 126 127 # Cast metric value as float before passing into logger. 128 metrics = {x[0]: float(x[1]) for x in metric_items} 129 self.metrics_logger.record_metrics(metrics, step) 130 131 def _log_epoch_metrics(self, trainer, pl_module): 132 # `trainer.callback_metrics` contains both training and validation metrics 133 # and includes metrics logged on steps and epochs. 134 # If we have logged any metrics on a step basis in mlflow, we exclude these from the 135 # epoch level metrics to prevent mixing epoch and step based values. 136 metric_items = [ 137 (name, val) 138 for (name, val) in trainer.callback_metrics.items() 139 if name not in self._step_metrics 140 ] 141 # Record which metrics are logged on epochs, so we don't try to log these on steps 142 self._epoch_metrics.update(name for (name, _) in metric_items) 143 if (pl_module.current_epoch + 1) % self.log_every_n_epoch == 0: 144 self._log_metrics(trainer, pl_module.current_epoch, metric_items) 145 146 _pl_version = Version(pl.__version__) 147 148 # In pytorch-lightning >= 1.4.0, validation is run inside the training epoch and 149 # `trainer.callback_metrics` contains both training and validation metrics of the 150 # current training epoch when `on_train_epoch_end` is called: 151 # https://github.com/PyTorchLightning/pytorch-lightning/pull/7357 152 if _pl_version >= Version("1.4.0dev"): 153 154 @rank_zero_only 155 def on_train_epoch_end(self, trainer, pl_module, *args): 156 self._log_epoch_metrics(trainer, pl_module) 157 158 # In pytorch-lightning >= 1.2.0, logging metrics in `on_epoch_end` results in duplicate 159 # metrics records because `on_epoch_end` is called after both train and validation 160 # epochs (related PR: https://github.com/PyTorchLightning/pytorch-lightning/pull/5986) 161 # As a workaround, use `on_train_epoch_end` and `on_validation_epoch_end` instead 162 # in pytorch-lightning >= 1.2.0. 163 elif _pl_version >= Version("1.2.0"): 164 # NB: Override `on_train_epoch_end` with an additional `*args` parameter for 165 # compatibility with versions of pytorch-lightning <= 1.2.0, which required an 166 # `outputs` argument that was not used and is no longer defined in 167 # pytorch-lightning >= 1.3.0 168 169 @rank_zero_only 170 def on_train_epoch_end(self, trainer, pl_module, *args): 171 """ 172 Log loss and other metrics values after each train epoch 173 174 Args: 175 trainer: pytorch lightning trainer instance 176 pl_module: pytorch lightning base module 177 args: additional positional arguments 178 """ 179 # If validation loop is enabled (meaning `validation_step` is overridden), 180 # log metrics in `on_validaion_epoch_end` to avoid logging the same metrics 181 # records twice 182 if not trainer.enable_validation: 183 self._log_epoch_metrics(trainer, pl_module) 184 185 @rank_zero_only 186 def on_validation_epoch_end(self, trainer, pl_module): 187 """ 188 Log loss and other metrics values after each validation epoch 189 190 Args: 191 trainer: pytorch lightning trainer instance 192 pl_module: pytorch lightning base module 193 """ 194 self._log_epoch_metrics(trainer, pl_module) 195 196 else: 197 198 @rank_zero_only 199 def on_epoch_end(self, trainer, pl_module): 200 """ 201 Log loss and other metrics values after each epoch 202 203 Args: 204 trainer: pytorch lightning trainer instance 205 pl_module: pytorch lightning base module 206 """ 207 self._log_epoch_metrics(trainer, pl_module) 208 209 @rank_zero_only 210 def on_train_batch_end(self, trainer, pl_module, *args): 211 """ 212 Log metric values after each step 213 214 Args: 215 trainer: pytorch lightning trainer instance 216 pl_module: pytorch lightning base module 217 args: additional positional arguments 218 """ 219 if not self.log_every_n_step: 220 return 221 # When logging at the end of a batch step, we only want to log metrics that are logged 222 # on steps. For forked metrics (metrics logged on both steps and epochs), we exclude the 223 # metric with the non-forked name (eg. "loss" when we have "loss", "loss_step" and 224 # "loss_epoch") so that this is only logged on epochs. We also record which metrics 225 # we've logged per step, so we can later exclude these from metrics logged on epochs. 226 metrics = _get_step_metrics(trainer) 227 metric_items = [ 228 (name, val) 229 for (name, val) in metrics.items() 230 if (name not in self._epoch_metrics) and (f"{name}_step" not in metrics.keys()) 231 ] 232 self._step_metrics.update(name for (name, _) in metric_items) 233 step = trainer.global_step 234 if ((step // self._global_steps_per_training_step) + 1) % self.log_every_n_step == 0: 235 self._log_metrics(trainer, step, metric_items) 236 237 @rank_zero_only 238 def on_train_start(self, trainer, pl_module): 239 """ 240 Logs Optimizer related metrics when the train begins 241 242 Args: 243 trainer: pytorch lightning trainer instance 244 pl_module: pytorch lightning base module 245 """ 246 self.client.set_tags(self.run_id, {"Mode": "training"}) 247 248 params = {"epochs": trainer.max_epochs} 249 250 # TODO For logging optimizer params - Following scenarios are to revisited. 251 # 1. In the current scenario, only the first optimizer details are logged. 252 # Code to be enhanced to log params when multiple optimizers are used. 253 # 2. mlflow.log_params is used to store optimizer default values into mlflow. 254 # The keys in default dictionary are too short, Ex: (lr - learning_rate). 255 # Efficient mapping technique needs to be introduced 256 # to rename the optimizer parameters based on keys in default dictionary. 257 258 if hasattr(trainer, "optimizers"): 259 # Lightning >= 1.6.0 increments the global step every time an optimizer is stepped. 260 # We assume every optimizer will be stepped in each training step. 261 if _pl_version >= Version("1.6.0"): 262 self._global_steps_per_training_step = len(trainer.optimizers) 263 optimizer = trainer.optimizers[0] 264 params["optimizer_name"] = _get_optimizer_name(optimizer) 265 266 if hasattr(optimizer, "defaults"): 267 params.update(optimizer.defaults) 268 269 self.client.log_params(self.run_id, params) 270 self.client.flush(synchronous=True) 271 272 if self.log_models and self.log_model_signatures: 273 # Set up `model.forward` patch in order to capture 274 # the first batch input (for inferring model signature). 275 276 # Note: 277 # 1. The `model.forward` patch can't be set up in the 278 # `patched Trainer.fit` method, because in training with 279 # parallel strategy, the `model.forward` is called in spawned 280 # training workers (subprocesses), and the patch in parent process 281 # does not work in subprocess. 282 # 283 # 2. We can't use `Callback.on_train_batch_start` to capture 284 # the first batch input, because the argument `batch` in 285 # `Callback.on_train_batch_start` contains input and target, 286 # and lightning callback interface does not restrict the 287 # data format of the batch argument, so we have no way to 288 # extract `model.forward` input from the batch argument 289 # (the extracting logic is defined in `model.training_step`). 290 lightning_module = trainer.strategy.lightning_module 291 original_model_forward = lightning_module.forward 292 293 def patched_model_forward(*inputs, **kwargs): 294 result = original_model_forward(*inputs, **kwargs) 295 if not self._first_batch_checked: 296 try: 297 # Model signature only supports input schema of one Tensor 298 if ( 299 len(inputs) == 1 300 and isinstance(inputs[0], torch.Tensor) 301 and isinstance(result, torch.Tensor) 302 ): 303 tempdir = os.environ.get(_MLFLOW_LIGHTNING_AUTOLOGGING_TMP_DIR_ENV) 304 assert tempdir is not None, ( 305 "_MLFLOW_LIGHTNING_AUTOLOGGING_TMP_DIR environment variable " 306 "is missing." 307 ) 308 torch.save( 309 (inputs[0], result), 310 os.path.join(tempdir, _INPUT_OUTPUT_TENSORS_FILENAME), 311 ) 312 except Exception: 313 pass 314 self._first_batch_checked = True 315 316 return result 317 318 patch = gorilla.Patch( 319 lightning_module, 320 "forward", 321 patched_model_forward, 322 gorilla.Settings(allow_hit=True, store_hit=True), 323 ) 324 gorilla.apply(patch) 325 self._model_forward_patch = patch 326 327 @rank_zero_only 328 def on_train_end(self, trainer, pl_module): 329 """ 330 Logs the model checkpoint into mlflow - models folder on the training end 331 332 333 Args: 334 trainer: pytorch lightning trainer instance 335 pl_module: pytorch lightning base module 336 """ 337 # manually flush any remaining metadata from training 338 self.metrics_logger.flush() 339 self.client.flush(synchronous=True) 340 341 @rank_zero_only 342 def on_test_end(self, trainer, pl_module): 343 """ 344 Logs accuracy and other relevant metrics on the testing end 345 346 Args: 347 trainer: pytorch lightning trainer instance 348 pl_module: pytorch lightning base module 349 """ 350 self.client.set_tags(self.run_id, {"Mode": "testing"}) 351 self.client.flush(synchronous=True) 352 353 self.metrics_logger.record_metrics({ 354 key: float(value) for key, value in trainer.callback_metrics.items() 355 }) 356 self.metrics_logger.flush() 357 358 359 class MlflowModelCheckpointCallback(pl.Callback, MlflowModelCheckpointCallbackBase): 360 """Callback for auto-logging pytorch-lightning model checkpoints to MLflow. 361 This callback implementation only supports pytorch-lightning >= 1.6.0. 362 363 Args: 364 monitor: In automatic model checkpointing, the metric name to monitor if 365 you set `model_checkpoint_save_best_only` to True. 366 save_best_only: If True, automatic model checkpointing only saves when 367 the model is considered the "best" model according to the quantity 368 monitored and previous checkpoint model is overwritten. 369 mode: one of {"min", "max"}. In automatic model checkpointing, 370 if save_best_only=True, the decision to overwrite the current save file is made 371 based on either the maximization or the minimization of the monitored quantity. 372 save_weights_only: In automatic model checkpointing, if True, then 373 only the model's weights will be saved. Otherwise, the optimizer states, 374 lr-scheduler states, etc are added in the checkpoint too. 375 save_freq: `"epoch"` or integer. When using `"epoch"`, the callback 376 saves the model after each epoch. When using integer, the callback 377 saves the model at end of this many batches. Note that if the saving isn't 378 aligned to epochs, the monitored metric may potentially be less reliable (it 379 could reflect as little as 1 batch, since the metrics get reset 380 every epoch). Defaults to `"epoch"`. 381 382 .. code-block:: python 383 :caption: Example 384 385 import mlflow 386 from mlflow.pytorch import MlflowModelCheckpointCallback 387 from pytorch_lightning import Trainer 388 389 mlflow.pytorch.autolog(checkpoint=True) 390 391 model = MyLightningModuleNet() # A custom-pytorch lightning model 392 train_loader = create_train_dataset_loader() 393 394 mlflow_checkpoint_callback = MlflowModelCheckpointCallback() 395 396 trainer = Trainer(callbacks=[mlflow_checkpoint_callback]) 397 398 with mlflow.start_run() as run: 399 trainer.fit(model, train_loader) 400 401 """ 402 403 def __init__( 404 self, 405 monitor="val_loss", 406 mode="min", 407 save_best_only=True, 408 save_weights_only=False, 409 save_freq="epoch", 410 ): 411 super().__init__( 412 checkpoint_file_suffix=".pth", 413 monitor=monitor, 414 mode=mode, 415 save_best_only=save_best_only, 416 save_weights_only=save_weights_only, 417 save_freq=save_freq, 418 ) 419 self.trainer = None 420 421 def save_checkpoint(self, filepath: str): 422 # Note: `trainer.save_checkpoint` implementation contains invocation of 423 # `self.strategy.barrier("Trainer.save_checkpoint")`, 424 # in DDP training, this callback is only invoked in rank 0 process, 425 # the `barrier` invocation causes deadlock, 426 # so I implement `save_checkpoint` instead of 427 # calling `trainer.save_checkpoint`. 428 checkpoint = self.trainer._checkpoint_connector.dump_checkpoint(self.save_weights_only) 429 self.trainer.strategy.save_checkpoint(checkpoint, filepath) 430 431 @rank_zero_only 432 def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 433 self.trainer = trainer 434 435 @rank_zero_only 436 def on_train_batch_end( 437 self, 438 trainer: "pl.Trainer", 439 pl_module: "pl.LightningModule", 440 outputs, 441 batch, 442 batch_idx, 443 ) -> None: 444 if isinstance(self.save_freq, int) and ( 445 trainer.global_step > 0 and trainer.global_step % self.save_freq == 0 446 ): 447 self.check_and_save_checkpoint_if_needed( 448 current_epoch=trainer.current_epoch, 449 global_step=trainer.global_step, 450 metric_dict={k: float(v) for k, v in trainer.callback_metrics.items()}, 451 ) 452 453 @rank_zero_only 454 def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 455 if self.save_freq == "epoch": 456 self.check_and_save_checkpoint_if_needed( 457 current_epoch=trainer.current_epoch, 458 global_step=trainer.global_step, 459 metric_dict={k: float(v) for k, v in trainer.callback_metrics.items()}, 460 ) 461 462 463 # PyTorch-Lightning refactored the LoggerConnector class in version 1.4.0 and made metrics 464 # update on demand. Prior to this, the metrics from the current step were not available to 465 # callbacks immediately, so the view of metrics was off by one step. 466 # To avoid this problem, we access the metrics via the logger_connector for older versions. 467 if _pl_version >= Version("1.4.0"): 468 469 def _get_step_metrics(trainer): 470 return trainer.callback_metrics 471 472 else: 473 474 def _get_step_metrics(trainer): 475 return trainer.logger_connector.cached_results.get_latest_batch_log_metrics() 476 477 478 def _log_early_stop_params(early_stop_callback, client, run_id): 479 """ 480 Logs early stopping configuration parameters to MLflow. 481 482 Args: 483 early_stop_callback: The early stopping callback instance used during training. 484 client: An `MlflowAutologgingQueueingClient` instance used for MLflow logging. 485 run_id: The ID of the MLflow Run to which to log configuration parameters. 486 """ 487 client.log_params( 488 run_id, 489 { 490 p: getattr(early_stop_callback, p) 491 for p in ["monitor", "mode", "patience", "min_delta", "stopped_epoch"] 492 if hasattr(early_stop_callback, p) 493 }, 494 ) 495 496 497 def _log_early_stop_metrics(early_stop_callback, client, run_id, model_id=None): 498 """ 499 Logs early stopping behavior results (e.g. stopped epoch) as metrics to MLflow. 500 501 Args: 502 early_stop_callback: The early stopping callback instance used during training. 503 client: An `MlflowAutologgingQueueingClient` instance used for MLflow logging. 504 run_id: The ID of the MLflow Run to which to log configuration parameters. 505 model_id: The ID of the LoggedModel to which the metrics are associated. 506 """ 507 if early_stop_callback.stopped_epoch == 0: 508 return 509 510 metrics = { 511 "stopped_epoch": early_stop_callback.stopped_epoch, 512 "restored_epoch": early_stop_callback.stopped_epoch - max(1, early_stop_callback.patience), 513 } 514 515 if hasattr(early_stop_callback, "best_score"): 516 metrics["best_score"] = float(early_stop_callback.best_score) 517 518 if hasattr(early_stop_callback, "wait_count"): 519 metrics["wait_count"] = early_stop_callback.wait_count 520 521 client.log_metrics(run_id, metrics, model_id=model_id) 522 523 524 def patched_fit(original, self, *args, **kwargs): 525 """ 526 A patched implementation of `pytorch_lightning.Trainer.fit` which enables logging the 527 following parameters, metrics and artifacts: 528 529 - Training epochs 530 - Optimizer parameters 531 - `EarlyStoppingCallback`_ parameters 532 - Metrics stored in `trainer.callback_metrics` 533 - Model checkpoints 534 - Trained model 535 536 .. _EarlyStoppingCallback: 537 https://pytorch-lightning.readthedocs.io/en/latest/early_stopping.html 538 """ 539 from mlflow.pytorch import _is_forecasting_model 540 541 if not MIN_REQ_VERSION <= _pl_version <= MAX_REQ_VERSION: 542 warnings.warn( 543 "Autologging is known to be compatible with pytorch-lightning versions between " 544 f"{MIN_REQ_VERSION} and {MAX_REQ_VERSION} and may not succeed with packages " 545 "outside this range." 546 ) 547 548 model = args[0] if len(args) > 0 else kwargs["model"] 549 if _is_forecasting_model(model): 550 # The forecasting model predict method calls tensor board writer's add_hparams 551 # method, which triggers pytorch autologging. The patch is for disabling it. 552 original_predict = model.predict 553 554 @functools.wraps(original_predict) 555 def patched_predict(*args, **kwargs): 556 with disable_autologging(): 557 return original_predict(*args, **kwargs) 558 559 model.predict = patched_predict 560 561 with disable_autologging(): 562 run_id = mlflow.active_run().info.run_id 563 tracking_uri = mlflow.get_tracking_uri() 564 client = MlflowAutologgingQueueingClient(tracking_uri) 565 566 log_model_signatures = get_autologging_config( 567 mlflow.pytorch.FLAVOR_NAME, "log_model_signatures", True 568 ) 569 log_models = get_autologging_config(mlflow.pytorch.FLAVOR_NAME, "log_models", True) 570 model_id = None 571 if log_models: 572 model_id = _initialize_logged_model( 573 name="model", flavor=mlflow.pytorch.FLAVOR_NAME 574 ).model_id 575 metrics_logger = BatchMetricsLogger(run_id, tracking_uri, model_id=model_id) 576 577 log_every_n_epoch = get_autologging_config( 578 mlflow.pytorch.FLAVOR_NAME, "log_every_n_epoch", 1 579 ) 580 log_every_n_step = get_autologging_config( 581 mlflow.pytorch.FLAVOR_NAME, "log_every_n_step", None 582 ) 583 584 early_stop_callback = None 585 for callback in self.callbacks: 586 if isinstance(callback, pl.callbacks.early_stopping.EarlyStopping): 587 early_stop_callback = callback 588 _log_early_stop_params(early_stop_callback, client, run_id) 589 590 if not any(isinstance(callbacks, __MlflowPLCallback) for callbacks in self.callbacks): 591 self.callbacks += [ 592 __MlflowPLCallback( 593 client, 594 metrics_logger, 595 run_id, 596 log_models, 597 log_every_n_epoch, 598 log_every_n_step, 599 log_model_signatures, 600 ) 601 ] 602 603 model_checkpoint = get_autologging_config(mlflow.pytorch.FLAVOR_NAME, "checkpoint", True) 604 if model_checkpoint: 605 # __MLflowModelCheckpoint only supports pytorch-lightning >= 1.6.0 606 if _pl_version >= Version("1.6.0"): 607 checkpoint_monitor = get_autologging_config( 608 mlflow.pytorch.FLAVOR_NAME, "checkpoint_monitor", "val_loss" 609 ) 610 checkpoint_mode = get_autologging_config( 611 mlflow.pytorch.FLAVOR_NAME, "checkpoint_mode", "min" 612 ) 613 checkpoint_save_best_only = get_autologging_config( 614 mlflow.pytorch.FLAVOR_NAME, "checkpoint_save_best_only", True 615 ) 616 checkpoint_save_weights_only = get_autologging_config( 617 mlflow.pytorch.FLAVOR_NAME, "checkpoint_save_weights_only", False 618 ) 619 checkpoint_save_freq = get_autologging_config( 620 mlflow.pytorch.FLAVOR_NAME, "checkpoint_save_freq", "epoch" 621 ) 622 623 if not any( 624 isinstance(callbacks, MlflowModelCheckpointCallback) 625 for callbacks in self.callbacks 626 ): 627 self.callbacks += [ 628 MlflowModelCheckpointCallback( 629 monitor=checkpoint_monitor, 630 mode=checkpoint_mode, 631 save_best_only=checkpoint_save_best_only, 632 save_weights_only=checkpoint_save_weights_only, 633 save_freq=checkpoint_save_freq, 634 ) 635 ] 636 else: 637 warnings.warn( 638 "Automatic model checkpointing is disabled because this feature only " 639 "supports pytorch-lightning >= 1.6.0." 640 ) 641 642 client.flush(synchronous=False) 643 644 with tempfile.TemporaryDirectory() as tempdir: 645 os.environ[_MLFLOW_LIGHTNING_AUTOLOGGING_TMP_DIR_ENV] = tempdir 646 647 try: 648 result = original(self, *args, **kwargs) 649 finally: 650 for callback in self.callbacks: 651 if isinstance(callback, __MlflowPLCallback) and callback._model_forward_patch: 652 gorilla.revert(callback._model_forward_patch) 653 654 model_signature = None 655 input_output_tensors_file = os.path.join(tempdir, _INPUT_OUTPUT_TENSORS_FILENAME) 656 if os.path.exists(input_output_tensors_file): 657 input_tensor, output_tensor = torch.load(input_output_tensors_file) 658 try: 659 input_example = input_tensor.cpu().numpy() 660 with torch.no_grad(): 661 output_example = output_tensor.cpu().numpy() 662 model_signature = infer_signature( 663 input_example, 664 output_example, 665 ) 666 except Exception as e: 667 _logger.warning( 668 "Inferring model signature failed, skip logging signature. " 669 "You need to manually log the model with a provided signature after " 670 f"training. root cause: {e!r}.", 671 exc_info=True, 672 ) 673 674 if early_stop_callback is not None: 675 _log_early_stop_metrics(early_stop_callback, client, run_id, model_id=model_id) 676 677 if Version(pl.__version__) < Version("1.4.0"): 678 summary = str(ModelSummary(self.model, mode="full")) 679 else: 680 summary = str(ModelSummary(self.model, max_depth=-1)) 681 682 summary_file = os.path.join(tempdir, "model_summary.txt") 683 with open(summary_file, "w") as f: 684 f.write(summary) 685 686 mlflow.log_artifact(local_path=summary_file) 687 688 if log_models: 689 registered_model_name = get_autologging_config( 690 mlflow.pytorch.FLAVOR_NAME, "registered_model_name", None 691 ) 692 mlflow.pytorch.log_model( 693 self.model, 694 name="model", 695 registered_model_name=registered_model_name, 696 model_id=model_id, 697 signature=model_signature, 698 ) 699 700 if early_stop_callback is not None and self.checkpoint_callback.best_model_path: 701 mlflow.log_artifact( 702 local_path=self.checkpoint_callback.best_model_path, 703 artifact_path="restored_model_checkpoint", 704 ) 705 706 client.flush(synchronous=True) 707 708 return result