/ mlflow / paddle / _paddle_autolog.py
_paddle_autolog.py
  1  import paddle
  2  
  3  import mlflow
  4  from mlflow.tracking.fluent import _initialize_logged_model
  5  from mlflow.utils.autologging_utils import (
  6      BatchMetricsLogger,
  7      ExceptionSafeAbstractClass,
  8      MlflowAutologgingQueueingClient,
  9      get_autologging_config,
 10  )
 11  
 12  
 13  class __MlflowPaddleCallback(paddle.callbacks.Callback, metaclass=ExceptionSafeAbstractClass):
 14      """Callback for auto-logging metrics and parameters."""
 15  
 16      def __init__(self, client, metrics_logger, run_id, log_models, log_every_n_epoch):
 17          super().__init__()
 18          self.early_stopping = False
 19          self.client = client
 20          self.metrics_logger = metrics_logger
 21          self.run_id = run_id
 22          self.log_models = log_models
 23          self.log_every_n_epoch = log_every_n_epoch
 24          self.epoch = 0
 25  
 26      def _log_metrics(self, logs, current_epoch):
 27          metrics = {
 28              key: (metric[0] if isinstance(metric, list) else metric) for key, metric in logs.items()
 29          }
 30          self.metrics_logger.record_metrics(metrics, current_epoch)
 31  
 32      def on_epoch_end(self, epoch, logs=None):
 33          if self.model is not None and epoch % self.log_every_n_epoch == 0:
 34              self._log_metrics(logs, epoch)
 35              self.epoch = epoch
 36  
 37      def on_train_begin(self, logs=None):
 38          params = {
 39              "optimizer_name": self.model._optimizer.__class__.__name__,
 40              "learning_rate": self.model._optimizer._learning_rate,
 41          }
 42          self.client.log_params(self.run_id, params)
 43          self.client.flush(synchronous=True)
 44  
 45      def on_train_end(self, logs=None):
 46          self.metrics_logger.flush()
 47          self.client.flush(synchronous=True)
 48  
 49      def on_eval_end(self, logs=None):
 50          eval_logs = {
 51              "eval_" + key: (metric[0] if isinstance(metric, list) else metric)
 52              for key, metric in logs.items()
 53          }
 54          self._log_metrics(eval_logs, self.epoch)
 55  
 56  
 57  def _log_early_stop_params(early_stop_callback, client, run_id):
 58      """
 59      Logs early stopping configuration parameters to MLflow.
 60  
 61      Args:
 62          early_stop_callback: The early stopping callback instance used during training.
 63          client: An `MlflowAutologgingQueueingClient` instance used for MLflow logging.
 64          run_id: The ID of the MLflow Run to which to log configuration parameters.
 65      """
 66      client.log_params(
 67          run_id,
 68          {
 69              p: getattr(early_stop_callback, p)
 70              for p in ["monitor", "patience", "min_delta", "baseline"]
 71              if hasattr(early_stop_callback, p)
 72          },
 73      )
 74  
 75  
 76  def _log_early_stop_metrics(early_stop_callback, client, run_id, model_id=None):
 77      """
 78      Logs early stopping behavior results (e.g. stopped epoch) as metrics to MLflow.
 79  
 80      Args:
 81          early_stop_callback: The early stopping callback instance used during training.
 82          client: An `MlflowAutologgingQueueingClient` instance used for MLflow logging.
 83          run_id: The ID of the MLflow Run to which to log configuration parameters.
 84          model_id: The ID of the model metrics will be associated with.
 85      """
 86      if early_stop_callback.stopped_epoch == 0:
 87          return
 88  
 89      metrics = {
 90          "stopped_epoch": early_stop_callback.stopped_epoch,
 91          "best_value": early_stop_callback.best_value,
 92      }
 93      client.log_metrics(run_id, metrics, model_id=model_id)
 94  
 95  
 96  def patched_fit(original, self, *args, **kwargs):
 97      run_id = mlflow.active_run().info.run_id
 98      tracking_uri = mlflow.get_tracking_uri()
 99      client = MlflowAutologgingQueueingClient(tracking_uri)
100      log_models = get_autologging_config(mlflow.paddle.FLAVOR_NAME, "log_models", True)
101      log_every_n_epoch = get_autologging_config(mlflow.paddle.FLAVOR_NAME, "log_every_n_epoch", 1)
102  
103      model_id = None
104      if log_models:
105          model_id = _initialize_logged_model("model", flavor=mlflow.paddle.FLAVOR_NAME).model_id
106      metrics_logger = BatchMetricsLogger(run_id, tracking_uri, model_id=model_id)
107  
108      early_stop_callback = None
109      mlflow_callback = __MlflowPaddleCallback(
110          client, metrics_logger, run_id, log_models, log_every_n_epoch
111      )
112      if "callbacks" in kwargs:
113          callbacks = kwargs["callbacks"]
114          for callback in callbacks:
115              if isinstance(callback, paddle.callbacks.EarlyStopping):
116                  early_stop_callback = callback
117                  _log_early_stop_params(early_stop_callback, client, run_id)
118                  break
119          kwargs["callbacks"].append(mlflow_callback)
120      else:
121          kwargs["callbacks"] = [mlflow_callback]
122      client.flush(synchronous=False)
123  
124      result = original(self, *args, **kwargs)
125  
126      if early_stop_callback is not None:
127          _log_early_stop_metrics(early_stop_callback, client, run_id, model_id=model_id)
128  
129      mlflow.log_text(str(self.summary()), "model_summary.txt")
130  
131      if log_models:
132          registered_model_name = get_autologging_config(
133              mlflow.paddle.FLAVOR_NAME, "registered_model_name", None
134          )
135          mlflow.paddle.log_model(
136              self, name="model", registered_model_name=registered_model_name, model_id=model_id
137          )
138  
139      client.flush(synchronous=True)
140  
141      return result