train.py
1 """ 2 Example of image classification with MLflow using Keras to classify flowers from photos. The data is 3 taken from ``http://download.tensorflow.org/example_images/flower_photos.tgz`` and may be 4 downloaded during running this project if it is missing. 5 """ 6 7 import math 8 import os 9 import tarfile 10 11 import click 12 import keras 13 import numpy as np 14 import tensorflow as tf 15 from image_pyfunc import decode_and_resize_image, log_model 16 from keras.applications import vgg16 17 from keras.callbacks import Callback 18 from keras.layers import Dense, Flatten, Input, Lambda 19 from keras.models import Model 20 from keras.utils import np_utils 21 from sklearn.model_selection import train_test_split 22 23 import mlflow 24 from mlflow.models import infer_signature 25 26 27 def download_input(): 28 import requests 29 30 url = "http://download.tensorflow.org/example_images/flower_photos.tgz" 31 print("downloading '{}' into '{}'".format(url, os.path.abspath("flower_photos.tgz"))) 32 r = requests.get(url) 33 with open("flower_photos.tgz", "wb") as f: 34 f.write(r.content) 35 36 print("decompressing flower_photos.tgz to '{}'".format(os.path.abspath("flower_photos"))) 37 with tarfile.open("flower_photos.tgz") as tar: 38 tar.extractall(path="./") 39 40 41 @click.command( 42 help="Trains an Keras model on flower_photos dataset. " 43 "The input is expected as a directory tree with pictures for each category in a " 44 "folder named by the category. " 45 "The model and its metrics are logged with mlflow." 46 ) 47 @click.option("--epochs", type=click.INT, default=1, help="Maximum number of epochs to evaluate.") 48 @click.option( 49 "--batch-size", type=click.INT, default=16, help="Batch size passed to the learning algo." 50 ) 51 @click.option("--image-width", type=click.INT, default=224, help="Input image width in pixels.") 52 @click.option("--image-height", type=click.INT, default=224, help="Input image height in pixels.") 53 @click.option("--seed", type=click.INT, default=97531, help="Seed for the random generator.") 54 @click.option("--training-data", type=click.STRING, default="./flower_photos") 55 @click.option("--test-ratio", type=click.FLOAT, default=0.2) 56 def run(training_data, test_ratio, epochs, batch_size, image_width, image_height, seed): 57 image_files = [] 58 labels = [] 59 domain = {} 60 print("Training model with the following parameters:") 61 for param, value in locals().items(): 62 print(" ", param, "=", value) 63 64 if training_data == "./flower_photos" and not os.path.exists(training_data): 65 print("Input data not found, attempting to download the data from the web.") 66 download_input() 67 68 for dirname, _, files in os.walk(training_data): 69 for filename in files: 70 if filename.endswith("jpg"): 71 image_files.append(os.path.join(dirname, filename)) 72 clazz = os.path.basename(dirname) 73 if clazz not in domain: 74 domain[clazz] = len(domain) 75 labels.append(domain[clazz]) 76 77 train( 78 image_files, 79 labels, 80 domain, 81 epochs=epochs, 82 test_ratio=test_ratio, 83 batch_size=batch_size, 84 image_width=image_width, 85 image_height=image_height, 86 seed=seed, 87 ) 88 89 90 class MlflowLogger(Callback): 91 """ 92 Keras callback for logging metrics and final model with MLflow. 93 94 Metrics are logged after every epoch. The logger keeps track of the best model based on the 95 validation metric. At the end of the training, the best model is logged with MLflow. 96 """ 97 98 def __init__(self, model, x_train, y_train, x_valid, y_valid, **kwargs): 99 self._model = model 100 self._best_val_loss = math.inf 101 self._train = (x_train, y_train) 102 self._valid = (x_valid, y_valid) 103 self._pyfunc_params = kwargs 104 self._best_weights = None 105 106 def on_epoch_end(self, epoch, logs=None): 107 """ 108 Log Keras metrics with MLflow. Update the best model if the model improved on the validation 109 data. 110 """ 111 if not logs: 112 return 113 for name, value in logs.items(): 114 name = "valid_" + name[4:] if name.startswith("val_") else "train_" + name 115 mlflow.log_metric(name, value) 116 val_loss = logs["val_loss"] 117 if val_loss < self._best_val_loss: 118 # Save the "best" weights 119 self._best_val_loss = val_loss 120 self._best_weights = [x.copy() for x in self._model.get_weights()] 121 122 def on_train_end(self, *args, **kwargs): 123 """ 124 Log the best model with MLflow and evaluate it on the train and validation data so that the 125 metrics stored with MLflow reflect the logged model. 126 """ 127 self._model.set_weights(self._best_weights) 128 x, y = self._train 129 train_res = self._model.evaluate(x=x, y=y) 130 for name, value in zip(self._model.metrics_names, train_res): 131 mlflow.log_metric(f"train_{name}", value) 132 x, y = self._valid 133 valid_res = self._model.evaluate(x=x, y=y) 134 for name, value in zip(self._model.metrics_names, valid_res): 135 mlflow.log_metric(f"valid_{name}", value) 136 signature = infer_signature(x, y) 137 log_model(keras_model=self._model, signature=signature, **self._pyfunc_params) 138 139 140 def _imagenet_preprocess_tf(x): 141 return (x / 127.5) - 1 142 143 144 def _create_model(input_shape, classes): 145 image = Input(input_shape) 146 lambda_layer = Lambda(_imagenet_preprocess_tf) 147 preprocessed_image = lambda_layer(image) 148 model = vgg16.VGG16( 149 classes=classes, input_tensor=preprocessed_image, weights=None, include_top=False 150 ) 151 152 x = Flatten(name="flatten")(model.output) 153 x = Dense(4096, activation="relu", name="fc1")(x) 154 x = Dense(4096, activation="relu", name="fc2")(x) 155 x = Dense(classes, activation="softmax", name="predictions")(x) 156 return Model(inputs=model.input, outputs=x) 157 158 159 def train( 160 image_files, 161 labels, 162 domain, 163 image_width=224, 164 image_height=224, 165 epochs=1, 166 batch_size=16, 167 test_ratio=0.2, 168 seed=None, 169 ): 170 """ 171 Train VGG16 model on provided image files. This will create a new MLflow run and log all 172 parameters, metrics and the resulting model with MLflow. The resulting model is an instance 173 of KerasImageClassifierPyfunc - a custom python function model that embeds all necessary 174 preprocessing together with the VGG16 Keras model. The resulting model can be applied 175 directly to image base64 encoded image data. 176 177 Args: 178 image_files: List of image files to be used for training. 179 labels: List of labels for the image files. 180 domain: Dictionary representing the domain of the response. 181 Provides mapping label-name -> label-id. 182 image_width: Width of the input image in pixels. 183 image_height: Height of the input image in pixels. 184 epochs: Number of epochs to train the model for. 185 batch_size: Batch size used during training. 186 test_ratio: Fraction of dataset to be used for validation. This data will not be used 187 during training. 188 seed: Random seed. Used e.g. when splitting the dataset into train / validation. 189 190 """ 191 assert len(set(labels)) == len(domain) 192 193 input_shape = (image_width, image_height, 3) 194 195 with mlflow.start_run(): 196 mlflow.log_param("epochs", str(epochs)) 197 mlflow.log_param("batch_size", str(batch_size)) 198 mlflow.log_param("validation_ratio", str(test_ratio)) 199 if seed: 200 mlflow.log_param("seed", str(seed)) 201 202 def _read_image(filename): 203 with open(filename, "rb") as f: 204 return f.read() 205 206 with tf.Graph().as_default() as g: 207 with tf.compat.v1.Session(graph=g).as_default(): 208 dims = input_shape[:2] 209 x = np.array([decode_and_resize_image(_read_image(x), dims) for x in image_files]) 210 y = np_utils.to_categorical(np.array(labels), num_classes=len(domain)) 211 train_size = 1 - test_ratio 212 x_train, x_valid, y_train, y_valid = train_test_split( 213 x, y, random_state=seed, train_size=train_size 214 ) 215 model = _create_model(input_shape=input_shape, classes=len(domain)) 216 model.compile( 217 optimizer=keras.optimizers.SGD(decay=1e-5, nesterov=True, momentum=0.9), 218 loss=keras.losses.categorical_crossentropy, 219 metrics=["accuracy"], 220 ) 221 sorted_domain = sorted(domain.keys(), key=lambda x: domain[x]) 222 model.fit( 223 x=x_train, 224 y=y_train, 225 validation_data=(x_valid, y_valid), 226 epochs=epochs, 227 batch_size=batch_size, 228 callbacks=[ 229 MlflowLogger( 230 model=model, 231 x_train=x_train, 232 y_train=y_train, 233 x_valid=x_valid, 234 y_valid=y_valid, 235 artifact_path="model", 236 domain=sorted_domain, 237 image_dims=input_shape, 238 ) 239 ], 240 ) 241 242 243 if __name__ == "__main__": 244 run()