/ tests / tensorflow / test_tensorflow2_core_model_export.py
test_tensorflow2_core_model_export.py
  1  import os
  2  from typing import Any, NamedTuple
  3  from unittest import mock
  4  
  5  import numpy as np
  6  import pytest
  7  import tensorflow as tf
  8  
  9  import mlflow.tensorflow
 10  from mlflow.models import Model, infer_signature
 11  
 12  
 13  class ToyModel(tf.Module):
 14      def __init__(self, w, b):
 15          super().__init__()
 16          self.w = w
 17          self.b = b
 18  
 19      @tf.function
 20      def __call__(self, x):
 21          return tf.reshape(tf.add(tf.matmul(x, self.w), self.b), [-1])
 22  
 23  
 24  class TF2ModelInfo(NamedTuple):
 25      model: Any
 26      inference_data: Any
 27      expected_results: Any
 28  
 29  
 30  @pytest.fixture
 31  def tf2_toy_model():
 32      tf.random.set_seed(1337)
 33      rand_w = tf.random.uniform(shape=[3, 1], dtype=tf.float32)
 34      rand_b = tf.random.uniform(shape=[], dtype=tf.float32)
 35  
 36      inference_data = np.array([[2, 3, 4], [5, 6, 7]], dtype=np.float32)
 37      model = ToyModel(rand_w, rand_b)
 38      expected_results = model(inference_data)
 39  
 40      return TF2ModelInfo(
 41          model=model,
 42          inference_data=inference_data,
 43          expected_results=expected_results,
 44      )
 45  
 46  
 47  def test_save_and_load_tf2_module(tmp_path, tf2_toy_model):
 48      model_path = os.path.join(tmp_path, "model")
 49      mlflow.tensorflow.save_model(tf2_toy_model.model, model_path)
 50  
 51      loaded_model = mlflow.tensorflow.load_model(model_path)
 52  
 53      predictions = loaded_model(tf2_toy_model.inference_data).numpy()
 54      np.testing.assert_allclose(
 55          predictions,
 56          tf2_toy_model.expected_results,
 57      )
 58  
 59  
 60  def test_log_and_load_tf2_module(tf2_toy_model):
 61      with mlflow.start_run():
 62          model_info = mlflow.tensorflow.log_model(tf2_toy_model.model, name="model")
 63  
 64      model_uri = model_info.model_uri
 65      loaded_model = mlflow.tensorflow.load_model(model_uri)
 66      predictions = loaded_model(tf2_toy_model.inference_data).numpy()
 67      np.testing.assert_allclose(
 68          predictions,
 69          tf2_toy_model.expected_results,
 70      )
 71  
 72      loaded_model2 = mlflow.pyfunc.load_model(model_uri)
 73      predictions2 = loaded_model2.predict(tf2_toy_model.inference_data)
 74      assert isinstance(predictions2, np.ndarray)
 75      np.testing.assert_allclose(
 76          predictions2,
 77          tf2_toy_model.expected_results,
 78      )
 79  
 80  
 81  def test_model_log_with_signature_inference(tf2_toy_model):
 82      artifact_path = "model"
 83      example = tf2_toy_model.inference_data
 84  
 85      with mlflow.start_run():
 86          model_info = mlflow.tensorflow.log_model(
 87              tf2_toy_model.model, name=artifact_path, input_example=example
 88          )
 89  
 90      mlflow_model = Model.load(model_info.model_uri)
 91      assert mlflow_model.signature == infer_signature(
 92          tf2_toy_model.inference_data, tf2_toy_model.expected_results.numpy()
 93      )
 94  
 95  
 96  def test_save_with_options(tmp_path, tf2_toy_model):
 97      model_path = os.path.join(tmp_path, "model")
 98  
 99      saved_model_kwargs = {
100          "signatures": [tf.TensorSpec(shape=None, dtype=tf.float32)],
101          "options": tf.saved_model.SaveOptions(save_debug_info=True),
102      }
103  
104      with mock.patch("tensorflow.saved_model.save") as mock_save:
105          mlflow.tensorflow.save_model(
106              tf2_toy_model.model, model_path, saved_model_kwargs=saved_model_kwargs
107          )
108          mock_save.assert_called_once_with(mock.ANY, mock.ANY, **saved_model_kwargs)
109  
110          mock_save.reset_mock()
111  
112          with mlflow.start_run():
113              mlflow.tensorflow.log_model(
114                  tf2_toy_model.model, name="model", saved_model_kwargs=saved_model_kwargs
115              )
116  
117          mock_save.assert_called_once_with(mock.ANY, mock.ANY, **saved_model_kwargs)
118  
119  
120  def test_load_with_options(tmp_path, tf2_toy_model):
121      model_path = os.path.join(tmp_path, "model")
122      mlflow.tensorflow.save_model(tf2_toy_model.model, model_path)
123  
124      saved_model_kwargs = {
125          "options": tf.saved_model.LoadOptions(),
126      }
127      with mock.patch("tensorflow.saved_model.load") as mock_load:
128          mlflow.tensorflow.load_model(model_path, saved_model_kwargs=saved_model_kwargs)
129          mock_load.assert_called_once_with(mock.ANY, **saved_model_kwargs)