/ examples / sktime / test_sktime_model_export.py
test_sktime_model_export.py
  1  import os
  2  from pathlib import Path
  3  from unittest import mock
  4  
  5  import boto3
  6  import flavor
  7  import moto
  8  import numpy as np
  9  import pandas as pd
 10  import pytest
 11  from botocore.config import Config
 12  from sktime.datasets import load_airline, load_longley
 13  from sktime.datatypes import convert
 14  from sktime.forecasting.arima import AutoARIMA
 15  from sktime.forecasting.model_selection import temporal_train_test_split
 16  from sktime.forecasting.naive import NaiveForecaster
 17  
 18  import mlflow
 19  from mlflow import pyfunc
 20  from mlflow.exceptions import MlflowException
 21  from mlflow.models import Model, infer_signature
 22  from mlflow.models.utils import _read_example
 23  from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository
 24  from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
 25  from mlflow.tracking.artifact_utils import _download_artifact_from_uri
 26  from mlflow.utils.environment import _mlflow_conda_env
 27  
 28  FH = [1, 2, 3]
 29  COVERAGE = [0.1, 0.5, 0.9]
 30  ALPHA = [0.1, 0.5, 0.9]
 31  COV = False
 32  
 33  
 34  @pytest.fixture
 35  def model_path(tmp_path):
 36      """Create a temporary path to save/log model."""
 37      return tmp_path.joinpath("model")
 38  
 39  
 40  @pytest.fixture(scope="module")
 41  def mock_s3_bucket():
 42      """Create a mock S3 bucket using moto.
 43  
 44      Returns
 45      -------
 46      string with name of mock S3 bucket
 47      """
 48      with moto.mock_s3():
 49          bucket_name = "mock-bucket"
 50          my_config = Config(region_name="us-east-1")
 51          s3_client = boto3.client("s3", config=my_config)
 52          s3_client.create_bucket(Bucket=bucket_name)
 53          yield bucket_name
 54  
 55  
 56  @pytest.fixture
 57  def sktime_custom_env(tmp_path):
 58      """Create a conda environment and returns path to conda environment yml file."""
 59      conda_env = tmp_path.joinpath("conda_env.yml")
 60      _mlflow_conda_env(conda_env, additional_pip_deps=["sktime"])
 61      return conda_env
 62  
 63  
 64  @pytest.fixture(scope="module")
 65  def data_airline():
 66      """Create sample data for univariate model without exogenous regressor."""
 67      return load_airline()
 68  
 69  
 70  @pytest.fixture(scope="module")
 71  def data_longley():
 72      """Create sample data for univariate model with exogenous regressor."""
 73      y, X = load_longley()
 74      y_train, y_test, X_train, X_test = temporal_train_test_split(y, X)
 75      return y_train, y_test, X_train, X_test
 76  
 77  
 78  @pytest.fixture(scope="module")
 79  def auto_arima_model(data_airline):
 80      """Create instance of fitted auto arima model."""
 81      return AutoARIMA(sp=12, d=0, max_p=2, max_q=2, suppress_warnings=True).fit(data_airline)
 82  
 83  
 84  @pytest.fixture(scope="module")
 85  def naive_forecaster_model_with_regressor(data_longley):
 86      """Create instance of fitted naive forecaster model."""
 87      y_train, _, X_train, _ = data_longley
 88      model = NaiveForecaster()
 89      return model.fit(y_train, X_train)
 90  
 91  
 92  @pytest.mark.parametrize("serialization_format", ["pickle", "cloudpickle"])
 93  def test_auto_arima_model_save_and_load(auto_arima_model, model_path, serialization_format):
 94      flavor.save_model(
 95          sktime_model=auto_arima_model,
 96          path=model_path,
 97          serialization_format=serialization_format,
 98      )
 99      loaded_model = flavor.load_model(
100          model_uri=model_path,
101      )
102  
103      np.testing.assert_array_equal(auto_arima_model.predict(fh=FH), loaded_model.predict(fh=FH))
104  
105  
106  @pytest.mark.parametrize("serialization_format", ["pickle", "cloudpickle"])
107  def test_auto_arima_model_pyfunc_output(auto_arima_model, model_path, serialization_format):
108      flavor.save_model(
109          sktime_model=auto_arima_model,
110          path=model_path,
111          serialization_format=serialization_format,
112      )
113      loaded_pyfunc = flavor.pyfunc.load_model(model_uri=model_path)
114  
115      model_predict = auto_arima_model.predict(fh=FH)
116      predict_conf = pd.DataFrame([{"fh": FH, "predict_method": "predict"}])
117      pyfunc_predict = loaded_pyfunc.predict(predict_conf)
118      np.testing.assert_array_equal(model_predict, pyfunc_predict)
119  
120      model_predict_interval = auto_arima_model.predict_interval(fh=FH, coverage=COVERAGE)
121      predict_interval_conf = pd.DataFrame([
122          {
123              "fh": FH,
124              "predict_method": "predict_interval",
125              "coverage": COVERAGE,
126          }
127      ])
128      pyfunc_predict_interval = loaded_pyfunc.predict(predict_interval_conf)
129      np.testing.assert_array_equal(model_predict_interval.values, pyfunc_predict_interval.values)
130  
131      model_predict_quantiles = auto_arima_model.predict_quantiles(fh=FH, alpha=ALPHA)
132      predict_quantiles_conf = pd.DataFrame([
133          {
134              "fh": FH,
135              "predict_method": "predict_quantiles",
136              "alpha": ALPHA,
137          }
138      ])
139      pyfunc_predict_quantiles = loaded_pyfunc.predict(predict_quantiles_conf)
140      np.testing.assert_array_equal(model_predict_quantiles.values, pyfunc_predict_quantiles.values)
141  
142      model_predict_var = auto_arima_model.predict_var(fh=FH, cov=COV)
143      predict_var_conf = pd.DataFrame([{"fh": FH, "predict_method": "predict_var", "cov": COV}])
144      pyfunc_predict_var = loaded_pyfunc.predict(predict_var_conf)
145      np.testing.assert_array_equal(model_predict_var.values, pyfunc_predict_var.values)
146  
147  
148  def test_naive_forecaster_model_with_regressor_pyfunc_output(
149      naive_forecaster_model_with_regressor, model_path, data_longley
150  ):
151      _, _, _, X_test = data_longley
152  
153      flavor.save_model(sktime_model=naive_forecaster_model_with_regressor, path=model_path)
154      loaded_pyfunc = flavor.pyfunc.load_model(model_uri=model_path)
155  
156      X_test_array = convert(X_test, "pd.DataFrame", "np.ndarray")
157  
158      model_predict = naive_forecaster_model_with_regressor.predict(fh=FH, X=X_test)
159      predict_conf = pd.DataFrame([{"fh": FH, "predict_method": "predict", "X": X_test_array}])
160      pyfunc_predict = loaded_pyfunc.predict(predict_conf)
161      np.testing.assert_array_equal(model_predict, pyfunc_predict)
162  
163      model_predict_interval = naive_forecaster_model_with_regressor.predict_interval(
164          fh=FH, coverage=COVERAGE, X=X_test
165      )
166      predict_interval_conf = pd.DataFrame([
167          {
168              "fh": FH,
169              "predict_method": "predict_interval",
170              "coverage": COVERAGE,
171              "X": X_test_array,
172          }
173      ])
174      pyfunc_predict_interval = loaded_pyfunc.predict(predict_interval_conf)
175      np.testing.assert_array_equal(model_predict_interval.values, pyfunc_predict_interval.values)
176  
177      model_predict_quantiles = naive_forecaster_model_with_regressor.predict_quantiles(
178          fh=FH, alpha=ALPHA, X=X_test
179      )
180      predict_quantiles_conf = pd.DataFrame([
181          {
182              "fh": FH,
183              "predict_method": "predict_quantiles",
184              "alpha": ALPHA,
185              "X": X_test_array,
186          }
187      ])
188      pyfunc_predict_quantiles = loaded_pyfunc.predict(predict_quantiles_conf)
189      np.testing.assert_array_equal(model_predict_quantiles.values, pyfunc_predict_quantiles.values)
190  
191      model_predict_var = naive_forecaster_model_with_regressor.predict_var(fh=FH, cov=COV, X=X_test)
192      predict_var_conf = pd.DataFrame([
193          {
194              "fh": FH,
195              "predict_method": "predict_var",
196              "cov": COV,
197              "X": X_test_array,
198          }
199      ])
200      pyfunc_predict_var = loaded_pyfunc.predict(predict_var_conf)
201      np.testing.assert_array_equal(model_predict_var.values, pyfunc_predict_var.values)
202  
203  
204  @pytest.mark.parametrize("use_signature", [True, False])
205  @pytest.mark.parametrize("use_example", [True, False])
206  def test_signature_and_examples_saved_correctly(
207      auto_arima_model, data_airline, model_path, use_signature, use_example
208  ):
209      # Note: Signature inference fails on native model predict_interval/predict_quantiles
210      prediction = auto_arima_model.predict(fh=FH)
211      signature = infer_signature(data_airline, prediction) if use_signature else None
212      example = pd.DataFrame(data_airline[0:5].copy(deep=False)) if use_example else None
213      flavor.save_model(auto_arima_model, path=model_path, signature=signature, input_example=example)
214      mlflow_model = Model.load(model_path)
215      assert signature == mlflow_model.signature
216      if example is None:
217          assert mlflow_model.saved_input_example_info is None
218      else:
219          r_example = _read_example(mlflow_model, model_path).copy(deep=False)
220          np.testing.assert_array_equal(r_example, example)
221  
222  
223  @pytest.mark.parametrize("use_signature", [True, False])
224  def test_predict_var_signature_saved_correctly(
225      auto_arima_model, data_airline, model_path, use_signature
226  ):
227      prediction = auto_arima_model.predict_var(fh=FH)
228      signature = infer_signature(data_airline, prediction) if use_signature else None
229      flavor.save_model(auto_arima_model, path=model_path, signature=signature)
230      mlflow_model = Model.load(model_path)
231      assert signature == mlflow_model.signature
232  
233  
234  @pytest.mark.parametrize("use_signature", [True, False])
235  @pytest.mark.parametrize("use_example", [True, False])
236  def test_signature_and_example_for_pyfunc_predict_interval(
237      auto_arima_model, model_path, data_airline, use_signature, use_example
238  ):
239      model_path_primary = model_path.joinpath("primary")
240      model_path_secondary = model_path.joinpath("secondary")
241      flavor.save_model(sktime_model=auto_arima_model, path=model_path_primary)
242      loaded_pyfunc = flavor.pyfunc.load_model(model_uri=model_path_primary)
243      predict_conf = pd.DataFrame([
244          {
245              "fh": FH,
246              "predict_method": "predict_interval",
247              "coverage": COVERAGE,
248          }
249      ])
250      forecast = loaded_pyfunc.predict(predict_conf)
251      signature = infer_signature(data_airline, forecast) if use_signature else None
252      example = pd.DataFrame(data_airline[0:5].copy(deep=False)) if use_example else None
253      flavor.save_model(
254          auto_arima_model,
255          path=model_path_secondary,
256          signature=signature,
257          input_example=example,
258      )
259      mlflow_model = Model.load(model_path_secondary)
260      assert signature == mlflow_model.signature
261      if example is None:
262          assert mlflow_model.saved_input_example_info is None
263      else:
264          r_example = _read_example(mlflow_model, model_path_secondary).copy(deep=False)
265          np.testing.assert_array_equal(r_example, example)
266  
267  
268  @pytest.mark.parametrize("use_signature", [True, False])
269  def test_signature_for_pyfunc_predict_quantiles(
270      auto_arima_model, model_path, data_airline, use_signature
271  ):
272      model_path_primary = model_path.joinpath("primary")
273      model_path_secondary = model_path.joinpath("secondary")
274      flavor.save_model(sktime_model=auto_arima_model, path=model_path_primary)
275      loaded_pyfunc = flavor.pyfunc.load_model(model_uri=model_path_primary)
276      predict_conf = pd.DataFrame([
277          {
278              "fh": FH,
279              "predict_method": "predict_quantiles",
280              "alpha": ALPHA,
281          }
282      ])
283      forecast = loaded_pyfunc.predict(predict_conf)
284      signature = infer_signature(data_airline, forecast) if use_signature else None
285      flavor.save_model(auto_arima_model, path=model_path_secondary, signature=signature)
286      mlflow_model = Model.load(model_path_secondary)
287      assert signature == mlflow_model.signature
288  
289  
290  def test_load_from_remote_uri_succeeds(auto_arima_model, model_path, mock_s3_bucket):
291      flavor.save_model(sktime_model=auto_arima_model, path=model_path)
292  
293      artifact_root = f"s3://{mock_s3_bucket}"
294      artifact_path = "model"
295      artifact_repo = S3ArtifactRepository(artifact_root)
296      artifact_repo.log_artifacts(model_path, artifact_path=artifact_path)
297  
298      model_uri = os.path.join(artifact_root, artifact_path)
299      reloaded_sktime_model = flavor.load_model(model_uri=model_uri)
300  
301      np.testing.assert_array_equal(
302          auto_arima_model.predict(fh=FH),
303          reloaded_sktime_model.predict(fh=FH),
304      )
305  
306  
307  @pytest.mark.parametrize("should_start_run", [True, False])
308  @pytest.mark.parametrize("serialization_format", ["pickle", "cloudpickle"])
309  def test_log_model(auto_arima_model, tmp_path, should_start_run, serialization_format):
310      try:
311          if should_start_run:
312              mlflow.start_run()
313          artifact_path = "sktime"
314          conda_env = tmp_path.joinpath("conda_env.yaml")
315          _mlflow_conda_env(conda_env, additional_pip_deps=["sktime"])
316          model_info = flavor.log_model(
317              sktime_model=auto_arima_model,
318              artifact_path=artifact_path,
319              conda_env=str(conda_env),
320              serialization_format=serialization_format,
321          )
322          model_uri = f"runs:/{mlflow.active_run().info.run_id}/{artifact_path}"
323          assert model_info.model_uri == model_uri
324          reloaded_model = flavor.load_model(
325              model_uri=model_uri,
326          )
327          np.testing.assert_array_equal(auto_arima_model.predict(), reloaded_model.predict())
328          model_path = Path(_download_artifact_from_uri(artifact_uri=model_uri))
329          model_config = Model.load(str(model_path.joinpath("MLmodel")))
330          assert pyfunc.FLAVOR_NAME in model_config.flavors
331      finally:
332          mlflow.end_run()
333  
334  
335  def test_log_model_calls_register_model(auto_arima_model, tmp_path):
336      artifact_path = "sktime"
337      register_model_patch = mock.patch("mlflow.register_model")
338      with mlflow.start_run(), register_model_patch:
339          conda_env = tmp_path.joinpath("conda_env.yaml")
340          _mlflow_conda_env(conda_env, additional_pip_deps=["sktime"])
341          flavor.log_model(
342              sktime_model=auto_arima_model,
343              artifact_path=artifact_path,
344              conda_env=str(conda_env),
345              registered_model_name="SktimeModel",
346          )
347          model_uri = f"runs:/{mlflow.active_run().info.run_id}/{artifact_path}"
348          mlflow.register_model.assert_called_once_with(
349              model_uri,
350              "SktimeModel",
351              await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
352          )
353  
354  
355  def test_log_model_no_registered_model_name(auto_arima_model, tmp_path):
356      artifact_path = "sktime"
357      register_model_patch = mock.patch("mlflow.register_model")
358      with mlflow.start_run(), register_model_patch:
359          conda_env = tmp_path.joinpath("conda_env.yaml")
360          _mlflow_conda_env(conda_env, additional_pip_deps=["sktime"])
361          flavor.log_model(
362              sktime_model=auto_arima_model,
363              artifact_path=artifact_path,
364              conda_env=str(conda_env),
365          )
366          mlflow.register_model.assert_not_called()
367  
368  
369  def test_sktime_pyfunc_raises_invalid_df_input(auto_arima_model, model_path):
370      flavor.save_model(sktime_model=auto_arima_model, path=model_path)
371      loaded_pyfunc = flavor.pyfunc.load_model(model_uri=model_path)
372  
373      with pytest.raises(MlflowException, match="The provided prediction pd.DataFrame "):
374          loaded_pyfunc.predict(pd.DataFrame([{"predict_method": "predict"}, {"fh": FH}]))
375  
376      with pytest.raises(MlflowException, match="The provided prediction configuration "):
377          loaded_pyfunc.predict(pd.DataFrame([{"invalid": True}]))
378  
379      with pytest.raises(MlflowException, match="Invalid `predict_method` value"):
380          loaded_pyfunc.predict(pd.DataFrame([{"predict_method": "predict_proba"}]))
381  
382  
383  def test_sktime_save_model_raises_invalid_serialization_format(auto_arima_model, model_path):
384      with pytest.raises(MlflowException, match="Unrecognized serialization format: "):
385          flavor.save_model(
386              sktime_model=auto_arima_model, path=model_path, serialization_format="json"
387          )