_paddle_autolog.py
1 import paddle 2 3 import mlflow 4 from mlflow.tracking.fluent import _initialize_logged_model 5 from mlflow.utils.autologging_utils import ( 6 BatchMetricsLogger, 7 ExceptionSafeAbstractClass, 8 MlflowAutologgingQueueingClient, 9 get_autologging_config, 10 ) 11 12 13 class __MlflowPaddleCallback(paddle.callbacks.Callback, metaclass=ExceptionSafeAbstractClass): 14 """Callback for auto-logging metrics and parameters.""" 15 16 def __init__(self, client, metrics_logger, run_id, log_models, log_every_n_epoch): 17 super().__init__() 18 self.early_stopping = False 19 self.client = client 20 self.metrics_logger = metrics_logger 21 self.run_id = run_id 22 self.log_models = log_models 23 self.log_every_n_epoch = log_every_n_epoch 24 self.epoch = 0 25 26 def _log_metrics(self, logs, current_epoch): 27 metrics = { 28 key: (metric[0] if isinstance(metric, list) else metric) for key, metric in logs.items() 29 } 30 self.metrics_logger.record_metrics(metrics, current_epoch) 31 32 def on_epoch_end(self, epoch, logs=None): 33 if self.model is not None and epoch % self.log_every_n_epoch == 0: 34 self._log_metrics(logs, epoch) 35 self.epoch = epoch 36 37 def on_train_begin(self, logs=None): 38 params = { 39 "optimizer_name": self.model._optimizer.__class__.__name__, 40 "learning_rate": self.model._optimizer._learning_rate, 41 } 42 self.client.log_params(self.run_id, params) 43 self.client.flush(synchronous=True) 44 45 def on_train_end(self, logs=None): 46 self.metrics_logger.flush() 47 self.client.flush(synchronous=True) 48 49 def on_eval_end(self, logs=None): 50 eval_logs = { 51 "eval_" + key: (metric[0] if isinstance(metric, list) else metric) 52 for key, metric in logs.items() 53 } 54 self._log_metrics(eval_logs, self.epoch) 55 56 57 def _log_early_stop_params(early_stop_callback, client, run_id): 58 """ 59 Logs early stopping configuration parameters to MLflow. 60 61 Args: 62 early_stop_callback: The early stopping callback instance used during training. 63 client: An `MlflowAutologgingQueueingClient` instance used for MLflow logging. 64 run_id: The ID of the MLflow Run to which to log configuration parameters. 65 """ 66 client.log_params( 67 run_id, 68 { 69 p: getattr(early_stop_callback, p) 70 for p in ["monitor", "patience", "min_delta", "baseline"] 71 if hasattr(early_stop_callback, p) 72 }, 73 ) 74 75 76 def _log_early_stop_metrics(early_stop_callback, client, run_id, model_id=None): 77 """ 78 Logs early stopping behavior results (e.g. stopped epoch) as metrics to MLflow. 79 80 Args: 81 early_stop_callback: The early stopping callback instance used during training. 82 client: An `MlflowAutologgingQueueingClient` instance used for MLflow logging. 83 run_id: The ID of the MLflow Run to which to log configuration parameters. 84 model_id: The ID of the model metrics will be associated with. 85 """ 86 if early_stop_callback.stopped_epoch == 0: 87 return 88 89 metrics = { 90 "stopped_epoch": early_stop_callback.stopped_epoch, 91 "best_value": early_stop_callback.best_value, 92 } 93 client.log_metrics(run_id, metrics, model_id=model_id) 94 95 96 def patched_fit(original, self, *args, **kwargs): 97 run_id = mlflow.active_run().info.run_id 98 tracking_uri = mlflow.get_tracking_uri() 99 client = MlflowAutologgingQueueingClient(tracking_uri) 100 log_models = get_autologging_config(mlflow.paddle.FLAVOR_NAME, "log_models", True) 101 log_every_n_epoch = get_autologging_config(mlflow.paddle.FLAVOR_NAME, "log_every_n_epoch", 1) 102 103 model_id = None 104 if log_models: 105 model_id = _initialize_logged_model("model", flavor=mlflow.paddle.FLAVOR_NAME).model_id 106 metrics_logger = BatchMetricsLogger(run_id, tracking_uri, model_id=model_id) 107 108 early_stop_callback = None 109 mlflow_callback = __MlflowPaddleCallback( 110 client, metrics_logger, run_id, log_models, log_every_n_epoch 111 ) 112 if "callbacks" in kwargs: 113 callbacks = kwargs["callbacks"] 114 for callback in callbacks: 115 if isinstance(callback, paddle.callbacks.EarlyStopping): 116 early_stop_callback = callback 117 _log_early_stop_params(early_stop_callback, client, run_id) 118 break 119 kwargs["callbacks"].append(mlflow_callback) 120 else: 121 kwargs["callbacks"] = [mlflow_callback] 122 client.flush(synchronous=False) 123 124 result = original(self, *args, **kwargs) 125 126 if early_stop_callback is not None: 127 _log_early_stop_metrics(early_stop_callback, client, run_id, model_id=model_id) 128 129 mlflow.log_text(str(self.summary()), "model_summary.txt") 130 131 if log_models: 132 registered_model_name = get_autologging_config( 133 mlflow.paddle.FLAVOR_NAME, "registered_model_name", None 134 ) 135 mlflow.paddle.log_model( 136 self, name="model", registered_model_name=registered_model_name, model_id=model_id 137 ) 138 139 client.flush(synchronous=True) 140 141 return result