image_pyfunc.py
1 """ 2 Example of a custom python function implementing image classifier with image preprocessing embedded 3 in the model. 4 """ 5 6 import base64 7 import importlib.metadata 8 import os 9 from io import BytesIO 10 from typing import Any 11 12 import keras 13 import numpy as np 14 import pandas as pd 15 import PIL 16 import tensorflow as tf 17 import yaml 18 from PIL import Image 19 20 import mlflow 21 from mlflow.utils import PYTHON_VERSION 22 from mlflow.utils.file_utils import TempDir 23 24 25 def decode_and_resize_image(raw_bytes, size): 26 """ 27 Read, decode and resize raw image bytes (e.g. raw content of a jpeg file). 28 29 Args: 30 raw_bytes: Image bits, e.g. jpeg image. 31 size: Requested output dimensions. 32 33 Returns: 34 Multidimensional numpy array representing the resized image. 35 """ 36 return np.asarray(Image.open(BytesIO(raw_bytes)).resize(size), dtype=np.float32) 37 38 39 class KerasImageClassifierPyfunc: 40 """ 41 Image classification model with embedded pre-processing. 42 43 This class is essentially an MLflow custom python function wrapper around a Keras model. 44 The wrapper provides image preprocessing so that the model can be applied to images directly. 45 The input to the model is base64 encoded image binary data (e.g. contents of a jpeg file). 46 The output is the predicted class label, predicted class id followed by probabilities for each 47 class. 48 49 The model declares current local versions of Keras, Tensorlow and pillow as dependencies in its 50 conda environment file. 51 """ 52 53 def __init__(self, graph, session, model, image_dims, domain): 54 self._graph = graph 55 self._session = session 56 self._model = model 57 self._image_dims = image_dims 58 self._domain = domain 59 probs_names = [f"p({x})" for x in domain] 60 self._column_names = ["predicted_label", "predicted_label_id"] + probs_names 61 62 def predict( 63 self, 64 input, 65 params: dict[str, Any] | None = None, 66 ): 67 """ 68 Generate predictions for the data. 69 70 Args: 71 input: pandas.DataFrame with one column containing images to be scored. The image 72 column must contain base64 encoded binary content of the image files. The image 73 format must be supported by PIL (e.g. jpeg or png). 74 params: Additional parameters to pass to the model for inference. 75 76 Returns: 77 pandas.DataFrame containing predictions with the following schema: 78 Predicted class: string, 79 Predicted class index: int, 80 Probability(class==0): float, 81 ..., 82 Probability(class==N): float, 83 """ 84 85 # decode image bytes from base64 encoding 86 def decode_img(x): 87 return pd.Series(base64.decodebytes(bytearray(x[0], encoding="utf8"))) 88 89 images = input.apply(axis=1, func=decode_img) 90 probs = self._predict_images(images) 91 m, n = probs.shape 92 label_idx = np.argmax(probs, axis=1) 93 labels = np.array([self._domain[i] for i in label_idx], dtype=str).reshape(m, 1) 94 output_data = np.concatenate((labels, label_idx.reshape(m, 1), probs), axis=1) 95 res = pd.DataFrame(columns=self._column_names, data=output_data) 96 res.index = input.index 97 return res 98 99 def _predict_images(self, images): 100 """ 101 Generate predictions for input images. 102 103 Args: 104 images: Binary image data. 105 106 Returns: 107 Predicted probabilities for each class. 108 """ 109 110 def preprocess_f(z): 111 return decode_and_resize_image(z, self._image_dims[:2]) 112 113 x = np.array(images[images.columns[0]].apply(preprocess_f).tolist()) 114 with self._graph.as_default(): 115 with self._session.as_default(): 116 return self._model.predict(x) 117 118 119 def log_model(keras_model, signature, artifact_path, image_dims, domain): 120 """ 121 Log a KerasImageClassifierPyfunc model as an MLflow artifact for the current run. 122 123 Args: 124 keras_model: Keras model to be saved. 125 signature: Model signature. 126 artifact_path: Run-relative artifact path this model is to be saved to. 127 image_dims: Image dimensions the Keras model expects. 128 domain: Labels for the classes this model can predict. 129 """ 130 131 with TempDir() as tmp: 132 data_path = tmp.path("image_model") 133 os.mkdir(data_path) 134 conf = {"image_dims": "/".join(map(str, image_dims)), "domain": "/".join(map(str, domain))} 135 with open(os.path.join(data_path, "conf.yaml"), "w") as f: 136 yaml.safe_dump(conf, stream=f) 137 keras_path = os.path.join(data_path, "keras_model") 138 mlflow.tensorflow.save_model(model=keras_model, path=keras_path) 139 conda_env = tmp.path("conda_env.yaml") 140 with open(conda_env, "w") as f: 141 f.write( 142 conda_env_template.format( 143 python_version=PYTHON_VERSION, 144 keras_version=keras.__version__, 145 tf_name=tf.__name__, # can have optional -gpu suffix 146 tf_version=tf.__version__, 147 pip_version=importlib.metadata.version("pip"), 148 pillow_version=PIL.__version__, 149 ) 150 ) 151 152 mlflow.pyfunc.log_model( 153 name=artifact_path, 154 signature=signature, 155 loader_module=__name__, 156 code_paths=[__file__], 157 data_path=data_path, 158 conda_env=conda_env, 159 ) 160 161 162 def _load_pyfunc(path): 163 """ 164 Load the KerasImageClassifierPyfunc model. 165 """ 166 with open(os.path.join(path, "conf.yaml")) as f: 167 conf = yaml.safe_load(f) 168 keras_model_path = os.path.join(path, "keras_model") 169 domain = conf["domain"].split("/") 170 image_dims = np.array([int(x) for x in conf["image_dims"].split("/")], dtype=np.int32) 171 # NOTE: TensorFlow based models depend on global state (Graph and Session) given by the context. 172 # To make sure we score the model in the same session as we loaded it in, we create a new 173 # session and a new graph here and store them with the model. 174 with tf.Graph().as_default() as g: 175 with tf.Session().as_default() as sess: 176 keras.backend.set_session(sess) 177 keras_model = mlflow.tensorflow.load_model(keras_model_path) 178 return KerasImageClassifierPyfunc(g, sess, keras_model, image_dims, domain=domain) 179 180 181 conda_env_template = """ 182 name: flower_classifier 183 channels: 184 - conda-forge 185 dependencies: 186 - python=={python_version} 187 - pip=={pip_version} 188 - pip: 189 - mlflow>=1.6 190 - pillow=={pillow_version} 191 - keras=={keras_version} 192 - {tf_name}=={tf_version} 193 """