/ tests / pmdarima / test_pmdarima_model_export.py
test_pmdarima_model_export.py
  1  import json
  2  import os
  3  from pathlib import Path
  4  from unittest import mock
  5  
  6  import numpy as np
  7  import pandas as pd
  8  import pmdarima
  9  import pytest
 10  import yaml
 11  
 12  import mlflow.pmdarima
 13  import mlflow.pyfunc.scoring_server as pyfunc_scoring_server
 14  from mlflow import pyfunc
 15  from mlflow.exceptions import MlflowException
 16  from mlflow.models import Model, ModelSignature, infer_signature
 17  from mlflow.models.utils import _read_example, load_serving_example
 18  from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository
 19  from mlflow.tracking.artifact_utils import _download_artifact_from_uri
 20  from mlflow.types import DataType
 21  from mlflow.types.schema import ColSpec, Schema
 22  from mlflow.utils.environment import _mlflow_conda_env
 23  from mlflow.utils.model_utils import _get_flavor_configuration
 24  
 25  from tests.helper_functions import (
 26      _assert_pip_requirements,
 27      _compare_conda_env_requirements,
 28      _compare_logged_code_paths,
 29      _is_available_on_pypi,
 30      _mlflow_major_version_string,
 31      assert_register_model_called_with_local_model_path,
 32      pyfunc_serve_and_score_model,
 33  )
 34  from tests.prophet.test_prophet_model_export import DataGeneration
 35  
 36  EXTRA_PYFUNC_SERVING_TEST_ARGS = (
 37      [] if _is_available_on_pypi("pmdarima") else ["--env-manager", "local"]
 38  )
 39  
 40  
 41  @pytest.fixture
 42  def model_path(tmp_path):
 43      return tmp_path.joinpath("model")
 44  
 45  
 46  @pytest.fixture
 47  def pmdarima_custom_env(tmp_path):
 48      conda_env = tmp_path.joinpath("conda_env.yml")
 49      _mlflow_conda_env(conda_env, additional_pip_deps=["pmdarima"])
 50      return conda_env
 51  
 52  
 53  @pytest.fixture(scope="module")
 54  def test_data():
 55      data_conf = {
 56          "shift": False,
 57          "start": "2016-01-01",
 58          "size": 365 * 3,
 59          "seasonal_period": 7,
 60          "seasonal_freq": 0.1,
 61          "date_field": "date",
 62          "target_field": "orders",
 63      }
 64      raw = DataGeneration(**data_conf).create_series_df()
 65      return raw.set_index("date")
 66  
 67  
 68  @pytest.fixture(scope="module")
 69  def auto_arima_model(test_data):
 70      return pmdarima.auto_arima(
 71          test_data["orders"], max_d=1, suppress_warnings=True, error_action="raise"
 72      )
 73  
 74  
 75  @pytest.fixture(scope="module")
 76  def auto_arima_object_model(test_data):
 77      model = pmdarima.arima.ARIMA(order=(2, 1, 3), maxiter=25)
 78      return model.fit(test_data["orders"])
 79  
 80  
 81  def test_pmdarima_auto_arima_save_and_load(auto_arima_model, model_path):
 82      mlflow.pmdarima.save_model(pmdarima_model=auto_arima_model, path=model_path)
 83  
 84      loaded_model = mlflow.pmdarima.load_model(model_uri=model_path)
 85  
 86      np.testing.assert_array_equal(auto_arima_model.predict(10), loaded_model.predict(10))
 87  
 88  
 89  def test_load_model_disallows_pickle_deserialization(auto_arima_model, model_path, monkeypatch):
 90      mlflow.pmdarima.save_model(pmdarima_model=auto_arima_model, path=model_path)
 91  
 92      monkeypatch.setenv("MLFLOW_ALLOW_PICKLE_DESERIALIZATION", "false")
 93      with pytest.raises(MlflowException, match="MLFLOW_ALLOW_PICKLE_DESERIALIZATION"):
 94          mlflow.pmdarima.load_model(model_uri=model_path)
 95  
 96  
 97  def test_pmdarima_arima_object_save_and_load(auto_arima_object_model, model_path):
 98      mlflow.pmdarima.save_model(pmdarima_model=auto_arima_object_model, path=model_path)
 99  
100      loaded_model = mlflow.pmdarima.load_model(model_uri=model_path)
101  
102      np.testing.assert_array_equal(auto_arima_object_model.predict(30), loaded_model.predict(30))
103  
104  
105  def test_pmdarima_autoarima_pyfunc_save_and_load(auto_arima_model, model_path):
106      mlflow.pmdarima.save_model(pmdarima_model=auto_arima_model, path=model_path)
107      loaded_pyfunc = mlflow.pyfunc.load_model(model_uri=model_path)
108  
109      model_predict = auto_arima_model.predict(n_periods=60, return_conf_int=True, alpha=0.1)
110  
111      predict_conf = pd.DataFrame({"n_periods": 60, "return_conf_int": True, "alpha": 0.1}, index=[0])
112      pyfunc_predict = loaded_pyfunc.predict(predict_conf)
113  
114      np.testing.assert_array_equal(model_predict[0], pyfunc_predict["yhat"])
115      yhat_low, yhat_high = list(zip(*model_predict[1]))
116      np.testing.assert_array_equal(yhat_low, pyfunc_predict["yhat_lower"])
117      np.testing.assert_array_equal(yhat_high, pyfunc_predict["yhat_upper"])
118  
119  
120  @pytest.mark.parametrize("use_signature", [True, False])
121  @pytest.mark.parametrize("use_example", [True, False])
122  def test_pmdarima_signature_and_examples_saved_correctly(
123      auto_arima_model, model_path, use_signature, use_example
124  ):
125      # NB: Signature inference will only work on the first element of the tuple return
126      prediction = auto_arima_model.predict(n_periods=20, return_conf_int=True, alpha=0.05)
127      test_data = pd.DataFrame({"n_periods": [30]})
128      signature = infer_signature(test_data, prediction[0]) if use_signature or use_example else None
129      example = test_data if use_example else None
130      mlflow.pmdarima.save_model(
131          auto_arima_model, path=model_path, signature=signature, input_example=example
132      )
133      mlflow_model = Model.load(model_path)
134      if signature is None and example is None:
135          assert mlflow_model.signature is None
136      else:
137          assert mlflow_model.signature == signature
138      if example is None:
139          assert mlflow_model.saved_input_example_info is None
140      else:
141          r_example = _read_example(mlflow_model, model_path).copy(deep=False)
142          np.testing.assert_array_equal(r_example, example)
143  
144  
145  @pytest.mark.parametrize("use_signature", [True, False])
146  @pytest.mark.parametrize("use_example", [True, False])
147  def test_pmdarima_signature_and_example_for_confidence_interval_mode(
148      auto_arima_model, model_path, use_signature, use_example
149  ):
150      model_path_primary = model_path.joinpath("primary")
151      model_path_secondary = model_path.joinpath("secondary")
152      mlflow.pmdarima.save_model(pmdarima_model=auto_arima_model, path=model_path_primary)
153      loaded_pyfunc = mlflow.pyfunc.load_model(model_uri=model_path_primary)
154      predict_conf = pd.DataFrame([{"n_periods": 10, "return_conf_int": True, "alpha": 0.2}])
155      forecast = loaded_pyfunc.predict(predict_conf)
156      signature_ = infer_signature(predict_conf, forecast)
157      signature = signature_ if use_signature else None
158      example = predict_conf.copy(deep=False) if use_example else None
159      mlflow.pmdarima.save_model(
160          auto_arima_model, path=model_path_secondary, signature=signature, input_example=example
161      )
162      mlflow_model = Model.load(model_path_secondary)
163      if signature is None and example is None:
164          assert mlflow_model.signature is None
165      else:
166          assert mlflow_model.signature == signature_
167      if example is None:
168          assert mlflow_model.saved_input_example_info is None
169      else:
170          r_example = _read_example(mlflow_model, model_path_secondary).copy(deep=False)
171          np.testing.assert_array_equal(r_example, example)
172  
173  
174  def test_pmdarima_load_from_remote_uri_succeeds(
175      auto_arima_object_model, model_path, mock_s3_bucket
176  ):
177      mlflow.pmdarima.save_model(pmdarima_model=auto_arima_object_model, path=model_path)
178  
179      artifact_root = f"s3://{mock_s3_bucket}"
180      artifact_path = "model"
181      artifact_repo = S3ArtifactRepository(artifact_root)
182      artifact_repo.log_artifacts(model_path, artifact_path=artifact_path)
183  
184      # NB: cloudpathlib would need to be used here to handle object store uri
185      model_uri = os.path.join(artifact_root, artifact_path)
186      reloaded_pmdarima_model = mlflow.pmdarima.load_model(model_uri=model_uri)
187  
188      np.testing.assert_array_equal(
189          auto_arima_object_model.predict(30), reloaded_pmdarima_model.predict(30)
190      )
191  
192  
193  @pytest.mark.parametrize("should_start_run", [True, False])
194  def test_pmdarima_log_model(auto_arima_model, tmp_path, should_start_run):
195      try:
196          if should_start_run:
197              mlflow.start_run()
198          artifact_path = "pmdarima"
199          conda_env = tmp_path.joinpath("conda_env.yaml")
200          _mlflow_conda_env(conda_env, additional_pip_deps=["pmdarima"])
201          model_info = mlflow.pmdarima.log_model(
202              auto_arima_model,
203              name=artifact_path,
204              conda_env=str(conda_env),
205          )
206          reloaded_model = mlflow.pmdarima.load_model(model_uri=model_info.model_uri)
207          np.testing.assert_array_equal(auto_arima_model.predict(20), reloaded_model.predict(20))
208          model_path = Path(_download_artifact_from_uri(artifact_uri=model_info.model_uri))
209          model_config = Model.load(str(model_path.joinpath("MLmodel")))
210          assert pyfunc.FLAVOR_NAME in model_config.flavors
211          assert pyfunc.ENV in model_config.flavors[pyfunc.FLAVOR_NAME]
212          env_path = model_config.flavors[pyfunc.FLAVOR_NAME][pyfunc.ENV]["conda"]
213          assert model_path.joinpath(env_path).exists()
214      finally:
215          mlflow.end_run()
216  
217  
218  def test_pmdarima_log_model_calls_register_model(auto_arima_object_model, tmp_path):
219      artifact_path = "pmdarima"
220      register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model")
221      with mlflow.start_run(), register_model_patch:
222          conda_env = tmp_path.joinpath("conda_env.yaml")
223          _mlflow_conda_env(conda_env, additional_pip_deps=["pmdarima"])
224          model_info = mlflow.pmdarima.log_model(
225              auto_arima_object_model,
226              name=artifact_path,
227              conda_env=str(conda_env),
228              registered_model_name="PmdarimaModel",
229          )
230          assert_register_model_called_with_local_model_path(
231              mlflow.tracking._model_registry.fluent._register_model,
232              model_info.model_uri,
233              "PmdarimaModel",
234          )
235  
236  
237  def test_pmdarima_log_model_no_registered_model_name(auto_arima_model, tmp_path):
238      artifact_path = "pmdarima"
239      register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model")
240      with mlflow.start_run(), register_model_patch:
241          conda_env = tmp_path.joinpath("conda_env.yaml")
242          _mlflow_conda_env(conda_env, additional_pip_deps=["pmdarima"])
243          mlflow.pmdarima.log_model(auto_arima_model, name=artifact_path, conda_env=str(conda_env))
244          mlflow.tracking._model_registry.fluent._register_model.assert_not_called()
245  
246  
247  def test_pmdarima_model_save_persists_specified_conda_env_in_mlflow_model_directory(
248      auto_arima_object_model, model_path, pmdarima_custom_env
249  ):
250      mlflow.pmdarima.save_model(
251          pmdarima_model=auto_arima_object_model, path=model_path, conda_env=str(pmdarima_custom_env)
252      )
253      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
254      saved_conda_env_path = model_path.joinpath(pyfunc_conf[pyfunc.ENV]["conda"])
255      assert saved_conda_env_path.exists()
256      assert not pmdarima_custom_env.samefile(saved_conda_env_path)
257  
258      pmdarima_custom_env_parsed = yaml.safe_load(pmdarima_custom_env.read_bytes())
259      saved_conda_env_parsed = yaml.safe_load(saved_conda_env_path.read_bytes())
260      assert saved_conda_env_parsed == pmdarima_custom_env_parsed
261  
262  
263  def test_pmdarima_model_save_persists_requirements_in_mlflow_model_directory(
264      auto_arima_model, model_path, pmdarima_custom_env
265  ):
266      mlflow.pmdarima.save_model(
267          pmdarima_model=auto_arima_model, path=model_path, conda_env=str(pmdarima_custom_env)
268      )
269      saved_pip_req_path = model_path.joinpath("requirements.txt")
270      _compare_conda_env_requirements(pmdarima_custom_env, str(saved_pip_req_path))
271  
272  
273  def test_pmdarima_log_model_with_pip_requirements(auto_arima_object_model, tmp_path):
274      expected_mlflow_version = _mlflow_major_version_string()
275      req_file = tmp_path.joinpath("requirements.txt")
276      req_file.write_text("a")
277      with mlflow.start_run():
278          model_info = mlflow.pmdarima.log_model(
279              auto_arima_object_model, name="model", pip_requirements=str(req_file)
280          )
281          _assert_pip_requirements(model_info.model_uri, [expected_mlflow_version, "a"], strict=True)
282  
283      # List of requirements
284      with mlflow.start_run():
285          model_info = mlflow.pmdarima.log_model(
286              auto_arima_object_model, name="model", pip_requirements=[f"-r {req_file}", "b"]
287          )
288          _assert_pip_requirements(
289              model_info.model_uri, [expected_mlflow_version, "a", "b"], strict=True
290          )
291  
292      # Constraints file
293      with mlflow.start_run():
294          model_info = mlflow.pmdarima.log_model(
295              auto_arima_object_model, name="model", pip_requirements=[f"-c {req_file}", "b"]
296          )
297          _assert_pip_requirements(
298              model_info.model_uri,
299              [expected_mlflow_version, "b", "-c constraints.txt"],
300              ["a"],
301              strict=True,
302          )
303  
304  
305  def test_pmdarima_log_model_with_extra_pip_requirements(auto_arima_model, tmp_path):
306      expected_mlflow_version = _mlflow_major_version_string()
307      default_reqs = mlflow.pmdarima.get_default_pip_requirements()
308  
309      # Path to a requirements file
310      req_file = tmp_path.joinpath("requirements.txt")
311      req_file.write_text("a")
312      with mlflow.start_run():
313          model_info = mlflow.pmdarima.log_model(
314              auto_arima_model, name="model", extra_pip_requirements=str(req_file)
315          )
316          _assert_pip_requirements(
317              model_info.model_uri, [expected_mlflow_version, *default_reqs, "a"]
318          )
319  
320      # List of requirements
321      with mlflow.start_run():
322          model_info = mlflow.pmdarima.log_model(
323              auto_arima_model, name="model", extra_pip_requirements=[f"-r {req_file}", "b"]
324          )
325          _assert_pip_requirements(
326              model_info.model_uri, [expected_mlflow_version, *default_reqs, "a", "b"]
327          )
328  
329      # Constraints file
330      with mlflow.start_run():
331          model_info = mlflow.pmdarima.log_model(
332              auto_arima_model, name="model", extra_pip_requirements=[f"-c {req_file}", "b"]
333          )
334          _assert_pip_requirements(
335              model_uri=model_info.model_uri,
336              requirements=[expected_mlflow_version, *default_reqs, "b", "-c constraints.txt"],
337              constraints=["a"],
338              strict=False,
339          )
340  
341  
342  def test_pmdarima_model_save_without_conda_env_uses_default_env_with_expected_dependencies(
343      auto_arima_model, model_path
344  ):
345      mlflow.pmdarima.save_model(auto_arima_model, model_path)
346      _assert_pip_requirements(model_path, mlflow.pmdarima.get_default_pip_requirements())
347  
348  
349  def test_pmdarima_model_log_without_conda_env_uses_default_env_with_expected_dependencies(
350      auto_arima_object_model,
351  ):
352      artifact_path = "model"
353      with mlflow.start_run():
354          model_info = mlflow.pmdarima.log_model(auto_arima_object_model, name=artifact_path)
355      _assert_pip_requirements(model_info.model_uri, mlflow.pmdarima.get_default_pip_requirements())
356  
357  
358  def test_pmdarima_pyfunc_serve_and_score(auto_arima_model):
359      artifact_path = "model"
360      with mlflow.start_run():
361          model_info = mlflow.pmdarima.log_model(
362              auto_arima_model,
363              name=artifact_path,
364              input_example=pd.DataFrame({"n_periods": 30}, index=[0]),
365          )
366      local_predict = auto_arima_model.predict(30)
367  
368      inference_payload = load_serving_example(model_info.model_uri)
369      resp = pyfunc_serve_and_score_model(
370          model_info.model_uri,
371          data=inference_payload,
372          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
373          extra_args=EXTRA_PYFUNC_SERVING_TEST_ARGS,
374      )
375      scores = (
376          pd
377          .DataFrame(data=json.loads(resp.content.decode("utf-8"))["predictions"])
378          .to_numpy()
379          .flatten()
380      )
381      np.testing.assert_array_almost_equal(scores, local_predict)
382  
383  
384  def test_pmdarima_pyfunc_raises_invalid_df_input(auto_arima_model, model_path):
385      mlflow.pmdarima.save_model(pmdarima_model=auto_arima_model, path=model_path)
386      loaded_pyfunc = mlflow.pyfunc.load_model(model_uri=model_path)
387  
388      with pytest.raises(MlflowException, match="The provided prediction pd.DataFrame "):
389          loaded_pyfunc.predict(pd.DataFrame([{"n_periods": 60}, {"n_periods": 100}]))
390  
391      with pytest.raises(MlflowException, match="The provided prediction configuration "):
392          loaded_pyfunc.predict(pd.DataFrame([{"invalid": True}]))
393  
394      with pytest.raises(MlflowException, match="The provided `n_periods` value "):
395          loaded_pyfunc.predict(pd.DataFrame([{"n_periods": "60"}]))
396  
397  
398  def test_pmdarima_pyfunc_return_correct_structure(auto_arima_model, model_path):
399      mlflow.pmdarima.save_model(pmdarima_model=auto_arima_model, path=model_path)
400      loaded_pyfunc = mlflow.pyfunc.load_model(model_uri=model_path)
401  
402      predict_conf_no_ci = pd.DataFrame([{"n_periods": 10, "return_conf_int": False}])
403      forecast_no_ci = loaded_pyfunc.predict(predict_conf_no_ci)
404  
405      assert isinstance(forecast_no_ci, pd.DataFrame)
406      assert len(forecast_no_ci) == 10
407      assert len(forecast_no_ci.columns.values) == 1
408  
409      predict_conf_with_ci = pd.DataFrame([{"n_periods": 10, "return_conf_int": True}])
410      forecast_with_ci = loaded_pyfunc.predict(predict_conf_with_ci)
411  
412      assert isinstance(forecast_with_ci, pd.DataFrame)
413      assert len(forecast_with_ci) == 10
414      assert len(forecast_with_ci.columns.values) == 3
415  
416  
417  def test_log_model_with_code_paths(auto_arima_model):
418      artifact_path = "model"
419      with (
420          mlflow.start_run(),
421          mock.patch("mlflow.pmdarima._add_code_from_conf_to_system_path") as add_mock,
422      ):
423          model_info = mlflow.pmdarima.log_model(
424              auto_arima_model, name=artifact_path, code_paths=[__file__]
425          )
426          _compare_logged_code_paths(__file__, model_info.model_uri, mlflow.pmdarima.FLAVOR_NAME)
427          mlflow.pmdarima.load_model(model_info.model_uri)
428          add_mock.assert_called()
429  
430  
431  def test_virtualenv_subfield_points_to_correct_path(auto_arima_model, model_path):
432      mlflow.pmdarima.save_model(auto_arima_model, path=model_path)
433      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
434      python_env_path = Path(model_path, pyfunc_conf[pyfunc.ENV]["virtualenv"])
435      assert python_env_path.exists()
436      assert python_env_path.is_file()
437  
438  
439  def test_model_save_load_with_metadata(auto_arima_model, model_path):
440      mlflow.pmdarima.save_model(
441          auto_arima_model, path=model_path, metadata={"metadata_key": "metadata_value"}
442      )
443  
444      reloaded_model = mlflow.pyfunc.load_model(model_uri=model_path)
445      assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value"
446  
447  
448  def test_model_log_with_metadata(auto_arima_model):
449      artifact_path = "model"
450  
451      with mlflow.start_run():
452          model_info = mlflow.pmdarima.log_model(
453              auto_arima_model,
454              name=artifact_path,
455              metadata={"metadata_key": "metadata_value"},
456          )
457  
458      reloaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri)
459      assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value"
460  
461  
462  def test_model_log_with_signature_inference(auto_arima_model):
463      artifact_path = "model"
464      example = pd.DataFrame({"n_periods": 60, "return_conf_int": True, "alpha": 0.1}, index=[0])
465  
466      with mlflow.start_run():
467          model_info = mlflow.pmdarima.log_model(
468              auto_arima_model, name=artifact_path, input_example=example
469          )
470  
471      model_info_loaded = Model.load(model_info.model_uri)
472      assert model_info_loaded.signature == ModelSignature(
473          inputs=Schema([
474              ColSpec(name="n_periods", type=DataType.long),
475              ColSpec(name="return_conf_int", type=DataType.boolean),
476              ColSpec(name="alpha", type=DataType.double),
477          ]),
478          outputs=Schema([
479              ColSpec(name="yhat", type=DataType.double),
480              ColSpec(name="yhat_lower", type=DataType.double),
481              ColSpec(name="yhat_upper", type=DataType.double),
482          ]),
483      )