train.py
  1  """
  2  Example of image classification with MLflow using Keras to classify flowers from photos. The data is
  3  taken from ``http://download.tensorflow.org/example_images/flower_photos.tgz`` and may be
  4  downloaded during running this project if it is missing.
  5  """
  6  
  7  import math
  8  import os
  9  import tarfile
 10  
 11  import click
 12  import keras
 13  import numpy as np
 14  import tensorflow as tf
 15  from image_pyfunc import decode_and_resize_image, log_model
 16  from keras.applications import vgg16
 17  from keras.callbacks import Callback
 18  from keras.layers import Dense, Flatten, Input, Lambda
 19  from keras.models import Model
 20  from keras.utils import np_utils
 21  from sklearn.model_selection import train_test_split
 22  
 23  import mlflow
 24  from mlflow.models import infer_signature
 25  
 26  
 27  def download_input():
 28      import requests
 29  
 30      url = "http://download.tensorflow.org/example_images/flower_photos.tgz"
 31      print("downloading '{}' into '{}'".format(url, os.path.abspath("flower_photos.tgz")))
 32      r = requests.get(url)
 33      with open("flower_photos.tgz", "wb") as f:
 34          f.write(r.content)
 35  
 36      print("decompressing flower_photos.tgz to '{}'".format(os.path.abspath("flower_photos")))
 37      with tarfile.open("flower_photos.tgz") as tar:
 38          tar.extractall(path="./")
 39  
 40  
 41  @click.command(
 42      help="Trains an Keras model on flower_photos dataset. "
 43      "The input is expected as a directory tree with pictures for each category in a "
 44      "folder named by the category. "
 45      "The model and its metrics are logged with mlflow."
 46  )
 47  @click.option("--epochs", type=click.INT, default=1, help="Maximum number of epochs to evaluate.")
 48  @click.option(
 49      "--batch-size", type=click.INT, default=16, help="Batch size passed to the learning algo."
 50  )
 51  @click.option("--image-width", type=click.INT, default=224, help="Input image width in pixels.")
 52  @click.option("--image-height", type=click.INT, default=224, help="Input image height in pixels.")
 53  @click.option("--seed", type=click.INT, default=97531, help="Seed for the random generator.")
 54  @click.option("--training-data", type=click.STRING, default="./flower_photos")
 55  @click.option("--test-ratio", type=click.FLOAT, default=0.2)
 56  def run(training_data, test_ratio, epochs, batch_size, image_width, image_height, seed):
 57      image_files = []
 58      labels = []
 59      domain = {}
 60      print("Training model with the following parameters:")
 61      for param, value in locals().items():
 62          print("  ", param, "=", value)
 63  
 64      if training_data == "./flower_photos" and not os.path.exists(training_data):
 65          print("Input data not found, attempting to download the data from the web.")
 66          download_input()
 67  
 68      for dirname, _, files in os.walk(training_data):
 69          for filename in files:
 70              if filename.endswith("jpg"):
 71                  image_files.append(os.path.join(dirname, filename))
 72                  clazz = os.path.basename(dirname)
 73                  if clazz not in domain:
 74                      domain[clazz] = len(domain)
 75                  labels.append(domain[clazz])
 76  
 77      train(
 78          image_files,
 79          labels,
 80          domain,
 81          epochs=epochs,
 82          test_ratio=test_ratio,
 83          batch_size=batch_size,
 84          image_width=image_width,
 85          image_height=image_height,
 86          seed=seed,
 87      )
 88  
 89  
 90  class MlflowLogger(Callback):
 91      """
 92      Keras callback for logging metrics and final model with MLflow.
 93  
 94      Metrics are logged after every epoch. The logger keeps track of the best model based on the
 95      validation metric. At the end of the training, the best model is logged with MLflow.
 96      """
 97  
 98      def __init__(self, model, x_train, y_train, x_valid, y_valid, **kwargs):
 99          self._model = model
100          self._best_val_loss = math.inf
101          self._train = (x_train, y_train)
102          self._valid = (x_valid, y_valid)
103          self._pyfunc_params = kwargs
104          self._best_weights = None
105  
106      def on_epoch_end(self, epoch, logs=None):
107          """
108          Log Keras metrics with MLflow. Update the best model if the model improved on the validation
109          data.
110          """
111          if not logs:
112              return
113          for name, value in logs.items():
114              name = "valid_" + name[4:] if name.startswith("val_") else "train_" + name
115              mlflow.log_metric(name, value)
116          val_loss = logs["val_loss"]
117          if val_loss < self._best_val_loss:
118              # Save the "best" weights
119              self._best_val_loss = val_loss
120              self._best_weights = [x.copy() for x in self._model.get_weights()]
121  
122      def on_train_end(self, *args, **kwargs):
123          """
124          Log the best model with MLflow and evaluate it on the train and validation data so that the
125          metrics stored with MLflow reflect the logged model.
126          """
127          self._model.set_weights(self._best_weights)
128          x, y = self._train
129          train_res = self._model.evaluate(x=x, y=y)
130          for name, value in zip(self._model.metrics_names, train_res):
131              mlflow.log_metric(f"train_{name}", value)
132          x, y = self._valid
133          valid_res = self._model.evaluate(x=x, y=y)
134          for name, value in zip(self._model.metrics_names, valid_res):
135              mlflow.log_metric(f"valid_{name}", value)
136          signature = infer_signature(x, y)
137          log_model(keras_model=self._model, signature=signature, **self._pyfunc_params)
138  
139  
140  def _imagenet_preprocess_tf(x):
141      return (x / 127.5) - 1
142  
143  
144  def _create_model(input_shape, classes):
145      image = Input(input_shape)
146      lambda_layer = Lambda(_imagenet_preprocess_tf)
147      preprocessed_image = lambda_layer(image)
148      model = vgg16.VGG16(
149          classes=classes, input_tensor=preprocessed_image, weights=None, include_top=False
150      )
151  
152      x = Flatten(name="flatten")(model.output)
153      x = Dense(4096, activation="relu", name="fc1")(x)
154      x = Dense(4096, activation="relu", name="fc2")(x)
155      x = Dense(classes, activation="softmax", name="predictions")(x)
156      return Model(inputs=model.input, outputs=x)
157  
158  
159  def train(
160      image_files,
161      labels,
162      domain,
163      image_width=224,
164      image_height=224,
165      epochs=1,
166      batch_size=16,
167      test_ratio=0.2,
168      seed=None,
169  ):
170      """
171      Train VGG16 model on provided image files. This will create a new MLflow run and log all
172      parameters, metrics and the resulting model with MLflow. The resulting model is an instance
173      of KerasImageClassifierPyfunc - a custom python function model that embeds all necessary
174      preprocessing together with the VGG16 Keras model. The resulting model can be applied
175      directly to image base64 encoded image data.
176  
177      Args:
178          image_files: List of image files to be used for training.
179          labels: List of labels for the image files.
180          domain: Dictionary representing the domain of the response.
181              Provides mapping label-name -> label-id.
182          image_width: Width of the input image in pixels.
183          image_height: Height of the input image in pixels.
184          epochs: Number of epochs to train the model for.
185          batch_size: Batch size used during training.
186          test_ratio: Fraction of dataset to be used for validation. This data will not be used
187              during training.
188          seed: Random seed. Used e.g. when splitting the dataset into train / validation.
189  
190      """
191      assert len(set(labels)) == len(domain)
192  
193      input_shape = (image_width, image_height, 3)
194  
195      with mlflow.start_run():
196          mlflow.log_param("epochs", str(epochs))
197          mlflow.log_param("batch_size", str(batch_size))
198          mlflow.log_param("validation_ratio", str(test_ratio))
199          if seed:
200              mlflow.log_param("seed", str(seed))
201  
202          def _read_image(filename):
203              with open(filename, "rb") as f:
204                  return f.read()
205  
206          with tf.Graph().as_default() as g:
207              with tf.compat.v1.Session(graph=g).as_default():
208                  dims = input_shape[:2]
209                  x = np.array([decode_and_resize_image(_read_image(x), dims) for x in image_files])
210                  y = np_utils.to_categorical(np.array(labels), num_classes=len(domain))
211                  train_size = 1 - test_ratio
212                  x_train, x_valid, y_train, y_valid = train_test_split(
213                      x, y, random_state=seed, train_size=train_size
214                  )
215                  model = _create_model(input_shape=input_shape, classes=len(domain))
216                  model.compile(
217                      optimizer=keras.optimizers.SGD(decay=1e-5, nesterov=True, momentum=0.9),
218                      loss=keras.losses.categorical_crossentropy,
219                      metrics=["accuracy"],
220                  )
221                  sorted_domain = sorted(domain.keys(), key=lambda x: domain[x])
222                  model.fit(
223                      x=x_train,
224                      y=y_train,
225                      validation_data=(x_valid, y_valid),
226                      epochs=epochs,
227                      batch_size=batch_size,
228                      callbacks=[
229                          MlflowLogger(
230                              model=model,
231                              x_train=x_train,
232                              y_train=y_train,
233                              x_valid=x_valid,
234                              y_valid=y_valid,
235                              artifact_path="model",
236                              domain=sorted_domain,
237                              image_dims=input_shape,
238                          )
239                      ],
240                  )
241  
242  
243  if __name__ == "__main__":
244      run()