test_sktime_model_export.py
1 import os 2 from pathlib import Path 3 from unittest import mock 4 5 import boto3 6 import flavor 7 import moto 8 import numpy as np 9 import pandas as pd 10 import pytest 11 from botocore.config import Config 12 from sktime.datasets import load_airline, load_longley 13 from sktime.datatypes import convert 14 from sktime.forecasting.arima import AutoARIMA 15 from sktime.forecasting.model_selection import temporal_train_test_split 16 from sktime.forecasting.naive import NaiveForecaster 17 18 import mlflow 19 from mlflow import pyfunc 20 from mlflow.exceptions import MlflowException 21 from mlflow.models import Model, infer_signature 22 from mlflow.models.utils import _read_example 23 from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository 24 from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS 25 from mlflow.tracking.artifact_utils import _download_artifact_from_uri 26 from mlflow.utils.environment import _mlflow_conda_env 27 28 FH = [1, 2, 3] 29 COVERAGE = [0.1, 0.5, 0.9] 30 ALPHA = [0.1, 0.5, 0.9] 31 COV = False 32 33 34 @pytest.fixture 35 def model_path(tmp_path): 36 """Create a temporary path to save/log model.""" 37 return tmp_path.joinpath("model") 38 39 40 @pytest.fixture(scope="module") 41 def mock_s3_bucket(): 42 """Create a mock S3 bucket using moto. 43 44 Returns 45 ------- 46 string with name of mock S3 bucket 47 """ 48 with moto.mock_s3(): 49 bucket_name = "mock-bucket" 50 my_config = Config(region_name="us-east-1") 51 s3_client = boto3.client("s3", config=my_config) 52 s3_client.create_bucket(Bucket=bucket_name) 53 yield bucket_name 54 55 56 @pytest.fixture 57 def sktime_custom_env(tmp_path): 58 """Create a conda environment and returns path to conda environment yml file.""" 59 conda_env = tmp_path.joinpath("conda_env.yml") 60 _mlflow_conda_env(conda_env, additional_pip_deps=["sktime"]) 61 return conda_env 62 63 64 @pytest.fixture(scope="module") 65 def data_airline(): 66 """Create sample data for univariate model without exogenous regressor.""" 67 return load_airline() 68 69 70 @pytest.fixture(scope="module") 71 def data_longley(): 72 """Create sample data for univariate model with exogenous regressor.""" 73 y, X = load_longley() 74 y_train, y_test, X_train, X_test = temporal_train_test_split(y, X) 75 return y_train, y_test, X_train, X_test 76 77 78 @pytest.fixture(scope="module") 79 def auto_arima_model(data_airline): 80 """Create instance of fitted auto arima model.""" 81 return AutoARIMA(sp=12, d=0, max_p=2, max_q=2, suppress_warnings=True).fit(data_airline) 82 83 84 @pytest.fixture(scope="module") 85 def naive_forecaster_model_with_regressor(data_longley): 86 """Create instance of fitted naive forecaster model.""" 87 y_train, _, X_train, _ = data_longley 88 model = NaiveForecaster() 89 return model.fit(y_train, X_train) 90 91 92 @pytest.mark.parametrize("serialization_format", ["pickle", "cloudpickle"]) 93 def test_auto_arima_model_save_and_load(auto_arima_model, model_path, serialization_format): 94 flavor.save_model( 95 sktime_model=auto_arima_model, 96 path=model_path, 97 serialization_format=serialization_format, 98 ) 99 loaded_model = flavor.load_model( 100 model_uri=model_path, 101 ) 102 103 np.testing.assert_array_equal(auto_arima_model.predict(fh=FH), loaded_model.predict(fh=FH)) 104 105 106 @pytest.mark.parametrize("serialization_format", ["pickle", "cloudpickle"]) 107 def test_auto_arima_model_pyfunc_output(auto_arima_model, model_path, serialization_format): 108 flavor.save_model( 109 sktime_model=auto_arima_model, 110 path=model_path, 111 serialization_format=serialization_format, 112 ) 113 loaded_pyfunc = flavor.pyfunc.load_model(model_uri=model_path) 114 115 model_predict = auto_arima_model.predict(fh=FH) 116 predict_conf = pd.DataFrame([{"fh": FH, "predict_method": "predict"}]) 117 pyfunc_predict = loaded_pyfunc.predict(predict_conf) 118 np.testing.assert_array_equal(model_predict, pyfunc_predict) 119 120 model_predict_interval = auto_arima_model.predict_interval(fh=FH, coverage=COVERAGE) 121 predict_interval_conf = pd.DataFrame([ 122 { 123 "fh": FH, 124 "predict_method": "predict_interval", 125 "coverage": COVERAGE, 126 } 127 ]) 128 pyfunc_predict_interval = loaded_pyfunc.predict(predict_interval_conf) 129 np.testing.assert_array_equal(model_predict_interval.values, pyfunc_predict_interval.values) 130 131 model_predict_quantiles = auto_arima_model.predict_quantiles(fh=FH, alpha=ALPHA) 132 predict_quantiles_conf = pd.DataFrame([ 133 { 134 "fh": FH, 135 "predict_method": "predict_quantiles", 136 "alpha": ALPHA, 137 } 138 ]) 139 pyfunc_predict_quantiles = loaded_pyfunc.predict(predict_quantiles_conf) 140 np.testing.assert_array_equal(model_predict_quantiles.values, pyfunc_predict_quantiles.values) 141 142 model_predict_var = auto_arima_model.predict_var(fh=FH, cov=COV) 143 predict_var_conf = pd.DataFrame([{"fh": FH, "predict_method": "predict_var", "cov": COV}]) 144 pyfunc_predict_var = loaded_pyfunc.predict(predict_var_conf) 145 np.testing.assert_array_equal(model_predict_var.values, pyfunc_predict_var.values) 146 147 148 def test_naive_forecaster_model_with_regressor_pyfunc_output( 149 naive_forecaster_model_with_regressor, model_path, data_longley 150 ): 151 _, _, _, X_test = data_longley 152 153 flavor.save_model(sktime_model=naive_forecaster_model_with_regressor, path=model_path) 154 loaded_pyfunc = flavor.pyfunc.load_model(model_uri=model_path) 155 156 X_test_array = convert(X_test, "pd.DataFrame", "np.ndarray") 157 158 model_predict = naive_forecaster_model_with_regressor.predict(fh=FH, X=X_test) 159 predict_conf = pd.DataFrame([{"fh": FH, "predict_method": "predict", "X": X_test_array}]) 160 pyfunc_predict = loaded_pyfunc.predict(predict_conf) 161 np.testing.assert_array_equal(model_predict, pyfunc_predict) 162 163 model_predict_interval = naive_forecaster_model_with_regressor.predict_interval( 164 fh=FH, coverage=COVERAGE, X=X_test 165 ) 166 predict_interval_conf = pd.DataFrame([ 167 { 168 "fh": FH, 169 "predict_method": "predict_interval", 170 "coverage": COVERAGE, 171 "X": X_test_array, 172 } 173 ]) 174 pyfunc_predict_interval = loaded_pyfunc.predict(predict_interval_conf) 175 np.testing.assert_array_equal(model_predict_interval.values, pyfunc_predict_interval.values) 176 177 model_predict_quantiles = naive_forecaster_model_with_regressor.predict_quantiles( 178 fh=FH, alpha=ALPHA, X=X_test 179 ) 180 predict_quantiles_conf = pd.DataFrame([ 181 { 182 "fh": FH, 183 "predict_method": "predict_quantiles", 184 "alpha": ALPHA, 185 "X": X_test_array, 186 } 187 ]) 188 pyfunc_predict_quantiles = loaded_pyfunc.predict(predict_quantiles_conf) 189 np.testing.assert_array_equal(model_predict_quantiles.values, pyfunc_predict_quantiles.values) 190 191 model_predict_var = naive_forecaster_model_with_regressor.predict_var(fh=FH, cov=COV, X=X_test) 192 predict_var_conf = pd.DataFrame([ 193 { 194 "fh": FH, 195 "predict_method": "predict_var", 196 "cov": COV, 197 "X": X_test_array, 198 } 199 ]) 200 pyfunc_predict_var = loaded_pyfunc.predict(predict_var_conf) 201 np.testing.assert_array_equal(model_predict_var.values, pyfunc_predict_var.values) 202 203 204 @pytest.mark.parametrize("use_signature", [True, False]) 205 @pytest.mark.parametrize("use_example", [True, False]) 206 def test_signature_and_examples_saved_correctly( 207 auto_arima_model, data_airline, model_path, use_signature, use_example 208 ): 209 # Note: Signature inference fails on native model predict_interval/predict_quantiles 210 prediction = auto_arima_model.predict(fh=FH) 211 signature = infer_signature(data_airline, prediction) if use_signature else None 212 example = pd.DataFrame(data_airline[0:5].copy(deep=False)) if use_example else None 213 flavor.save_model(auto_arima_model, path=model_path, signature=signature, input_example=example) 214 mlflow_model = Model.load(model_path) 215 assert signature == mlflow_model.signature 216 if example is None: 217 assert mlflow_model.saved_input_example_info is None 218 else: 219 r_example = _read_example(mlflow_model, model_path).copy(deep=False) 220 np.testing.assert_array_equal(r_example, example) 221 222 223 @pytest.mark.parametrize("use_signature", [True, False]) 224 def test_predict_var_signature_saved_correctly( 225 auto_arima_model, data_airline, model_path, use_signature 226 ): 227 prediction = auto_arima_model.predict_var(fh=FH) 228 signature = infer_signature(data_airline, prediction) if use_signature else None 229 flavor.save_model(auto_arima_model, path=model_path, signature=signature) 230 mlflow_model = Model.load(model_path) 231 assert signature == mlflow_model.signature 232 233 234 @pytest.mark.parametrize("use_signature", [True, False]) 235 @pytest.mark.parametrize("use_example", [True, False]) 236 def test_signature_and_example_for_pyfunc_predict_interval( 237 auto_arima_model, model_path, data_airline, use_signature, use_example 238 ): 239 model_path_primary = model_path.joinpath("primary") 240 model_path_secondary = model_path.joinpath("secondary") 241 flavor.save_model(sktime_model=auto_arima_model, path=model_path_primary) 242 loaded_pyfunc = flavor.pyfunc.load_model(model_uri=model_path_primary) 243 predict_conf = pd.DataFrame([ 244 { 245 "fh": FH, 246 "predict_method": "predict_interval", 247 "coverage": COVERAGE, 248 } 249 ]) 250 forecast = loaded_pyfunc.predict(predict_conf) 251 signature = infer_signature(data_airline, forecast) if use_signature else None 252 example = pd.DataFrame(data_airline[0:5].copy(deep=False)) if use_example else None 253 flavor.save_model( 254 auto_arima_model, 255 path=model_path_secondary, 256 signature=signature, 257 input_example=example, 258 ) 259 mlflow_model = Model.load(model_path_secondary) 260 assert signature == mlflow_model.signature 261 if example is None: 262 assert mlflow_model.saved_input_example_info is None 263 else: 264 r_example = _read_example(mlflow_model, model_path_secondary).copy(deep=False) 265 np.testing.assert_array_equal(r_example, example) 266 267 268 @pytest.mark.parametrize("use_signature", [True, False]) 269 def test_signature_for_pyfunc_predict_quantiles( 270 auto_arima_model, model_path, data_airline, use_signature 271 ): 272 model_path_primary = model_path.joinpath("primary") 273 model_path_secondary = model_path.joinpath("secondary") 274 flavor.save_model(sktime_model=auto_arima_model, path=model_path_primary) 275 loaded_pyfunc = flavor.pyfunc.load_model(model_uri=model_path_primary) 276 predict_conf = pd.DataFrame([ 277 { 278 "fh": FH, 279 "predict_method": "predict_quantiles", 280 "alpha": ALPHA, 281 } 282 ]) 283 forecast = loaded_pyfunc.predict(predict_conf) 284 signature = infer_signature(data_airline, forecast) if use_signature else None 285 flavor.save_model(auto_arima_model, path=model_path_secondary, signature=signature) 286 mlflow_model = Model.load(model_path_secondary) 287 assert signature == mlflow_model.signature 288 289 290 def test_load_from_remote_uri_succeeds(auto_arima_model, model_path, mock_s3_bucket): 291 flavor.save_model(sktime_model=auto_arima_model, path=model_path) 292 293 artifact_root = f"s3://{mock_s3_bucket}" 294 artifact_path = "model" 295 artifact_repo = S3ArtifactRepository(artifact_root) 296 artifact_repo.log_artifacts(model_path, artifact_path=artifact_path) 297 298 model_uri = os.path.join(artifact_root, artifact_path) 299 reloaded_sktime_model = flavor.load_model(model_uri=model_uri) 300 301 np.testing.assert_array_equal( 302 auto_arima_model.predict(fh=FH), 303 reloaded_sktime_model.predict(fh=FH), 304 ) 305 306 307 @pytest.mark.parametrize("should_start_run", [True, False]) 308 @pytest.mark.parametrize("serialization_format", ["pickle", "cloudpickle"]) 309 def test_log_model(auto_arima_model, tmp_path, should_start_run, serialization_format): 310 try: 311 if should_start_run: 312 mlflow.start_run() 313 artifact_path = "sktime" 314 conda_env = tmp_path.joinpath("conda_env.yaml") 315 _mlflow_conda_env(conda_env, additional_pip_deps=["sktime"]) 316 model_info = flavor.log_model( 317 sktime_model=auto_arima_model, 318 artifact_path=artifact_path, 319 conda_env=str(conda_env), 320 serialization_format=serialization_format, 321 ) 322 model_uri = f"runs:/{mlflow.active_run().info.run_id}/{artifact_path}" 323 assert model_info.model_uri == model_uri 324 reloaded_model = flavor.load_model( 325 model_uri=model_uri, 326 ) 327 np.testing.assert_array_equal(auto_arima_model.predict(), reloaded_model.predict()) 328 model_path = Path(_download_artifact_from_uri(artifact_uri=model_uri)) 329 model_config = Model.load(str(model_path.joinpath("MLmodel"))) 330 assert pyfunc.FLAVOR_NAME in model_config.flavors 331 finally: 332 mlflow.end_run() 333 334 335 def test_log_model_calls_register_model(auto_arima_model, tmp_path): 336 artifact_path = "sktime" 337 register_model_patch = mock.patch("mlflow.register_model") 338 with mlflow.start_run(), register_model_patch: 339 conda_env = tmp_path.joinpath("conda_env.yaml") 340 _mlflow_conda_env(conda_env, additional_pip_deps=["sktime"]) 341 flavor.log_model( 342 sktime_model=auto_arima_model, 343 artifact_path=artifact_path, 344 conda_env=str(conda_env), 345 registered_model_name="SktimeModel", 346 ) 347 model_uri = f"runs:/{mlflow.active_run().info.run_id}/{artifact_path}" 348 mlflow.register_model.assert_called_once_with( 349 model_uri, 350 "SktimeModel", 351 await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS, 352 ) 353 354 355 def test_log_model_no_registered_model_name(auto_arima_model, tmp_path): 356 artifact_path = "sktime" 357 register_model_patch = mock.patch("mlflow.register_model") 358 with mlflow.start_run(), register_model_patch: 359 conda_env = tmp_path.joinpath("conda_env.yaml") 360 _mlflow_conda_env(conda_env, additional_pip_deps=["sktime"]) 361 flavor.log_model( 362 sktime_model=auto_arima_model, 363 artifact_path=artifact_path, 364 conda_env=str(conda_env), 365 ) 366 mlflow.register_model.assert_not_called() 367 368 369 def test_sktime_pyfunc_raises_invalid_df_input(auto_arima_model, model_path): 370 flavor.save_model(sktime_model=auto_arima_model, path=model_path) 371 loaded_pyfunc = flavor.pyfunc.load_model(model_uri=model_path) 372 373 with pytest.raises(MlflowException, match="The provided prediction pd.DataFrame "): 374 loaded_pyfunc.predict(pd.DataFrame([{"predict_method": "predict"}, {"fh": FH}])) 375 376 with pytest.raises(MlflowException, match="The provided prediction configuration "): 377 loaded_pyfunc.predict(pd.DataFrame([{"invalid": True}])) 378 379 with pytest.raises(MlflowException, match="Invalid `predict_method` value"): 380 loaded_pyfunc.predict(pd.DataFrame([{"predict_method": "predict_proba"}])) 381 382 383 def test_sktime_save_model_raises_invalid_serialization_format(auto_arima_model, model_path): 384 with pytest.raises(MlflowException, match="Unrecognized serialization format: "): 385 flavor.save_model( 386 sktime_model=auto_arima_model, path=model_path, serialization_format="json" 387 )