test_statsmodels_model_export.py
1 import json 2 import os 3 from pathlib import Path 4 from unittest import mock 5 6 import numpy as np 7 import pandas as pd 8 import pytest 9 import yaml 10 11 import mlflow.pyfunc.scoring_server as pyfunc_scoring_server 12 import mlflow.statsmodels 13 from mlflow import pyfunc 14 from mlflow.models import Model 15 from mlflow.models.utils import _read_example, load_serving_example 16 from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository 17 from mlflow.tracking.artifact_utils import _download_artifact_from_uri 18 from mlflow.utils.environment import _mlflow_conda_env 19 from mlflow.utils.file_utils import TempDir 20 from mlflow.utils.model_utils import _get_flavor_configuration 21 22 from tests.helper_functions import ( 23 _assert_pip_requirements, 24 _compare_conda_env_requirements, 25 _compare_logged_code_paths, 26 _is_available_on_pypi, 27 _mlflow_major_version_string, 28 assert_register_model_called_with_local_model_path, 29 pyfunc_serve_and_score_model, 30 ) 31 from tests.statsmodels.model_fixtures import ( 32 arma_model, 33 gee_model, 34 glm_model, 35 gls_model, 36 glsar_model, 37 ols_model, 38 ols_model_signature, 39 recursivels_model, 40 rolling_ols_model, 41 rolling_wls_model, 42 wls_model, 43 ) 44 45 EXTRA_PYFUNC_SERVING_TEST_ARGS = ( 46 [] if _is_available_on_pypi("statsmodels") else ["--env-manager", "local"] 47 ) 48 49 # The code in this file has been adapted from the test cases of the lightgbm flavor. 50 51 52 def _get_dates_from_df(df): 53 start_date = df["start"][0] 54 end_date = df["end"][0] 55 return start_date, end_date 56 57 58 @pytest.fixture 59 def model_path(tmp_path, subdir="model"): 60 return os.path.join(tmp_path, subdir) 61 62 63 @pytest.fixture 64 def statsmodels_custom_env(tmp_path): 65 conda_env = os.path.join(tmp_path, "conda_env.yml") 66 _mlflow_conda_env(conda_env, additional_pip_deps=["pytest", "statsmodels"]) 67 return conda_env 68 69 70 def _test_models_list(tmp_path, func_to_apply): 71 from statsmodels.tsa.base.tsa_model import TimeSeriesModel 72 73 fixtures = [ 74 ols_model, 75 arma_model, 76 glsar_model, 77 gee_model, 78 glm_model, 79 gls_model, 80 recursivels_model, 81 rolling_ols_model, 82 rolling_wls_model, 83 wls_model, 84 ] 85 86 for algorithm in fixtures: 87 name = algorithm.__name__ 88 path = os.path.join(tmp_path, name) 89 model = algorithm() 90 if isinstance(model.alg, TimeSeriesModel): 91 start_date, end_date = _get_dates_from_df(model.inference_dataframe) 92 func_to_apply(model, path, start_date, end_date) 93 else: 94 func_to_apply(model, path, model.inference_dataframe) 95 96 97 def _test_model_save_load(statsmodels_model, model_path, *predict_args): 98 mlflow.statsmodels.save_model(statsmodels_model=statsmodels_model.model, path=model_path) 99 reloaded_model = mlflow.statsmodels.load_model(model_uri=model_path) 100 reloaded_pyfunc = pyfunc.load_model(model_uri=model_path) 101 102 if hasattr(statsmodels_model.model, "predict"): 103 np.testing.assert_array_almost_equal( 104 statsmodels_model.model.predict(*predict_args), 105 reloaded_model.predict(*predict_args), 106 ) 107 108 np.testing.assert_array_almost_equal( 109 reloaded_model.predict(*predict_args), 110 reloaded_pyfunc.predict(statsmodels_model.inference_dataframe), 111 ) 112 113 114 def _test_model_log(statsmodels_model, model_path, *predict_args): 115 model = statsmodels_model.model 116 with TempDir(chdr=True, remove_on_exit=True) as tmp: 117 try: 118 artifact_path = "model" 119 conda_env = os.path.join(tmp.path(), "conda_env.yaml") 120 _mlflow_conda_env(conda_env, additional_pip_deps=["statsmodels"]) 121 122 model_info = mlflow.statsmodels.log_model( 123 model, name=artifact_path, conda_env=conda_env 124 ) 125 reloaded_model = mlflow.statsmodels.load_model(model_uri=model_info.model_uri) 126 if hasattr(model, "predict"): 127 np.testing.assert_array_almost_equal( 128 model.predict(*predict_args), reloaded_model.predict(*predict_args) 129 ) 130 131 model_path = _download_artifact_from_uri(artifact_uri=model_info.model_uri) 132 model_config = Model.load(os.path.join(model_path, "MLmodel")) 133 assert pyfunc.FLAVOR_NAME in model_config.flavors 134 assert pyfunc.ENV in model_config.flavors[pyfunc.FLAVOR_NAME] 135 env_path = model_config.flavors[pyfunc.FLAVOR_NAME][pyfunc.ENV]["conda"] 136 assert os.path.exists(os.path.join(model_path, env_path)) 137 finally: 138 mlflow.end_run() 139 140 141 def test_models_save_load(tmp_path): 142 _test_models_list(tmp_path, _test_model_save_load) 143 144 145 def test_models_log(tmp_path): 146 _test_models_list(tmp_path, _test_model_log) 147 148 149 def test_signature_and_examples_are_saved_correctly(): 150 model, _, X = ols_model() 151 signature_ = ols_model_signature() 152 example_ = X[0:3, :] 153 154 for signature in (None, signature_): 155 for example in (None, example_): 156 with TempDir() as tmp: 157 path = tmp.path("model") 158 mlflow.statsmodels.save_model( 159 model, path=path, signature=signature, input_example=example 160 ) 161 mlflow_model = Model.load(path) 162 if signature is None and example is None: 163 assert mlflow_model.signature is None 164 else: 165 assert mlflow_model.signature == signature_ 166 if example is None: 167 assert mlflow_model.saved_input_example_info is None 168 else: 169 np.testing.assert_array_equal(_read_example(mlflow_model, path), example) 170 171 172 def test_model_load_from_remote_uri_succeeds(model_path, mock_s3_bucket): 173 model, _, inference_dataframe = arma_model() 174 mlflow.statsmodels.save_model(statsmodels_model=model, path=model_path) 175 176 artifact_root = f"s3://{mock_s3_bucket}" 177 artifact_path = "model" 178 artifact_repo = S3ArtifactRepository(artifact_root) 179 artifact_repo.log_artifacts(model_path, artifact_path=artifact_path) 180 181 model_uri = artifact_root + "/" + artifact_path 182 reloaded_model = mlflow.statsmodels.load_model(model_uri=model_uri) 183 start_date, end_date = _get_dates_from_df(inference_dataframe) 184 np.testing.assert_array_almost_equal( 185 model.predict(start=start_date, end=end_date), 186 reloaded_model.predict(start=start_date, end=end_date), 187 ) 188 189 190 def test_log_model_calls_register_model(): 191 # Adapted from lightgbm tests 192 ols = ols_model() 193 artifact_path = "model" 194 register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model") 195 with mlflow.start_run(), register_model_patch, TempDir(chdr=True, remove_on_exit=True) as tmp: 196 conda_env = os.path.join(tmp.path(), "conda_env.yaml") 197 _mlflow_conda_env(conda_env, additional_pip_deps=["statsmodels"]) 198 model_info = mlflow.statsmodels.log_model( 199 ols.model, 200 name=artifact_path, 201 conda_env=conda_env, 202 registered_model_name="OLSModel1", 203 ) 204 assert_register_model_called_with_local_model_path( 205 register_model_mock=mlflow.tracking._model_registry.fluent._register_model, 206 model_uri=model_info.model_uri, 207 registered_model_name="OLSModel1", 208 ) 209 210 211 def test_log_model_no_registered_model_name(): 212 ols = ols_model() 213 artifact_path = "model" 214 register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model") 215 with mlflow.start_run(), register_model_patch, TempDir(chdr=True, remove_on_exit=True) as tmp: 216 conda_env = os.path.join(tmp.path(), "conda_env.yaml") 217 _mlflow_conda_env(conda_env, additional_pip_deps=["statsmodels"]) 218 mlflow.statsmodels.log_model(ols.model, name=artifact_path, conda_env=conda_env) 219 mlflow.tracking._model_registry.fluent._register_model.assert_not_called() 220 221 222 def test_model_save_persists_specified_conda_env_in_mlflow_model_directory( 223 model_path, statsmodels_custom_env 224 ): 225 ols = ols_model() 226 mlflow.statsmodels.save_model( 227 statsmodels_model=ols.model, path=model_path, conda_env=statsmodels_custom_env 228 ) 229 230 pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) 231 saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"]) 232 assert os.path.exists(saved_conda_env_path) 233 assert saved_conda_env_path != statsmodels_custom_env 234 235 with open(statsmodels_custom_env) as f: 236 statsmodels_custom_env_parsed = yaml.safe_load(f) 237 with open(saved_conda_env_path) as f: 238 saved_conda_env_parsed = yaml.safe_load(f) 239 assert saved_conda_env_parsed == statsmodels_custom_env_parsed 240 241 242 def test_model_save_persists_requirements_in_mlflow_model_directory( 243 model_path, statsmodels_custom_env 244 ): 245 ols = ols_model() 246 mlflow.statsmodels.save_model( 247 statsmodels_model=ols.model, path=model_path, conda_env=statsmodels_custom_env 248 ) 249 250 saved_pip_req_path = os.path.join(model_path, "requirements.txt") 251 _compare_conda_env_requirements(statsmodels_custom_env, saved_pip_req_path) 252 253 254 def test_log_model_with_pip_requirements(tmp_path): 255 expected_mlflow_version = _mlflow_major_version_string() 256 ols = ols_model() 257 # Path to a requirements file 258 req_file = tmp_path.joinpath("requirements.txt") 259 req_file.write_text("a") 260 with mlflow.start_run(): 261 model_info = mlflow.statsmodels.log_model( 262 ols.model, name="model", pip_requirements=str(req_file) 263 ) 264 _assert_pip_requirements(model_info.model_uri, [expected_mlflow_version, "a"], strict=True) 265 266 # List of requirements 267 with mlflow.start_run(): 268 model_info = mlflow.statsmodels.log_model( 269 ols.model, name="model", pip_requirements=[f"-r {req_file}", "b"] 270 ) 271 _assert_pip_requirements( 272 model_info.model_uri, [expected_mlflow_version, "a", "b"], strict=True 273 ) 274 275 # Constraints file 276 with mlflow.start_run(): 277 model_info = mlflow.statsmodels.log_model( 278 ols.model, name="model", pip_requirements=[f"-c {req_file}", "b"] 279 ) 280 _assert_pip_requirements( 281 model_info.model_uri, 282 [expected_mlflow_version, "b", "-c constraints.txt"], 283 ["a"], 284 strict=True, 285 ) 286 287 288 def test_log_model_with_extra_pip_requirements(tmp_path): 289 expected_mlflow_version = _mlflow_major_version_string() 290 ols = ols_model() 291 default_reqs = mlflow.statsmodels.get_default_pip_requirements() 292 293 # Path to a requirements file 294 req_file = tmp_path.joinpath("requirements.txt") 295 req_file.write_text("a") 296 with mlflow.start_run(): 297 model_info = mlflow.statsmodels.log_model( 298 ols.model, name="model", extra_pip_requirements=str(req_file) 299 ) 300 _assert_pip_requirements( 301 model_info.model_uri, [expected_mlflow_version, *default_reqs, "a"] 302 ) 303 304 # List of requirements 305 with mlflow.start_run(): 306 model_info = mlflow.statsmodels.log_model( 307 ols.model, name="model", extra_pip_requirements=[f"-r {req_file}", "b"] 308 ) 309 _assert_pip_requirements( 310 model_info.model_uri, [expected_mlflow_version, *default_reqs, "a", "b"] 311 ) 312 313 # Constraints file 314 with mlflow.start_run(): 315 model_info = mlflow.statsmodels.log_model( 316 ols.model, name="model", extra_pip_requirements=[f"-c {req_file}", "b"] 317 ) 318 _assert_pip_requirements( 319 model_info.model_uri, 320 [expected_mlflow_version, *default_reqs, "b", "-c constraints.txt"], 321 ["a"], 322 ) 323 324 325 def test_model_save_accepts_conda_env_as_dict(model_path): 326 ols = ols_model() 327 conda_env = dict(mlflow.statsmodels.get_default_conda_env()) 328 conda_env["dependencies"].append("pytest") 329 mlflow.statsmodels.save_model(statsmodels_model=ols.model, path=model_path, conda_env=conda_env) 330 331 pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) 332 saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"]) 333 assert os.path.exists(saved_conda_env_path) 334 335 with open(saved_conda_env_path) as f: 336 saved_conda_env_parsed = yaml.safe_load(f) 337 assert saved_conda_env_parsed == conda_env 338 339 340 def test_model_log_persists_specified_conda_env_in_mlflow_model_directory(statsmodels_custom_env): 341 ols = ols_model() 342 with mlflow.start_run(): 343 model_info = mlflow.statsmodels.log_model( 344 ols.model, 345 name="model", 346 conda_env=statsmodels_custom_env, 347 ) 348 349 model_path = _download_artifact_from_uri(artifact_uri=model_info.model_uri) 350 pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) 351 saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"]) 352 assert os.path.exists(saved_conda_env_path) 353 assert saved_conda_env_path != statsmodels_custom_env 354 355 with open(statsmodels_custom_env) as f: 356 statsmodels_custom_env_parsed = yaml.safe_load(f) 357 with open(saved_conda_env_path) as f: 358 saved_conda_env_parsed = yaml.safe_load(f) 359 assert saved_conda_env_parsed == statsmodels_custom_env_parsed 360 361 362 def test_model_log_persists_requirements_in_mlflow_model_directory(statsmodels_custom_env): 363 ols = ols_model() 364 artifact_path = "model" 365 with mlflow.start_run(): 366 model_info = mlflow.statsmodels.log_model( 367 ols.model, 368 name=artifact_path, 369 conda_env=statsmodels_custom_env, 370 ) 371 372 model_path = _download_artifact_from_uri(artifact_uri=model_info.model_uri) 373 saved_pip_req_path = os.path.join(model_path, "requirements.txt") 374 _compare_conda_env_requirements(statsmodels_custom_env, saved_pip_req_path) 375 376 377 def test_model_save_without_specified_conda_env_uses_default_env_with_expected_dependencies( 378 model_path, 379 ): 380 ols = ols_model() 381 mlflow.statsmodels.save_model(statsmodels_model=ols.model, path=model_path) 382 _assert_pip_requirements(model_path, mlflow.statsmodels.get_default_pip_requirements()) 383 384 385 def test_model_log_without_specified_conda_env_uses_default_env_with_expected_dependencies(): 386 ols = ols_model() 387 artifact_path = "model" 388 with mlflow.start_run(): 389 model_info = mlflow.statsmodels.log_model(ols.model, name=artifact_path) 390 _assert_pip_requirements( 391 model_info.model_uri, mlflow.statsmodels.get_default_pip_requirements() 392 ) 393 394 395 def test_pyfunc_serve_and_score(): 396 model, _, inference_dataframe = ols_model() 397 artifact_path = "model" 398 with mlflow.start_run(): 399 model_info = mlflow.statsmodels.log_model( 400 model, name=artifact_path, input_example=inference_dataframe 401 ) 402 403 inference_payload = load_serving_example(model_info.model_uri) 404 resp = pyfunc_serve_and_score_model( 405 model_info.model_uri, 406 data=inference_payload, 407 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 408 extra_args=EXTRA_PYFUNC_SERVING_TEST_ARGS, 409 ) 410 scores = pd.DataFrame( 411 data=json.loads(resp.content.decode("utf-8"))["predictions"] 412 ).values.squeeze() 413 np.testing.assert_array_almost_equal(scores, model.predict(inference_dataframe)) 414 415 416 def test_log_model_with_code_paths(): 417 artifact_path = "model" 418 ols = ols_model() 419 with ( 420 mlflow.start_run(), 421 mock.patch("mlflow.statsmodels._add_code_from_conf_to_system_path") as add_mock, 422 ): 423 model_info = mlflow.statsmodels.log_model( 424 ols.model, name=artifact_path, code_paths=[__file__] 425 ) 426 _compare_logged_code_paths(__file__, model_info.model_uri, mlflow.statsmodels.FLAVOR_NAME) 427 mlflow.statsmodels.load_model(model_info.model_uri) 428 add_mock.assert_called() 429 430 431 def test_virtualenv_subfield_points_to_correct_path(model_path): 432 ols = ols_model() 433 mlflow.statsmodels.save_model(ols.model, path=model_path) 434 pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) 435 python_env_path = Path(model_path, pyfunc_conf[pyfunc.ENV]["virtualenv"]) 436 assert python_env_path.exists() 437 assert python_env_path.is_file() 438 439 440 def test_model_save_load_with_metadata(model_path): 441 ols = ols_model() 442 mlflow.statsmodels.save_model( 443 ols.model, path=model_path, metadata={"metadata_key": "metadata_value"} 444 ) 445 446 reloaded_model = mlflow.pyfunc.load_model(model_uri=model_path) 447 assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value" 448 449 450 def test_model_log_with_metadata(): 451 ols = ols_model() 452 artifact_path = "model" 453 454 with mlflow.start_run(): 455 model_info = mlflow.statsmodels.log_model( 456 ols.model, name=artifact_path, metadata={"metadata_key": "metadata_value"} 457 ) 458 459 reloaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri) 460 assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value" 461 462 463 def test_model_log_with_signature_inference(): 464 model, _, X = ols_model() 465 466 artifact_path = "model" 467 example = X[0:3, :] 468 469 with mlflow.start_run(): 470 model_info = mlflow.statsmodels.log_model(model, name=artifact_path, input_example=example) 471 472 loaded_model = Model.load(model_info.model_uri) 473 assert loaded_model.signature == ols_model_signature()