/ mlflow / pytorch / _lightning_autolog.py
_lightning_autolog.py
  1  import functools
  2  import logging
  3  import os
  4  import tempfile
  5  import warnings
  6  
  7  import torch
  8  from packaging.version import Version
  9  
 10  import mlflow.pytorch
 11  from mlflow.exceptions import MlflowException
 12  from mlflow.ml_package_versions import _ML_PACKAGE_VERSIONS
 13  from mlflow.models import infer_signature
 14  from mlflow.tracking.fluent import _initialize_logged_model
 15  from mlflow.utils import gorilla
 16  from mlflow.utils.autologging_utils import (
 17      BatchMetricsLogger,
 18      ExceptionSafeAbstractClass,
 19      MlflowAutologgingQueueingClient,
 20      disable_autologging,
 21      get_autologging_config,
 22  )
 23  from mlflow.utils.checkpoint_utils import MlflowModelCheckpointCallbackBase
 24  
 25  logging.basicConfig(level=logging.ERROR)
 26  MIN_REQ_VERSION = Version(_ML_PACKAGE_VERSIONS["pytorch-lightning"]["autologging"]["minimum"])
 27  MAX_REQ_VERSION = Version(_ML_PACKAGE_VERSIONS["pytorch-lightning"]["autologging"]["maximum"])
 28  
 29  import pytorch_lightning as pl
 30  from pytorch_lightning.utilities import rank_zero_only
 31  
 32  # The following are the downsides of using PyTorch Lightning's built-in MlflowLogger.
 33  # 1. MlflowLogger doesn't provide a mechanism to store an entire model into mlflow.
 34  #    Only model checkpoint is saved.
 35  # 2. For storing the model into mlflow `mlflow.pytorch` library is used
 36  # and the library expects `mlflow` object to be instantiated.
 37  # In case of MlflowLogger, Run management is completely controlled by the class and
 38  # hence mlflow object needs to be reinstantiated by setting
 39  # tracking uri, experiment_id and run_id which may lead to a race condition.
 40  # TODO: Replace __MlflowPLCallback with Pytorch Lightning's built-in MlflowLogger
 41  # once the above mentioned issues have been addressed
 42  
 43  _logger = logging.getLogger(__name__)
 44  
 45  _pl_version = Version(pl.__version__)
 46  if _pl_version < Version("1.5.0"):
 47      from pytorch_lightning.core.memory import ModelSummary
 48  else:
 49      from pytorch_lightning.utilities.model_summary import ModelSummary
 50  
 51  
 52  def _get_optimizer_name(optimizer):
 53      """
 54      In pytorch-lightning 1.1.0, `LightningOptimizer` was introduced:
 55      https://github.com/PyTorchLightning/pytorch-lightning/pull/4658
 56  
 57      If a user sets `enable_pl_optimizer` to True when instantiating a `Trainer` object,
 58      each optimizer will be wrapped by `LightningOptimizer`:
 59      https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.html
 60      #pytorch_lightning.trainer.trainer.Trainer.params.enable_pl_optimizer
 61      """
 62      if Version(pl.__version__) < Version("1.1.0"):
 63          return optimizer.__class__.__name__
 64      else:
 65          from pytorch_lightning.core.optimizer import LightningOptimizer
 66  
 67          return (
 68              optimizer._optimizer.__class__.__name__
 69              if isinstance(optimizer, LightningOptimizer)
 70              else optimizer.__class__.__name__
 71          )
 72  
 73  
 74  _MLFLOW_LIGHTNING_AUTOLOGGING_TMP_DIR_ENV = "_MLFLOW_LIGHTNING_AUTOLOGGING_TMP_DIR"
 75  _INPUT_OUTPUT_TENSORS_FILENAME = "input_output_tensors.pkl"
 76  
 77  
 78  class __MlflowPLCallback(pl.Callback, metaclass=ExceptionSafeAbstractClass):
 79      """
 80      Callback for auto-logging metrics and parameters.
 81      """
 82  
 83      def __init__(
 84          self,
 85          client,
 86          metrics_logger,
 87          run_id,
 88          log_models,
 89          log_every_n_epoch,
 90          log_every_n_step,
 91          log_model_signatures,
 92      ):
 93          if log_every_n_step and _pl_version < Version("1.1.0"):
 94              raise MlflowException(
 95                  "log_every_n_step is only supported for PyTorch-Lightning >= 1.1.0"
 96              )
 97          self.early_stopping = False
 98          self.client = client
 99          self.metrics_logger = metrics_logger
100          self.run_id = run_id
101          self.log_models = log_models
102          self.log_every_n_epoch = log_every_n_epoch
103          self.log_every_n_step = log_every_n_step
104          self._global_steps_per_training_step = 1
105          # Sets for tracking which metrics are logged on steps and which are logged on epochs
106          self._step_metrics = set()
107          self._epoch_metrics = set()
108          self.log_model_signatures = log_model_signatures
109          self._model_forward_patch = None
110          self._first_batch_checked = False
111  
112      def _log_metrics(self, trainer, step, metric_items):
113          # pytorch-lightning runs a few steps of validation in the beginning of training
114          # as a sanity check to catch bugs without having to wait for the training routine
115          # to complete. During this check, we should skip logging metrics.
116          # https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#num-sanity-val-steps
117          sanity_checking = (
118              # `running_sanity_check` has been renamed to `sanity_checking`:
119              # https://github.com/PyTorchLightning/pytorch-lightning/pull/9209
120              trainer.sanity_checking
121              if Version(pl.__version__) > Version("1.4.5")
122              else trainer.running_sanity_check
123          )
124          if sanity_checking:
125              return
126  
127          # Cast metric value as  float before passing into logger.
128          metrics = {x[0]: float(x[1]) for x in metric_items}
129          self.metrics_logger.record_metrics(metrics, step)
130  
131      def _log_epoch_metrics(self, trainer, pl_module):
132          # `trainer.callback_metrics` contains both training and validation metrics
133          # and includes metrics logged on steps and epochs.
134          # If we have logged any metrics on a step basis in mlflow, we exclude these from the
135          # epoch level metrics to prevent mixing epoch and step based values.
136          metric_items = [
137              (name, val)
138              for (name, val) in trainer.callback_metrics.items()
139              if name not in self._step_metrics
140          ]
141          # Record which metrics are logged on epochs, so we don't try to log these on steps
142          self._epoch_metrics.update(name for (name, _) in metric_items)
143          if (pl_module.current_epoch + 1) % self.log_every_n_epoch == 0:
144              self._log_metrics(trainer, pl_module.current_epoch, metric_items)
145  
146      _pl_version = Version(pl.__version__)
147  
148      # In pytorch-lightning >= 1.4.0, validation is run inside the training epoch and
149      # `trainer.callback_metrics` contains both training and validation metrics of the
150      # current training epoch when `on_train_epoch_end` is called:
151      # https://github.com/PyTorchLightning/pytorch-lightning/pull/7357
152      if _pl_version >= Version("1.4.0dev"):
153  
154          @rank_zero_only
155          def on_train_epoch_end(self, trainer, pl_module, *args):
156              self._log_epoch_metrics(trainer, pl_module)
157  
158      # In pytorch-lightning >= 1.2.0, logging metrics in `on_epoch_end` results in duplicate
159      # metrics records because `on_epoch_end` is called after both train and validation
160      # epochs (related PR: https://github.com/PyTorchLightning/pytorch-lightning/pull/5986)
161      # As a workaround, use `on_train_epoch_end` and `on_validation_epoch_end` instead
162      # in pytorch-lightning >= 1.2.0.
163      elif _pl_version >= Version("1.2.0"):
164          # NB: Override `on_train_epoch_end` with an additional `*args` parameter for
165          # compatibility with versions of pytorch-lightning <= 1.2.0, which required an
166          # `outputs` argument that was not used and is no longer defined in
167          # pytorch-lightning >= 1.3.0
168  
169          @rank_zero_only
170          def on_train_epoch_end(self, trainer, pl_module, *args):
171              """
172              Log loss and other metrics values after each train epoch
173  
174              Args:
175                  trainer: pytorch lightning trainer instance
176                  pl_module: pytorch lightning base module
177                  args: additional positional arguments
178              """
179              # If validation loop is enabled (meaning `validation_step` is overridden),
180              # log metrics in `on_validaion_epoch_end` to avoid logging the same metrics
181              # records twice
182              if not trainer.enable_validation:
183                  self._log_epoch_metrics(trainer, pl_module)
184  
185          @rank_zero_only
186          def on_validation_epoch_end(self, trainer, pl_module):
187              """
188              Log loss and other metrics values after each validation epoch
189  
190              Args:
191                  trainer: pytorch lightning trainer instance
192                  pl_module: pytorch lightning base module
193              """
194              self._log_epoch_metrics(trainer, pl_module)
195  
196      else:
197  
198          @rank_zero_only
199          def on_epoch_end(self, trainer, pl_module):
200              """
201              Log loss and other metrics values after each epoch
202  
203              Args:
204                  trainer: pytorch lightning trainer instance
205                  pl_module: pytorch lightning base module
206              """
207              self._log_epoch_metrics(trainer, pl_module)
208  
209      @rank_zero_only
210      def on_train_batch_end(self, trainer, pl_module, *args):
211          """
212          Log metric values after each step
213  
214          Args:
215              trainer: pytorch lightning trainer instance
216              pl_module: pytorch lightning base module
217              args: additional positional arguments
218          """
219          if not self.log_every_n_step:
220              return
221          # When logging at the end of a batch step, we only want to log metrics that are logged
222          # on steps. For forked metrics (metrics logged on both steps and epochs), we exclude the
223          # metric with the non-forked name (eg. "loss" when we have "loss", "loss_step" and
224          # "loss_epoch") so that this is only logged on epochs. We also record which metrics
225          # we've logged per step, so we can later exclude these from metrics logged on epochs.
226          metrics = _get_step_metrics(trainer)
227          metric_items = [
228              (name, val)
229              for (name, val) in metrics.items()
230              if (name not in self._epoch_metrics) and (f"{name}_step" not in metrics.keys())
231          ]
232          self._step_metrics.update(name for (name, _) in metric_items)
233          step = trainer.global_step
234          if ((step // self._global_steps_per_training_step) + 1) % self.log_every_n_step == 0:
235              self._log_metrics(trainer, step, metric_items)
236  
237      @rank_zero_only
238      def on_train_start(self, trainer, pl_module):
239          """
240          Logs Optimizer related metrics when the train begins
241  
242          Args:
243              trainer: pytorch lightning trainer instance
244              pl_module: pytorch lightning base module
245          """
246          self.client.set_tags(self.run_id, {"Mode": "training"})
247  
248          params = {"epochs": trainer.max_epochs}
249  
250          # TODO For logging optimizer params - Following scenarios are to revisited.
251          # 1. In the current scenario, only the first optimizer details are logged.
252          #    Code to be enhanced to log params when multiple optimizers are used.
253          # 2. mlflow.log_params is used to store optimizer default values into mlflow.
254          #    The keys in default dictionary are too short, Ex: (lr - learning_rate).
255          #    Efficient mapping technique needs to be introduced
256          #    to rename the optimizer parameters based on keys in default dictionary.
257  
258          if hasattr(trainer, "optimizers"):
259              # Lightning >= 1.6.0 increments the global step every time an optimizer is stepped.
260              # We assume every optimizer will be stepped in each training step.
261              if _pl_version >= Version("1.6.0"):
262                  self._global_steps_per_training_step = len(trainer.optimizers)
263              optimizer = trainer.optimizers[0]
264              params["optimizer_name"] = _get_optimizer_name(optimizer)
265  
266              if hasattr(optimizer, "defaults"):
267                  params.update(optimizer.defaults)
268  
269          self.client.log_params(self.run_id, params)
270          self.client.flush(synchronous=True)
271  
272          if self.log_models and self.log_model_signatures:
273              # Set up `model.forward` patch in order to capture
274              # the first batch input (for inferring model signature).
275  
276              # Note:
277              #  1. The `model.forward` patch can't be set up in the
278              #  `patched Trainer.fit` method, because in training with
279              #  parallel strategy, the `model.forward` is called in spawned
280              #  training workers (subprocesses), and the patch in parent process
281              #  does not work in subprocess.
282              #
283              #  2. We can't use `Callback.on_train_batch_start` to capture
284              #  the first batch input, because the argument `batch` in
285              #  `Callback.on_train_batch_start` contains input and target,
286              #  and lightning callback interface does not restrict the
287              #  data format of the batch argument, so we have no way to
288              #  extract `model.forward` input from the batch argument
289              #  (the extracting logic is defined in `model.training_step`).
290              lightning_module = trainer.strategy.lightning_module
291              original_model_forward = lightning_module.forward
292  
293              def patched_model_forward(*inputs, **kwargs):
294                  result = original_model_forward(*inputs, **kwargs)
295                  if not self._first_batch_checked:
296                      try:
297                          # Model signature only supports input schema of one Tensor
298                          if (
299                              len(inputs) == 1
300                              and isinstance(inputs[0], torch.Tensor)
301                              and isinstance(result, torch.Tensor)
302                          ):
303                              tempdir = os.environ.get(_MLFLOW_LIGHTNING_AUTOLOGGING_TMP_DIR_ENV)
304                              assert tempdir is not None, (
305                                  "_MLFLOW_LIGHTNING_AUTOLOGGING_TMP_DIR environment variable "
306                                  "is missing."
307                              )
308                              torch.save(
309                                  (inputs[0], result),
310                                  os.path.join(tempdir, _INPUT_OUTPUT_TENSORS_FILENAME),
311                              )
312                      except Exception:
313                          pass
314                      self._first_batch_checked = True
315  
316                  return result
317  
318              patch = gorilla.Patch(
319                  lightning_module,
320                  "forward",
321                  patched_model_forward,
322                  gorilla.Settings(allow_hit=True, store_hit=True),
323              )
324              gorilla.apply(patch)
325              self._model_forward_patch = patch
326  
327      @rank_zero_only
328      def on_train_end(self, trainer, pl_module):
329          """
330          Logs the model checkpoint into mlflow - models folder on the training end
331  
332  
333          Args:
334              trainer: pytorch lightning trainer instance
335              pl_module: pytorch lightning base module
336          """
337          # manually flush any remaining metadata from training
338          self.metrics_logger.flush()
339          self.client.flush(synchronous=True)
340  
341      @rank_zero_only
342      def on_test_end(self, trainer, pl_module):
343          """
344          Logs accuracy and other relevant metrics on the testing end
345  
346          Args:
347              trainer: pytorch lightning trainer instance
348              pl_module: pytorch lightning base module
349          """
350          self.client.set_tags(self.run_id, {"Mode": "testing"})
351          self.client.flush(synchronous=True)
352  
353          self.metrics_logger.record_metrics({
354              key: float(value) for key, value in trainer.callback_metrics.items()
355          })
356          self.metrics_logger.flush()
357  
358  
359  class MlflowModelCheckpointCallback(pl.Callback, MlflowModelCheckpointCallbackBase):
360      """Callback for auto-logging pytorch-lightning model checkpoints to MLflow.
361      This callback implementation only supports pytorch-lightning >= 1.6.0.
362  
363      Args:
364          monitor: In automatic model checkpointing, the metric name to monitor if
365              you set `model_checkpoint_save_best_only` to True.
366          save_best_only: If True, automatic model checkpointing only saves when
367              the model is considered the "best" model according to the quantity
368              monitored and previous checkpoint model is overwritten.
369          mode: one of {"min", "max"}. In automatic model checkpointing,
370              if save_best_only=True, the decision to overwrite the current save file is made
371              based on either the maximization or the minimization of the monitored quantity.
372          save_weights_only: In automatic model checkpointing, if True, then
373              only the model's weights will be saved. Otherwise, the optimizer states,
374              lr-scheduler states, etc are added in the checkpoint too.
375          save_freq: `"epoch"` or integer. When using `"epoch"`, the callback
376              saves the model after each epoch. When using integer, the callback
377              saves the model at end of this many batches. Note that if the saving isn't
378              aligned to epochs, the monitored metric may potentially be less reliable (it
379              could reflect as little as 1 batch, since the metrics get reset
380              every epoch). Defaults to `"epoch"`.
381  
382      .. code-block:: python
383          :caption: Example
384  
385          import mlflow
386          from mlflow.pytorch import MlflowModelCheckpointCallback
387          from pytorch_lightning import Trainer
388  
389          mlflow.pytorch.autolog(checkpoint=True)
390  
391          model = MyLightningModuleNet()  # A custom-pytorch lightning model
392          train_loader = create_train_dataset_loader()
393  
394          mlflow_checkpoint_callback = MlflowModelCheckpointCallback()
395  
396          trainer = Trainer(callbacks=[mlflow_checkpoint_callback])
397  
398          with mlflow.start_run() as run:
399              trainer.fit(model, train_loader)
400  
401      """
402  
403      def __init__(
404          self,
405          monitor="val_loss",
406          mode="min",
407          save_best_only=True,
408          save_weights_only=False,
409          save_freq="epoch",
410      ):
411          super().__init__(
412              checkpoint_file_suffix=".pth",
413              monitor=monitor,
414              mode=mode,
415              save_best_only=save_best_only,
416              save_weights_only=save_weights_only,
417              save_freq=save_freq,
418          )
419          self.trainer = None
420  
421      def save_checkpoint(self, filepath: str):
422          # Note: `trainer.save_checkpoint` implementation contains invocation of
423          # `self.strategy.barrier("Trainer.save_checkpoint")`,
424          # in DDP training, this callback is only invoked in rank 0 process,
425          # the `barrier` invocation causes deadlock,
426          # so I implement `save_checkpoint` instead of
427          # calling `trainer.save_checkpoint`.
428          checkpoint = self.trainer._checkpoint_connector.dump_checkpoint(self.save_weights_only)
429          self.trainer.strategy.save_checkpoint(checkpoint, filepath)
430  
431      @rank_zero_only
432      def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
433          self.trainer = trainer
434  
435      @rank_zero_only
436      def on_train_batch_end(
437          self,
438          trainer: "pl.Trainer",
439          pl_module: "pl.LightningModule",
440          outputs,
441          batch,
442          batch_idx,
443      ) -> None:
444          if isinstance(self.save_freq, int) and (
445              trainer.global_step > 0 and trainer.global_step % self.save_freq == 0
446          ):
447              self.check_and_save_checkpoint_if_needed(
448                  current_epoch=trainer.current_epoch,
449                  global_step=trainer.global_step,
450                  metric_dict={k: float(v) for k, v in trainer.callback_metrics.items()},
451              )
452  
453      @rank_zero_only
454      def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
455          if self.save_freq == "epoch":
456              self.check_and_save_checkpoint_if_needed(
457                  current_epoch=trainer.current_epoch,
458                  global_step=trainer.global_step,
459                  metric_dict={k: float(v) for k, v in trainer.callback_metrics.items()},
460              )
461  
462  
463  # PyTorch-Lightning refactored the LoggerConnector class in version 1.4.0 and made metrics
464  # update on demand. Prior to this, the metrics from the current step were not available to
465  # callbacks immediately, so the view of metrics was off by one step.
466  # To avoid this problem, we access the metrics via the logger_connector for older versions.
467  if _pl_version >= Version("1.4.0"):
468  
469      def _get_step_metrics(trainer):
470          return trainer.callback_metrics
471  
472  else:
473  
474      def _get_step_metrics(trainer):
475          return trainer.logger_connector.cached_results.get_latest_batch_log_metrics()
476  
477  
478  def _log_early_stop_params(early_stop_callback, client, run_id):
479      """
480      Logs early stopping configuration parameters to MLflow.
481  
482      Args:
483          early_stop_callback: The early stopping callback instance used during training.
484          client: An `MlflowAutologgingQueueingClient` instance used for MLflow logging.
485          run_id: The ID of the MLflow Run to which to log configuration parameters.
486      """
487      client.log_params(
488          run_id,
489          {
490              p: getattr(early_stop_callback, p)
491              for p in ["monitor", "mode", "patience", "min_delta", "stopped_epoch"]
492              if hasattr(early_stop_callback, p)
493          },
494      )
495  
496  
497  def _log_early_stop_metrics(early_stop_callback, client, run_id, model_id=None):
498      """
499      Logs early stopping behavior results (e.g. stopped epoch) as metrics to MLflow.
500  
501      Args:
502          early_stop_callback: The early stopping callback instance used during training.
503          client: An `MlflowAutologgingQueueingClient` instance used for MLflow logging.
504          run_id: The ID of the MLflow Run to which to log configuration parameters.
505          model_id: The ID of the LoggedModel to which the metrics are associated.
506      """
507      if early_stop_callback.stopped_epoch == 0:
508          return
509  
510      metrics = {
511          "stopped_epoch": early_stop_callback.stopped_epoch,
512          "restored_epoch": early_stop_callback.stopped_epoch - max(1, early_stop_callback.patience),
513      }
514  
515      if hasattr(early_stop_callback, "best_score"):
516          metrics["best_score"] = float(early_stop_callback.best_score)
517  
518      if hasattr(early_stop_callback, "wait_count"):
519          metrics["wait_count"] = early_stop_callback.wait_count
520  
521      client.log_metrics(run_id, metrics, model_id=model_id)
522  
523  
524  def patched_fit(original, self, *args, **kwargs):
525      """
526      A patched implementation of `pytorch_lightning.Trainer.fit` which enables logging the
527      following parameters, metrics and artifacts:
528  
529      - Training epochs
530      - Optimizer parameters
531      - `EarlyStoppingCallback`_ parameters
532      - Metrics stored in `trainer.callback_metrics`
533      - Model checkpoints
534      - Trained model
535  
536      .. _EarlyStoppingCallback:
537          https://pytorch-lightning.readthedocs.io/en/latest/early_stopping.html
538      """
539      from mlflow.pytorch import _is_forecasting_model
540  
541      if not MIN_REQ_VERSION <= _pl_version <= MAX_REQ_VERSION:
542          warnings.warn(
543              "Autologging is known to be compatible with pytorch-lightning versions between "
544              f"{MIN_REQ_VERSION} and {MAX_REQ_VERSION} and may not succeed with packages "
545              "outside this range."
546          )
547  
548      model = args[0] if len(args) > 0 else kwargs["model"]
549      if _is_forecasting_model(model):
550          # The forecasting model predict method calls tensor board writer's add_hparams
551          # method, which triggers pytorch autologging. The patch is for disabling it.
552          original_predict = model.predict
553  
554          @functools.wraps(original_predict)
555          def patched_predict(*args, **kwargs):
556              with disable_autologging():
557                  return original_predict(*args, **kwargs)
558  
559          model.predict = patched_predict
560  
561      with disable_autologging():
562          run_id = mlflow.active_run().info.run_id
563          tracking_uri = mlflow.get_tracking_uri()
564          client = MlflowAutologgingQueueingClient(tracking_uri)
565  
566          log_model_signatures = get_autologging_config(
567              mlflow.pytorch.FLAVOR_NAME, "log_model_signatures", True
568          )
569          log_models = get_autologging_config(mlflow.pytorch.FLAVOR_NAME, "log_models", True)
570          model_id = None
571          if log_models:
572              model_id = _initialize_logged_model(
573                  name="model", flavor=mlflow.pytorch.FLAVOR_NAME
574              ).model_id
575          metrics_logger = BatchMetricsLogger(run_id, tracking_uri, model_id=model_id)
576  
577          log_every_n_epoch = get_autologging_config(
578              mlflow.pytorch.FLAVOR_NAME, "log_every_n_epoch", 1
579          )
580          log_every_n_step = get_autologging_config(
581              mlflow.pytorch.FLAVOR_NAME, "log_every_n_step", None
582          )
583  
584          early_stop_callback = None
585          for callback in self.callbacks:
586              if isinstance(callback, pl.callbacks.early_stopping.EarlyStopping):
587                  early_stop_callback = callback
588                  _log_early_stop_params(early_stop_callback, client, run_id)
589  
590          if not any(isinstance(callbacks, __MlflowPLCallback) for callbacks in self.callbacks):
591              self.callbacks += [
592                  __MlflowPLCallback(
593                      client,
594                      metrics_logger,
595                      run_id,
596                      log_models,
597                      log_every_n_epoch,
598                      log_every_n_step,
599                      log_model_signatures,
600                  )
601              ]
602  
603          model_checkpoint = get_autologging_config(mlflow.pytorch.FLAVOR_NAME, "checkpoint", True)
604          if model_checkpoint:
605              # __MLflowModelCheckpoint only supports pytorch-lightning >= 1.6.0
606              if _pl_version >= Version("1.6.0"):
607                  checkpoint_monitor = get_autologging_config(
608                      mlflow.pytorch.FLAVOR_NAME, "checkpoint_monitor", "val_loss"
609                  )
610                  checkpoint_mode = get_autologging_config(
611                      mlflow.pytorch.FLAVOR_NAME, "checkpoint_mode", "min"
612                  )
613                  checkpoint_save_best_only = get_autologging_config(
614                      mlflow.pytorch.FLAVOR_NAME, "checkpoint_save_best_only", True
615                  )
616                  checkpoint_save_weights_only = get_autologging_config(
617                      mlflow.pytorch.FLAVOR_NAME, "checkpoint_save_weights_only", False
618                  )
619                  checkpoint_save_freq = get_autologging_config(
620                      mlflow.pytorch.FLAVOR_NAME, "checkpoint_save_freq", "epoch"
621                  )
622  
623                  if not any(
624                      isinstance(callbacks, MlflowModelCheckpointCallback)
625                      for callbacks in self.callbacks
626                  ):
627                      self.callbacks += [
628                          MlflowModelCheckpointCallback(
629                              monitor=checkpoint_monitor,
630                              mode=checkpoint_mode,
631                              save_best_only=checkpoint_save_best_only,
632                              save_weights_only=checkpoint_save_weights_only,
633                              save_freq=checkpoint_save_freq,
634                          )
635                      ]
636              else:
637                  warnings.warn(
638                      "Automatic model checkpointing is disabled because this feature only "
639                      "supports pytorch-lightning >= 1.6.0."
640                  )
641  
642          client.flush(synchronous=False)
643  
644          with tempfile.TemporaryDirectory() as tempdir:
645              os.environ[_MLFLOW_LIGHTNING_AUTOLOGGING_TMP_DIR_ENV] = tempdir
646  
647              try:
648                  result = original(self, *args, **kwargs)
649              finally:
650                  for callback in self.callbacks:
651                      if isinstance(callback, __MlflowPLCallback) and callback._model_forward_patch:
652                          gorilla.revert(callback._model_forward_patch)
653  
654              model_signature = None
655              input_output_tensors_file = os.path.join(tempdir, _INPUT_OUTPUT_TENSORS_FILENAME)
656              if os.path.exists(input_output_tensors_file):
657                  input_tensor, output_tensor = torch.load(input_output_tensors_file)
658                  try:
659                      input_example = input_tensor.cpu().numpy()
660                      with torch.no_grad():
661                          output_example = output_tensor.cpu().numpy()
662                      model_signature = infer_signature(
663                          input_example,
664                          output_example,
665                      )
666                  except Exception as e:
667                      _logger.warning(
668                          "Inferring model signature failed, skip logging signature. "
669                          "You need to manually log the model with a provided signature after "
670                          f"training. root cause: {e!r}.",
671                          exc_info=True,
672                      )
673  
674              if early_stop_callback is not None:
675                  _log_early_stop_metrics(early_stop_callback, client, run_id, model_id=model_id)
676  
677              if Version(pl.__version__) < Version("1.4.0"):
678                  summary = str(ModelSummary(self.model, mode="full"))
679              else:
680                  summary = str(ModelSummary(self.model, max_depth=-1))
681  
682              summary_file = os.path.join(tempdir, "model_summary.txt")
683              with open(summary_file, "w") as f:
684                  f.write(summary)
685  
686              mlflow.log_artifact(local_path=summary_file)
687  
688          if log_models:
689              registered_model_name = get_autologging_config(
690                  mlflow.pytorch.FLAVOR_NAME, "registered_model_name", None
691              )
692              mlflow.pytorch.log_model(
693                  self.model,
694                  name="model",
695                  registered_model_name=registered_model_name,
696                  model_id=model_id,
697                  signature=model_signature,
698              )
699  
700              if early_stop_callback is not None and self.checkpoint_callback.best_model_path:
701                  mlflow.log_artifact(
702                      local_path=self.checkpoint_callback.best_model_path,
703                      artifact_path="restored_model_checkpoint",
704                  )
705  
706          client.flush(synchronous=True)
707  
708      return result