/ tests / statsmodels / test_statsmodels_model_export.py
test_statsmodels_model_export.py
  1  import json
  2  import os
  3  from pathlib import Path
  4  from unittest import mock
  5  
  6  import numpy as np
  7  import pandas as pd
  8  import pytest
  9  import yaml
 10  
 11  import mlflow.pyfunc.scoring_server as pyfunc_scoring_server
 12  import mlflow.statsmodels
 13  from mlflow import pyfunc
 14  from mlflow.models import Model
 15  from mlflow.models.utils import _read_example, load_serving_example
 16  from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository
 17  from mlflow.tracking.artifact_utils import _download_artifact_from_uri
 18  from mlflow.utils.environment import _mlflow_conda_env
 19  from mlflow.utils.file_utils import TempDir
 20  from mlflow.utils.model_utils import _get_flavor_configuration
 21  
 22  from tests.helper_functions import (
 23      _assert_pip_requirements,
 24      _compare_conda_env_requirements,
 25      _compare_logged_code_paths,
 26      _is_available_on_pypi,
 27      _mlflow_major_version_string,
 28      assert_register_model_called_with_local_model_path,
 29      pyfunc_serve_and_score_model,
 30  )
 31  from tests.statsmodels.model_fixtures import (
 32      arma_model,
 33      gee_model,
 34      glm_model,
 35      gls_model,
 36      glsar_model,
 37      ols_model,
 38      ols_model_signature,
 39      recursivels_model,
 40      rolling_ols_model,
 41      rolling_wls_model,
 42      wls_model,
 43  )
 44  
 45  EXTRA_PYFUNC_SERVING_TEST_ARGS = (
 46      [] if _is_available_on_pypi("statsmodels") else ["--env-manager", "local"]
 47  )
 48  
 49  # The code in this file has been adapted from the test cases of the lightgbm flavor.
 50  
 51  
 52  def _get_dates_from_df(df):
 53      start_date = df["start"][0]
 54      end_date = df["end"][0]
 55      return start_date, end_date
 56  
 57  
 58  @pytest.fixture
 59  def model_path(tmp_path, subdir="model"):
 60      return os.path.join(tmp_path, subdir)
 61  
 62  
 63  @pytest.fixture
 64  def statsmodels_custom_env(tmp_path):
 65      conda_env = os.path.join(tmp_path, "conda_env.yml")
 66      _mlflow_conda_env(conda_env, additional_pip_deps=["pytest", "statsmodels"])
 67      return conda_env
 68  
 69  
 70  def _test_models_list(tmp_path, func_to_apply):
 71      from statsmodels.tsa.base.tsa_model import TimeSeriesModel
 72  
 73      fixtures = [
 74          ols_model,
 75          arma_model,
 76          glsar_model,
 77          gee_model,
 78          glm_model,
 79          gls_model,
 80          recursivels_model,
 81          rolling_ols_model,
 82          rolling_wls_model,
 83          wls_model,
 84      ]
 85  
 86      for algorithm in fixtures:
 87          name = algorithm.__name__
 88          path = os.path.join(tmp_path, name)
 89          model = algorithm()
 90          if isinstance(model.alg, TimeSeriesModel):
 91              start_date, end_date = _get_dates_from_df(model.inference_dataframe)
 92              func_to_apply(model, path, start_date, end_date)
 93          else:
 94              func_to_apply(model, path, model.inference_dataframe)
 95  
 96  
 97  def _test_model_save_load(statsmodels_model, model_path, *predict_args):
 98      mlflow.statsmodels.save_model(statsmodels_model=statsmodels_model.model, path=model_path)
 99      reloaded_model = mlflow.statsmodels.load_model(model_uri=model_path)
100      reloaded_pyfunc = pyfunc.load_model(model_uri=model_path)
101  
102      if hasattr(statsmodels_model.model, "predict"):
103          np.testing.assert_array_almost_equal(
104              statsmodels_model.model.predict(*predict_args),
105              reloaded_model.predict(*predict_args),
106          )
107  
108          np.testing.assert_array_almost_equal(
109              reloaded_model.predict(*predict_args),
110              reloaded_pyfunc.predict(statsmodels_model.inference_dataframe),
111          )
112  
113  
114  def _test_model_log(statsmodels_model, model_path, *predict_args):
115      model = statsmodels_model.model
116      with TempDir(chdr=True, remove_on_exit=True) as tmp:
117          try:
118              artifact_path = "model"
119              conda_env = os.path.join(tmp.path(), "conda_env.yaml")
120              _mlflow_conda_env(conda_env, additional_pip_deps=["statsmodels"])
121  
122              model_info = mlflow.statsmodels.log_model(
123                  model, name=artifact_path, conda_env=conda_env
124              )
125              reloaded_model = mlflow.statsmodels.load_model(model_uri=model_info.model_uri)
126              if hasattr(model, "predict"):
127                  np.testing.assert_array_almost_equal(
128                      model.predict(*predict_args), reloaded_model.predict(*predict_args)
129                  )
130  
131              model_path = _download_artifact_from_uri(artifact_uri=model_info.model_uri)
132              model_config = Model.load(os.path.join(model_path, "MLmodel"))
133              assert pyfunc.FLAVOR_NAME in model_config.flavors
134              assert pyfunc.ENV in model_config.flavors[pyfunc.FLAVOR_NAME]
135              env_path = model_config.flavors[pyfunc.FLAVOR_NAME][pyfunc.ENV]["conda"]
136              assert os.path.exists(os.path.join(model_path, env_path))
137          finally:
138              mlflow.end_run()
139  
140  
141  def test_models_save_load(tmp_path):
142      _test_models_list(tmp_path, _test_model_save_load)
143  
144  
145  def test_models_log(tmp_path):
146      _test_models_list(tmp_path, _test_model_log)
147  
148  
149  def test_signature_and_examples_are_saved_correctly():
150      model, _, X = ols_model()
151      signature_ = ols_model_signature()
152      example_ = X[0:3, :]
153  
154      for signature in (None, signature_):
155          for example in (None, example_):
156              with TempDir() as tmp:
157                  path = tmp.path("model")
158                  mlflow.statsmodels.save_model(
159                      model, path=path, signature=signature, input_example=example
160                  )
161                  mlflow_model = Model.load(path)
162                  if signature is None and example is None:
163                      assert mlflow_model.signature is None
164                  else:
165                      assert mlflow_model.signature == signature_
166                  if example is None:
167                      assert mlflow_model.saved_input_example_info is None
168                  else:
169                      np.testing.assert_array_equal(_read_example(mlflow_model, path), example)
170  
171  
172  def test_model_load_from_remote_uri_succeeds(model_path, mock_s3_bucket):
173      model, _, inference_dataframe = arma_model()
174      mlflow.statsmodels.save_model(statsmodels_model=model, path=model_path)
175  
176      artifact_root = f"s3://{mock_s3_bucket}"
177      artifact_path = "model"
178      artifact_repo = S3ArtifactRepository(artifact_root)
179      artifact_repo.log_artifacts(model_path, artifact_path=artifact_path)
180  
181      model_uri = artifact_root + "/" + artifact_path
182      reloaded_model = mlflow.statsmodels.load_model(model_uri=model_uri)
183      start_date, end_date = _get_dates_from_df(inference_dataframe)
184      np.testing.assert_array_almost_equal(
185          model.predict(start=start_date, end=end_date),
186          reloaded_model.predict(start=start_date, end=end_date),
187      )
188  
189  
190  def test_log_model_calls_register_model():
191      # Adapted from lightgbm tests
192      ols = ols_model()
193      artifact_path = "model"
194      register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model")
195      with mlflow.start_run(), register_model_patch, TempDir(chdr=True, remove_on_exit=True) as tmp:
196          conda_env = os.path.join(tmp.path(), "conda_env.yaml")
197          _mlflow_conda_env(conda_env, additional_pip_deps=["statsmodels"])
198          model_info = mlflow.statsmodels.log_model(
199              ols.model,
200              name=artifact_path,
201              conda_env=conda_env,
202              registered_model_name="OLSModel1",
203          )
204          assert_register_model_called_with_local_model_path(
205              register_model_mock=mlflow.tracking._model_registry.fluent._register_model,
206              model_uri=model_info.model_uri,
207              registered_model_name="OLSModel1",
208          )
209  
210  
211  def test_log_model_no_registered_model_name():
212      ols = ols_model()
213      artifact_path = "model"
214      register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model")
215      with mlflow.start_run(), register_model_patch, TempDir(chdr=True, remove_on_exit=True) as tmp:
216          conda_env = os.path.join(tmp.path(), "conda_env.yaml")
217          _mlflow_conda_env(conda_env, additional_pip_deps=["statsmodels"])
218          mlflow.statsmodels.log_model(ols.model, name=artifact_path, conda_env=conda_env)
219          mlflow.tracking._model_registry.fluent._register_model.assert_not_called()
220  
221  
222  def test_model_save_persists_specified_conda_env_in_mlflow_model_directory(
223      model_path, statsmodels_custom_env
224  ):
225      ols = ols_model()
226      mlflow.statsmodels.save_model(
227          statsmodels_model=ols.model, path=model_path, conda_env=statsmodels_custom_env
228      )
229  
230      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
231      saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"])
232      assert os.path.exists(saved_conda_env_path)
233      assert saved_conda_env_path != statsmodels_custom_env
234  
235      with open(statsmodels_custom_env) as f:
236          statsmodels_custom_env_parsed = yaml.safe_load(f)
237      with open(saved_conda_env_path) as f:
238          saved_conda_env_parsed = yaml.safe_load(f)
239      assert saved_conda_env_parsed == statsmodels_custom_env_parsed
240  
241  
242  def test_model_save_persists_requirements_in_mlflow_model_directory(
243      model_path, statsmodels_custom_env
244  ):
245      ols = ols_model()
246      mlflow.statsmodels.save_model(
247          statsmodels_model=ols.model, path=model_path, conda_env=statsmodels_custom_env
248      )
249  
250      saved_pip_req_path = os.path.join(model_path, "requirements.txt")
251      _compare_conda_env_requirements(statsmodels_custom_env, saved_pip_req_path)
252  
253  
254  def test_log_model_with_pip_requirements(tmp_path):
255      expected_mlflow_version = _mlflow_major_version_string()
256      ols = ols_model()
257      # Path to a requirements file
258      req_file = tmp_path.joinpath("requirements.txt")
259      req_file.write_text("a")
260      with mlflow.start_run():
261          model_info = mlflow.statsmodels.log_model(
262              ols.model, name="model", pip_requirements=str(req_file)
263          )
264          _assert_pip_requirements(model_info.model_uri, [expected_mlflow_version, "a"], strict=True)
265  
266      # List of requirements
267      with mlflow.start_run():
268          model_info = mlflow.statsmodels.log_model(
269              ols.model, name="model", pip_requirements=[f"-r {req_file}", "b"]
270          )
271          _assert_pip_requirements(
272              model_info.model_uri, [expected_mlflow_version, "a", "b"], strict=True
273          )
274  
275      # Constraints file
276      with mlflow.start_run():
277          model_info = mlflow.statsmodels.log_model(
278              ols.model, name="model", pip_requirements=[f"-c {req_file}", "b"]
279          )
280          _assert_pip_requirements(
281              model_info.model_uri,
282              [expected_mlflow_version, "b", "-c constraints.txt"],
283              ["a"],
284              strict=True,
285          )
286  
287  
288  def test_log_model_with_extra_pip_requirements(tmp_path):
289      expected_mlflow_version = _mlflow_major_version_string()
290      ols = ols_model()
291      default_reqs = mlflow.statsmodels.get_default_pip_requirements()
292  
293      # Path to a requirements file
294      req_file = tmp_path.joinpath("requirements.txt")
295      req_file.write_text("a")
296      with mlflow.start_run():
297          model_info = mlflow.statsmodels.log_model(
298              ols.model, name="model", extra_pip_requirements=str(req_file)
299          )
300          _assert_pip_requirements(
301              model_info.model_uri, [expected_mlflow_version, *default_reqs, "a"]
302          )
303  
304      # List of requirements
305      with mlflow.start_run():
306          model_info = mlflow.statsmodels.log_model(
307              ols.model, name="model", extra_pip_requirements=[f"-r {req_file}", "b"]
308          )
309          _assert_pip_requirements(
310              model_info.model_uri, [expected_mlflow_version, *default_reqs, "a", "b"]
311          )
312  
313      # Constraints file
314      with mlflow.start_run():
315          model_info = mlflow.statsmodels.log_model(
316              ols.model, name="model", extra_pip_requirements=[f"-c {req_file}", "b"]
317          )
318          _assert_pip_requirements(
319              model_info.model_uri,
320              [expected_mlflow_version, *default_reqs, "b", "-c constraints.txt"],
321              ["a"],
322          )
323  
324  
325  def test_model_save_accepts_conda_env_as_dict(model_path):
326      ols = ols_model()
327      conda_env = dict(mlflow.statsmodels.get_default_conda_env())
328      conda_env["dependencies"].append("pytest")
329      mlflow.statsmodels.save_model(statsmodels_model=ols.model, path=model_path, conda_env=conda_env)
330  
331      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
332      saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"])
333      assert os.path.exists(saved_conda_env_path)
334  
335      with open(saved_conda_env_path) as f:
336          saved_conda_env_parsed = yaml.safe_load(f)
337      assert saved_conda_env_parsed == conda_env
338  
339  
340  def test_model_log_persists_specified_conda_env_in_mlflow_model_directory(statsmodels_custom_env):
341      ols = ols_model()
342      with mlflow.start_run():
343          model_info = mlflow.statsmodels.log_model(
344              ols.model,
345              name="model",
346              conda_env=statsmodels_custom_env,
347          )
348  
349      model_path = _download_artifact_from_uri(artifact_uri=model_info.model_uri)
350      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
351      saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"])
352      assert os.path.exists(saved_conda_env_path)
353      assert saved_conda_env_path != statsmodels_custom_env
354  
355      with open(statsmodels_custom_env) as f:
356          statsmodels_custom_env_parsed = yaml.safe_load(f)
357      with open(saved_conda_env_path) as f:
358          saved_conda_env_parsed = yaml.safe_load(f)
359      assert saved_conda_env_parsed == statsmodels_custom_env_parsed
360  
361  
362  def test_model_log_persists_requirements_in_mlflow_model_directory(statsmodels_custom_env):
363      ols = ols_model()
364      artifact_path = "model"
365      with mlflow.start_run():
366          model_info = mlflow.statsmodels.log_model(
367              ols.model,
368              name=artifact_path,
369              conda_env=statsmodels_custom_env,
370          )
371  
372      model_path = _download_artifact_from_uri(artifact_uri=model_info.model_uri)
373      saved_pip_req_path = os.path.join(model_path, "requirements.txt")
374      _compare_conda_env_requirements(statsmodels_custom_env, saved_pip_req_path)
375  
376  
377  def test_model_save_without_specified_conda_env_uses_default_env_with_expected_dependencies(
378      model_path,
379  ):
380      ols = ols_model()
381      mlflow.statsmodels.save_model(statsmodels_model=ols.model, path=model_path)
382      _assert_pip_requirements(model_path, mlflow.statsmodels.get_default_pip_requirements())
383  
384  
385  def test_model_log_without_specified_conda_env_uses_default_env_with_expected_dependencies():
386      ols = ols_model()
387      artifact_path = "model"
388      with mlflow.start_run():
389          model_info = mlflow.statsmodels.log_model(ols.model, name=artifact_path)
390      _assert_pip_requirements(
391          model_info.model_uri, mlflow.statsmodels.get_default_pip_requirements()
392      )
393  
394  
395  def test_pyfunc_serve_and_score():
396      model, _, inference_dataframe = ols_model()
397      artifact_path = "model"
398      with mlflow.start_run():
399          model_info = mlflow.statsmodels.log_model(
400              model, name=artifact_path, input_example=inference_dataframe
401          )
402  
403      inference_payload = load_serving_example(model_info.model_uri)
404      resp = pyfunc_serve_and_score_model(
405          model_info.model_uri,
406          data=inference_payload,
407          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
408          extra_args=EXTRA_PYFUNC_SERVING_TEST_ARGS,
409      )
410      scores = pd.DataFrame(
411          data=json.loads(resp.content.decode("utf-8"))["predictions"]
412      ).values.squeeze()
413      np.testing.assert_array_almost_equal(scores, model.predict(inference_dataframe))
414  
415  
416  def test_log_model_with_code_paths():
417      artifact_path = "model"
418      ols = ols_model()
419      with (
420          mlflow.start_run(),
421          mock.patch("mlflow.statsmodels._add_code_from_conf_to_system_path") as add_mock,
422      ):
423          model_info = mlflow.statsmodels.log_model(
424              ols.model, name=artifact_path, code_paths=[__file__]
425          )
426          _compare_logged_code_paths(__file__, model_info.model_uri, mlflow.statsmodels.FLAVOR_NAME)
427          mlflow.statsmodels.load_model(model_info.model_uri)
428          add_mock.assert_called()
429  
430  
431  def test_virtualenv_subfield_points_to_correct_path(model_path):
432      ols = ols_model()
433      mlflow.statsmodels.save_model(ols.model, path=model_path)
434      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
435      python_env_path = Path(model_path, pyfunc_conf[pyfunc.ENV]["virtualenv"])
436      assert python_env_path.exists()
437      assert python_env_path.is_file()
438  
439  
440  def test_model_save_load_with_metadata(model_path):
441      ols = ols_model()
442      mlflow.statsmodels.save_model(
443          ols.model, path=model_path, metadata={"metadata_key": "metadata_value"}
444      )
445  
446      reloaded_model = mlflow.pyfunc.load_model(model_uri=model_path)
447      assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value"
448  
449  
450  def test_model_log_with_metadata():
451      ols = ols_model()
452      artifact_path = "model"
453  
454      with mlflow.start_run():
455          model_info = mlflow.statsmodels.log_model(
456              ols.model, name=artifact_path, metadata={"metadata_key": "metadata_value"}
457          )
458  
459      reloaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri)
460      assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value"
461  
462  
463  def test_model_log_with_signature_inference():
464      model, _, X = ols_model()
465  
466      artifact_path = "model"
467      example = X[0:3, :]
468  
469      with mlflow.start_run():
470          model_info = mlflow.statsmodels.log_model(model, name=artifact_path, input_example=example)
471  
472      loaded_model = Model.load(model_info.model_uri)
473      assert loaded_model.signature == ols_model_signature()