train.py
1 """ 2 Train a simple Keras DL model on the dataset used in MLflow tutorial (wine-quality.csv). 3 4 Dataset is split into train (~ 0.56), validation(~ 0.19) and test (0.25). 5 Validation data is used to select the best hyperparameters, test set performance is evaluated only 6 at epochs which improved performance on the validation dataset. The model with best validation set 7 performance is logged with MLflow. 8 """ 9 10 import math 11 import warnings 12 13 import click 14 import numpy as np 15 import pandas as pd 16 from sklearn.metrics import mean_squared_error 17 from sklearn.model_selection import train_test_split 18 from tensorflow import keras 19 from tensorflow.keras.callbacks import Callback 20 from tensorflow.keras.layers import Dense, Lambda 21 from tensorflow.keras.models import Sequential 22 from tensorflow.keras.optimizers import SGD 23 24 import mlflow 25 from mlflow.models import infer_signature 26 27 28 def eval_and_log_metrics(prefix, actual, pred, epoch): 29 rmse = np.sqrt(mean_squared_error(actual, pred)) 30 mlflow.log_metric(f"{prefix}_rmse", rmse, step=epoch) 31 return rmse 32 33 34 def get_standardize_f(train): 35 mu = np.mean(train, axis=0) 36 std = np.std(train, axis=0) 37 return lambda x: (x - mu) / std 38 39 40 class MlflowCheckpoint(Callback): 41 """ 42 Example of Keras MLflow logger. 43 Logs training metrics and final model with MLflow. 44 45 We log metrics provided by Keras during training and keep track of the best model (best loss 46 on validation dataset). Every improvement of the best model is also evaluated on the test set. 47 48 At the end of the training, log the best model with MLflow. 49 """ 50 51 def __init__(self, test_x, test_y, loss="rmse"): 52 self._test_x = test_x 53 self._test_y = test_y 54 self.train_loss = f"train_{loss}" 55 self.val_loss = f"val_{loss}" 56 self.test_loss = f"test_{loss}" 57 self._best_train_loss = math.inf 58 self._best_val_loss = math.inf 59 self._best_model = None 60 self._next_step = 0 61 62 def __enter__(self): 63 return self 64 65 def __exit__(self, exc_type, exc_val, exc_tb): 66 """ 67 Log the best model at the end of the training run. 68 """ 69 if not self._best_model: 70 raise Exception("Failed to build any model") 71 mlflow.log_metric(self.train_loss, self._best_train_loss, step=self._next_step) 72 mlflow.log_metric(self.val_loss, self._best_val_loss, step=self._next_step) 73 predictions = self._best_model.predict(self._test_x) 74 signature = infer_signature(self._test_x, predictions) 75 mlflow.tensorflow.log_model(self._best_model, name="model", signature=signature) 76 77 def on_epoch_end(self, epoch, logs=None): 78 """ 79 Log Keras metrics with MLflow. If model improved on the validation data, evaluate it on 80 a test set and store it as the best model. 81 """ 82 if not logs: 83 return 84 self._next_step = epoch + 1 85 train_loss = logs["loss"] 86 val_loss = logs["val_loss"] 87 mlflow.log_metrics({self.train_loss: train_loss, self.val_loss: val_loss}, step=epoch) 88 89 if val_loss < self._best_val_loss: 90 # The result improved in the validation set. 91 # Log the model with mlflow and also evaluate and log on test set. 92 self._best_train_loss = train_loss 93 self._best_val_loss = val_loss 94 self._best_model = keras.models.clone_model(self.model) 95 self._best_model.set_weights([x.copy() for x in self.model.get_weights()]) 96 preds = self._best_model.predict(self._test_x) 97 eval_and_log_metrics("test", self._test_y, preds, epoch) 98 99 100 @click.command( 101 help="Trains an Keras model on wine-quality dataset. " 102 "The input is expected in csv format. " 103 "The model and its metrics are logged with mlflow." 104 ) 105 @click.option("--epochs", type=click.INT, default=100, help="Maximum number of epochs to evaluate.") 106 @click.option( 107 "--batch-size", type=click.INT, default=16, help="Batch size passed to the learning algo." 108 ) 109 @click.option("--learning-rate", type=click.FLOAT, default=1e-2, help="Learning rate.") 110 @click.option("--momentum", type=click.FLOAT, default=0.9, help="SGD momentum.") 111 @click.option("--seed", type=click.INT, default=97531, help="Seed for the random generator.") 112 @click.argument("training_data") 113 def run(training_data, epochs, batch_size, learning_rate, momentum, seed): 114 warnings.filterwarnings("ignore") 115 data = pd.read_csv(training_data, sep=";") 116 # Split the data into training and test sets. (0.75, 0.25) split. 117 train, test = train_test_split(data, random_state=seed) 118 train, valid = train_test_split(train, random_state=seed) 119 # The predicted column is "quality" which is a scalar from [3, 9] 120 train_x = train.drop(["quality"], axis=1).astype("float32").values 121 train_y = train[["quality"]].astype("float32").values 122 valid_x = valid.drop(["quality"], axis=1).astype("float32").values 123 124 valid_y = valid[["quality"]].astype("float32").values 125 126 test_x = test.drop(["quality"], axis=1).astype("float32").values 127 test_y = test[["quality"]].astype("float32").values 128 129 with mlflow.start_run(): 130 if epochs == 0: # score null model 131 eval_and_log_metrics( 132 "train", train_y, np.ones(len(train_y)) * np.mean(train_y), epoch=-1 133 ) 134 eval_and_log_metrics("val", valid_y, np.ones(len(valid_y)) * np.mean(valid_y), epoch=-1) 135 eval_and_log_metrics("test", test_y, np.ones(len(test_y)) * np.mean(test_y), epoch=-1) 136 else: 137 with MlflowCheckpoint(test_x, test_y) as mlflow_logger: 138 model = Sequential() 139 model.add(Lambda(get_standardize_f(train_x))) 140 model.add( 141 Dense( 142 train_x.shape[1], 143 activation="relu", 144 kernel_initializer="normal", 145 input_shape=(train_x.shape[1],), 146 ) 147 ) 148 model.add(Dense(16, activation="relu", kernel_initializer="normal")) 149 model.add(Dense(16, activation="relu", kernel_initializer="normal")) 150 model.add(Dense(1, kernel_initializer="normal", activation="linear")) 151 model.compile( 152 loss="mean_squared_error", 153 optimizer=SGD(lr=learning_rate, momentum=momentum), 154 metrics=[], 155 ) 156 157 model.fit( 158 train_x, 159 train_y, 160 batch_size=batch_size, 161 epochs=epochs, 162 verbose=1, 163 validation_data=(valid_x, valid_y), 164 callbacks=[mlflow_logger], 165 ) 166 167 168 if __name__ == "__main__": 169 run()