test_tensorflow2_core_model_export.py
1 import os 2 from typing import Any, NamedTuple 3 from unittest import mock 4 5 import numpy as np 6 import pytest 7 import tensorflow as tf 8 9 import mlflow.tensorflow 10 from mlflow.models import Model, infer_signature 11 12 13 class ToyModel(tf.Module): 14 def __init__(self, w, b): 15 super().__init__() 16 self.w = w 17 self.b = b 18 19 @tf.function 20 def __call__(self, x): 21 return tf.reshape(tf.add(tf.matmul(x, self.w), self.b), [-1]) 22 23 24 class TF2ModelInfo(NamedTuple): 25 model: Any 26 inference_data: Any 27 expected_results: Any 28 29 30 @pytest.fixture 31 def tf2_toy_model(): 32 tf.random.set_seed(1337) 33 rand_w = tf.random.uniform(shape=[3, 1], dtype=tf.float32) 34 rand_b = tf.random.uniform(shape=[], dtype=tf.float32) 35 36 inference_data = np.array([[2, 3, 4], [5, 6, 7]], dtype=np.float32) 37 model = ToyModel(rand_w, rand_b) 38 expected_results = model(inference_data) 39 40 return TF2ModelInfo( 41 model=model, 42 inference_data=inference_data, 43 expected_results=expected_results, 44 ) 45 46 47 def test_save_and_load_tf2_module(tmp_path, tf2_toy_model): 48 model_path = os.path.join(tmp_path, "model") 49 mlflow.tensorflow.save_model(tf2_toy_model.model, model_path) 50 51 loaded_model = mlflow.tensorflow.load_model(model_path) 52 53 predictions = loaded_model(tf2_toy_model.inference_data).numpy() 54 np.testing.assert_allclose( 55 predictions, 56 tf2_toy_model.expected_results, 57 ) 58 59 60 def test_log_and_load_tf2_module(tf2_toy_model): 61 with mlflow.start_run(): 62 model_info = mlflow.tensorflow.log_model(tf2_toy_model.model, name="model") 63 64 model_uri = model_info.model_uri 65 loaded_model = mlflow.tensorflow.load_model(model_uri) 66 predictions = loaded_model(tf2_toy_model.inference_data).numpy() 67 np.testing.assert_allclose( 68 predictions, 69 tf2_toy_model.expected_results, 70 ) 71 72 loaded_model2 = mlflow.pyfunc.load_model(model_uri) 73 predictions2 = loaded_model2.predict(tf2_toy_model.inference_data) 74 assert isinstance(predictions2, np.ndarray) 75 np.testing.assert_allclose( 76 predictions2, 77 tf2_toy_model.expected_results, 78 ) 79 80 81 def test_model_log_with_signature_inference(tf2_toy_model): 82 artifact_path = "model" 83 example = tf2_toy_model.inference_data 84 85 with mlflow.start_run(): 86 model_info = mlflow.tensorflow.log_model( 87 tf2_toy_model.model, name=artifact_path, input_example=example 88 ) 89 90 mlflow_model = Model.load(model_info.model_uri) 91 assert mlflow_model.signature == infer_signature( 92 tf2_toy_model.inference_data, tf2_toy_model.expected_results.numpy() 93 ) 94 95 96 def test_save_with_options(tmp_path, tf2_toy_model): 97 model_path = os.path.join(tmp_path, "model") 98 99 saved_model_kwargs = { 100 "signatures": [tf.TensorSpec(shape=None, dtype=tf.float32)], 101 "options": tf.saved_model.SaveOptions(save_debug_info=True), 102 } 103 104 with mock.patch("tensorflow.saved_model.save") as mock_save: 105 mlflow.tensorflow.save_model( 106 tf2_toy_model.model, model_path, saved_model_kwargs=saved_model_kwargs 107 ) 108 mock_save.assert_called_once_with(mock.ANY, mock.ANY, **saved_model_kwargs) 109 110 mock_save.reset_mock() 111 112 with mlflow.start_run(): 113 mlflow.tensorflow.log_model( 114 tf2_toy_model.model, name="model", saved_model_kwargs=saved_model_kwargs 115 ) 116 117 mock_save.assert_called_once_with(mock.ANY, mock.ANY, **saved_model_kwargs) 118 119 120 def test_load_with_options(tmp_path, tf2_toy_model): 121 model_path = os.path.join(tmp_path, "model") 122 mlflow.tensorflow.save_model(tf2_toy_model.model, model_path) 123 124 saved_model_kwargs = { 125 "options": tf.saved_model.LoadOptions(), 126 } 127 with mock.patch("tensorflow.saved_model.load") as mock_load: 128 mlflow.tensorflow.load_model(model_path, saved_model_kwargs=saved_model_kwargs) 129 mock_load.assert_called_once_with(mock.ANY, **saved_model_kwargs)