/ examples / flower_classifier / image_pyfunc.py
image_pyfunc.py
  1  """
  2  Example of a custom python function implementing image classifier with image preprocessing embedded
  3  in the model.
  4  """
  5  
  6  import base64
  7  import importlib.metadata
  8  import os
  9  from io import BytesIO
 10  from typing import Any
 11  
 12  import keras
 13  import numpy as np
 14  import pandas as pd
 15  import PIL
 16  import tensorflow as tf
 17  import yaml
 18  from PIL import Image
 19  
 20  import mlflow
 21  from mlflow.utils import PYTHON_VERSION
 22  from mlflow.utils.file_utils import TempDir
 23  
 24  
 25  def decode_and_resize_image(raw_bytes, size):
 26      """
 27      Read, decode and resize raw image bytes (e.g. raw content of a jpeg file).
 28  
 29      Args:
 30          raw_bytes: Image bits, e.g. jpeg image.
 31          size: Requested output dimensions.
 32  
 33      Returns:
 34          Multidimensional numpy array representing the resized image.
 35      """
 36      return np.asarray(Image.open(BytesIO(raw_bytes)).resize(size), dtype=np.float32)
 37  
 38  
 39  class KerasImageClassifierPyfunc:
 40      """
 41      Image classification model with embedded pre-processing.
 42  
 43      This class is essentially an MLflow custom python function wrapper around a Keras model.
 44      The wrapper provides image preprocessing so that the model can be applied to images directly.
 45      The input to the model is base64 encoded image binary data (e.g. contents of a jpeg file).
 46      The output is the predicted class label, predicted class id followed by probabilities for each
 47      class.
 48  
 49      The model declares current local versions of Keras, Tensorlow and pillow as dependencies in its
 50      conda environment file.
 51      """
 52  
 53      def __init__(self, graph, session, model, image_dims, domain):
 54          self._graph = graph
 55          self._session = session
 56          self._model = model
 57          self._image_dims = image_dims
 58          self._domain = domain
 59          probs_names = [f"p({x})" for x in domain]
 60          self._column_names = ["predicted_label", "predicted_label_id"] + probs_names
 61  
 62      def predict(
 63          self,
 64          input,
 65          params: dict[str, Any] | None = None,
 66      ):
 67          """
 68          Generate predictions for the data.
 69  
 70          Args:
 71              input: pandas.DataFrame with one column containing images to be scored. The image
 72                  column must contain base64 encoded binary content of the image files. The image
 73                  format must be supported by PIL (e.g. jpeg or png).
 74              params: Additional parameters to pass to the model for inference.
 75  
 76          Returns:
 77              pandas.DataFrame containing predictions with the following schema:
 78                  Predicted class: string,
 79                  Predicted class index: int,
 80                  Probability(class==0): float,
 81                  ...,
 82                  Probability(class==N): float,
 83          """
 84  
 85          # decode image bytes from base64 encoding
 86          def decode_img(x):
 87              return pd.Series(base64.decodebytes(bytearray(x[0], encoding="utf8")))
 88  
 89          images = input.apply(axis=1, func=decode_img)
 90          probs = self._predict_images(images)
 91          m, n = probs.shape
 92          label_idx = np.argmax(probs, axis=1)
 93          labels = np.array([self._domain[i] for i in label_idx], dtype=str).reshape(m, 1)
 94          output_data = np.concatenate((labels, label_idx.reshape(m, 1), probs), axis=1)
 95          res = pd.DataFrame(columns=self._column_names, data=output_data)
 96          res.index = input.index
 97          return res
 98  
 99      def _predict_images(self, images):
100          """
101          Generate predictions for input images.
102  
103          Args:
104              images: Binary image data.
105  
106          Returns:
107              Predicted probabilities for each class.
108          """
109  
110          def preprocess_f(z):
111              return decode_and_resize_image(z, self._image_dims[:2])
112  
113          x = np.array(images[images.columns[0]].apply(preprocess_f).tolist())
114          with self._graph.as_default():
115              with self._session.as_default():
116                  return self._model.predict(x)
117  
118  
119  def log_model(keras_model, signature, artifact_path, image_dims, domain):
120      """
121      Log a KerasImageClassifierPyfunc model as an MLflow artifact for the current run.
122  
123      Args:
124          keras_model: Keras model to be saved.
125          signature: Model signature.
126          artifact_path: Run-relative artifact path this model is to be saved to.
127          image_dims: Image dimensions the Keras model expects.
128          domain: Labels for the classes this model can predict.
129      """
130  
131      with TempDir() as tmp:
132          data_path = tmp.path("image_model")
133          os.mkdir(data_path)
134          conf = {"image_dims": "/".join(map(str, image_dims)), "domain": "/".join(map(str, domain))}
135          with open(os.path.join(data_path, "conf.yaml"), "w") as f:
136              yaml.safe_dump(conf, stream=f)
137          keras_path = os.path.join(data_path, "keras_model")
138          mlflow.tensorflow.save_model(model=keras_model, path=keras_path)
139          conda_env = tmp.path("conda_env.yaml")
140          with open(conda_env, "w") as f:
141              f.write(
142                  conda_env_template.format(
143                      python_version=PYTHON_VERSION,
144                      keras_version=keras.__version__,
145                      tf_name=tf.__name__,  # can have optional -gpu suffix
146                      tf_version=tf.__version__,
147                      pip_version=importlib.metadata.version("pip"),
148                      pillow_version=PIL.__version__,
149                  )
150              )
151  
152          mlflow.pyfunc.log_model(
153              name=artifact_path,
154              signature=signature,
155              loader_module=__name__,
156              code_paths=[__file__],
157              data_path=data_path,
158              conda_env=conda_env,
159          )
160  
161  
162  def _load_pyfunc(path):
163      """
164      Load the KerasImageClassifierPyfunc model.
165      """
166      with open(os.path.join(path, "conf.yaml")) as f:
167          conf = yaml.safe_load(f)
168      keras_model_path = os.path.join(path, "keras_model")
169      domain = conf["domain"].split("/")
170      image_dims = np.array([int(x) for x in conf["image_dims"].split("/")], dtype=np.int32)
171      # NOTE: TensorFlow based models depend on global state (Graph and Session) given by the context.
172      # To make sure we score the model in the same session as we loaded it in, we create a new
173      # session and a new graph here and store them with the model.
174      with tf.Graph().as_default() as g:
175          with tf.Session().as_default() as sess:
176              keras.backend.set_session(sess)
177              keras_model = mlflow.tensorflow.load_model(keras_model_path)
178      return KerasImageClassifierPyfunc(g, sess, keras_model, image_dims, domain=domain)
179  
180  
181  conda_env_template = """
182  name: flower_classifier
183  channels:
184    - conda-forge
185  dependencies:
186    - python=={python_version}
187    - pip=={pip_version}
188    - pip:
189      - mlflow>=1.6
190      - pillow=={pillow_version}
191      - keras=={keras_version}
192      - {tf_name}=={tf_version}
193  """