test_pmdarima_model_export.py
1 import json 2 import os 3 from pathlib import Path 4 from unittest import mock 5 6 import numpy as np 7 import pandas as pd 8 import pmdarima 9 import pytest 10 import yaml 11 12 import mlflow.pmdarima 13 import mlflow.pyfunc.scoring_server as pyfunc_scoring_server 14 from mlflow import pyfunc 15 from mlflow.exceptions import MlflowException 16 from mlflow.models import Model, ModelSignature, infer_signature 17 from mlflow.models.utils import _read_example, load_serving_example 18 from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository 19 from mlflow.tracking.artifact_utils import _download_artifact_from_uri 20 from mlflow.types import DataType 21 from mlflow.types.schema import ColSpec, Schema 22 from mlflow.utils.environment import _mlflow_conda_env 23 from mlflow.utils.model_utils import _get_flavor_configuration 24 25 from tests.helper_functions import ( 26 _assert_pip_requirements, 27 _compare_conda_env_requirements, 28 _compare_logged_code_paths, 29 _is_available_on_pypi, 30 _mlflow_major_version_string, 31 assert_register_model_called_with_local_model_path, 32 pyfunc_serve_and_score_model, 33 ) 34 from tests.prophet.test_prophet_model_export import DataGeneration 35 36 EXTRA_PYFUNC_SERVING_TEST_ARGS = ( 37 [] if _is_available_on_pypi("pmdarima") else ["--env-manager", "local"] 38 ) 39 40 41 @pytest.fixture 42 def model_path(tmp_path): 43 return tmp_path.joinpath("model") 44 45 46 @pytest.fixture 47 def pmdarima_custom_env(tmp_path): 48 conda_env = tmp_path.joinpath("conda_env.yml") 49 _mlflow_conda_env(conda_env, additional_pip_deps=["pmdarima"]) 50 return conda_env 51 52 53 @pytest.fixture(scope="module") 54 def test_data(): 55 data_conf = { 56 "shift": False, 57 "start": "2016-01-01", 58 "size": 365 * 3, 59 "seasonal_period": 7, 60 "seasonal_freq": 0.1, 61 "date_field": "date", 62 "target_field": "orders", 63 } 64 raw = DataGeneration(**data_conf).create_series_df() 65 return raw.set_index("date") 66 67 68 @pytest.fixture(scope="module") 69 def auto_arima_model(test_data): 70 return pmdarima.auto_arima( 71 test_data["orders"], max_d=1, suppress_warnings=True, error_action="raise" 72 ) 73 74 75 @pytest.fixture(scope="module") 76 def auto_arima_object_model(test_data): 77 model = pmdarima.arima.ARIMA(order=(2, 1, 3), maxiter=25) 78 return model.fit(test_data["orders"]) 79 80 81 def test_pmdarima_auto_arima_save_and_load(auto_arima_model, model_path): 82 mlflow.pmdarima.save_model(pmdarima_model=auto_arima_model, path=model_path) 83 84 loaded_model = mlflow.pmdarima.load_model(model_uri=model_path) 85 86 np.testing.assert_array_equal(auto_arima_model.predict(10), loaded_model.predict(10)) 87 88 89 def test_load_model_disallows_pickle_deserialization(auto_arima_model, model_path, monkeypatch): 90 mlflow.pmdarima.save_model(pmdarima_model=auto_arima_model, path=model_path) 91 92 monkeypatch.setenv("MLFLOW_ALLOW_PICKLE_DESERIALIZATION", "false") 93 with pytest.raises(MlflowException, match="MLFLOW_ALLOW_PICKLE_DESERIALIZATION"): 94 mlflow.pmdarima.load_model(model_uri=model_path) 95 96 97 def test_pmdarima_arima_object_save_and_load(auto_arima_object_model, model_path): 98 mlflow.pmdarima.save_model(pmdarima_model=auto_arima_object_model, path=model_path) 99 100 loaded_model = mlflow.pmdarima.load_model(model_uri=model_path) 101 102 np.testing.assert_array_equal(auto_arima_object_model.predict(30), loaded_model.predict(30)) 103 104 105 def test_pmdarima_autoarima_pyfunc_save_and_load(auto_arima_model, model_path): 106 mlflow.pmdarima.save_model(pmdarima_model=auto_arima_model, path=model_path) 107 loaded_pyfunc = mlflow.pyfunc.load_model(model_uri=model_path) 108 109 model_predict = auto_arima_model.predict(n_periods=60, return_conf_int=True, alpha=0.1) 110 111 predict_conf = pd.DataFrame({"n_periods": 60, "return_conf_int": True, "alpha": 0.1}, index=[0]) 112 pyfunc_predict = loaded_pyfunc.predict(predict_conf) 113 114 np.testing.assert_array_equal(model_predict[0], pyfunc_predict["yhat"]) 115 yhat_low, yhat_high = list(zip(*model_predict[1])) 116 np.testing.assert_array_equal(yhat_low, pyfunc_predict["yhat_lower"]) 117 np.testing.assert_array_equal(yhat_high, pyfunc_predict["yhat_upper"]) 118 119 120 @pytest.mark.parametrize("use_signature", [True, False]) 121 @pytest.mark.parametrize("use_example", [True, False]) 122 def test_pmdarima_signature_and_examples_saved_correctly( 123 auto_arima_model, model_path, use_signature, use_example 124 ): 125 # NB: Signature inference will only work on the first element of the tuple return 126 prediction = auto_arima_model.predict(n_periods=20, return_conf_int=True, alpha=0.05) 127 test_data = pd.DataFrame({"n_periods": [30]}) 128 signature = infer_signature(test_data, prediction[0]) if use_signature or use_example else None 129 example = test_data if use_example else None 130 mlflow.pmdarima.save_model( 131 auto_arima_model, path=model_path, signature=signature, input_example=example 132 ) 133 mlflow_model = Model.load(model_path) 134 if signature is None and example is None: 135 assert mlflow_model.signature is None 136 else: 137 assert mlflow_model.signature == signature 138 if example is None: 139 assert mlflow_model.saved_input_example_info is None 140 else: 141 r_example = _read_example(mlflow_model, model_path).copy(deep=False) 142 np.testing.assert_array_equal(r_example, example) 143 144 145 @pytest.mark.parametrize("use_signature", [True, False]) 146 @pytest.mark.parametrize("use_example", [True, False]) 147 def test_pmdarima_signature_and_example_for_confidence_interval_mode( 148 auto_arima_model, model_path, use_signature, use_example 149 ): 150 model_path_primary = model_path.joinpath("primary") 151 model_path_secondary = model_path.joinpath("secondary") 152 mlflow.pmdarima.save_model(pmdarima_model=auto_arima_model, path=model_path_primary) 153 loaded_pyfunc = mlflow.pyfunc.load_model(model_uri=model_path_primary) 154 predict_conf = pd.DataFrame([{"n_periods": 10, "return_conf_int": True, "alpha": 0.2}]) 155 forecast = loaded_pyfunc.predict(predict_conf) 156 signature_ = infer_signature(predict_conf, forecast) 157 signature = signature_ if use_signature else None 158 example = predict_conf.copy(deep=False) if use_example else None 159 mlflow.pmdarima.save_model( 160 auto_arima_model, path=model_path_secondary, signature=signature, input_example=example 161 ) 162 mlflow_model = Model.load(model_path_secondary) 163 if signature is None and example is None: 164 assert mlflow_model.signature is None 165 else: 166 assert mlflow_model.signature == signature_ 167 if example is None: 168 assert mlflow_model.saved_input_example_info is None 169 else: 170 r_example = _read_example(mlflow_model, model_path_secondary).copy(deep=False) 171 np.testing.assert_array_equal(r_example, example) 172 173 174 def test_pmdarima_load_from_remote_uri_succeeds( 175 auto_arima_object_model, model_path, mock_s3_bucket 176 ): 177 mlflow.pmdarima.save_model(pmdarima_model=auto_arima_object_model, path=model_path) 178 179 artifact_root = f"s3://{mock_s3_bucket}" 180 artifact_path = "model" 181 artifact_repo = S3ArtifactRepository(artifact_root) 182 artifact_repo.log_artifacts(model_path, artifact_path=artifact_path) 183 184 # NB: cloudpathlib would need to be used here to handle object store uri 185 model_uri = os.path.join(artifact_root, artifact_path) 186 reloaded_pmdarima_model = mlflow.pmdarima.load_model(model_uri=model_uri) 187 188 np.testing.assert_array_equal( 189 auto_arima_object_model.predict(30), reloaded_pmdarima_model.predict(30) 190 ) 191 192 193 @pytest.mark.parametrize("should_start_run", [True, False]) 194 def test_pmdarima_log_model(auto_arima_model, tmp_path, should_start_run): 195 try: 196 if should_start_run: 197 mlflow.start_run() 198 artifact_path = "pmdarima" 199 conda_env = tmp_path.joinpath("conda_env.yaml") 200 _mlflow_conda_env(conda_env, additional_pip_deps=["pmdarima"]) 201 model_info = mlflow.pmdarima.log_model( 202 auto_arima_model, 203 name=artifact_path, 204 conda_env=str(conda_env), 205 ) 206 reloaded_model = mlflow.pmdarima.load_model(model_uri=model_info.model_uri) 207 np.testing.assert_array_equal(auto_arima_model.predict(20), reloaded_model.predict(20)) 208 model_path = Path(_download_artifact_from_uri(artifact_uri=model_info.model_uri)) 209 model_config = Model.load(str(model_path.joinpath("MLmodel"))) 210 assert pyfunc.FLAVOR_NAME in model_config.flavors 211 assert pyfunc.ENV in model_config.flavors[pyfunc.FLAVOR_NAME] 212 env_path = model_config.flavors[pyfunc.FLAVOR_NAME][pyfunc.ENV]["conda"] 213 assert model_path.joinpath(env_path).exists() 214 finally: 215 mlflow.end_run() 216 217 218 def test_pmdarima_log_model_calls_register_model(auto_arima_object_model, tmp_path): 219 artifact_path = "pmdarima" 220 register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model") 221 with mlflow.start_run(), register_model_patch: 222 conda_env = tmp_path.joinpath("conda_env.yaml") 223 _mlflow_conda_env(conda_env, additional_pip_deps=["pmdarima"]) 224 model_info = mlflow.pmdarima.log_model( 225 auto_arima_object_model, 226 name=artifact_path, 227 conda_env=str(conda_env), 228 registered_model_name="PmdarimaModel", 229 ) 230 assert_register_model_called_with_local_model_path( 231 mlflow.tracking._model_registry.fluent._register_model, 232 model_info.model_uri, 233 "PmdarimaModel", 234 ) 235 236 237 def test_pmdarima_log_model_no_registered_model_name(auto_arima_model, tmp_path): 238 artifact_path = "pmdarima" 239 register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model") 240 with mlflow.start_run(), register_model_patch: 241 conda_env = tmp_path.joinpath("conda_env.yaml") 242 _mlflow_conda_env(conda_env, additional_pip_deps=["pmdarima"]) 243 mlflow.pmdarima.log_model(auto_arima_model, name=artifact_path, conda_env=str(conda_env)) 244 mlflow.tracking._model_registry.fluent._register_model.assert_not_called() 245 246 247 def test_pmdarima_model_save_persists_specified_conda_env_in_mlflow_model_directory( 248 auto_arima_object_model, model_path, pmdarima_custom_env 249 ): 250 mlflow.pmdarima.save_model( 251 pmdarima_model=auto_arima_object_model, path=model_path, conda_env=str(pmdarima_custom_env) 252 ) 253 pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) 254 saved_conda_env_path = model_path.joinpath(pyfunc_conf[pyfunc.ENV]["conda"]) 255 assert saved_conda_env_path.exists() 256 assert not pmdarima_custom_env.samefile(saved_conda_env_path) 257 258 pmdarima_custom_env_parsed = yaml.safe_load(pmdarima_custom_env.read_bytes()) 259 saved_conda_env_parsed = yaml.safe_load(saved_conda_env_path.read_bytes()) 260 assert saved_conda_env_parsed == pmdarima_custom_env_parsed 261 262 263 def test_pmdarima_model_save_persists_requirements_in_mlflow_model_directory( 264 auto_arima_model, model_path, pmdarima_custom_env 265 ): 266 mlflow.pmdarima.save_model( 267 pmdarima_model=auto_arima_model, path=model_path, conda_env=str(pmdarima_custom_env) 268 ) 269 saved_pip_req_path = model_path.joinpath("requirements.txt") 270 _compare_conda_env_requirements(pmdarima_custom_env, str(saved_pip_req_path)) 271 272 273 def test_pmdarima_log_model_with_pip_requirements(auto_arima_object_model, tmp_path): 274 expected_mlflow_version = _mlflow_major_version_string() 275 req_file = tmp_path.joinpath("requirements.txt") 276 req_file.write_text("a") 277 with mlflow.start_run(): 278 model_info = mlflow.pmdarima.log_model( 279 auto_arima_object_model, name="model", pip_requirements=str(req_file) 280 ) 281 _assert_pip_requirements(model_info.model_uri, [expected_mlflow_version, "a"], strict=True) 282 283 # List of requirements 284 with mlflow.start_run(): 285 model_info = mlflow.pmdarima.log_model( 286 auto_arima_object_model, name="model", pip_requirements=[f"-r {req_file}", "b"] 287 ) 288 _assert_pip_requirements( 289 model_info.model_uri, [expected_mlflow_version, "a", "b"], strict=True 290 ) 291 292 # Constraints file 293 with mlflow.start_run(): 294 model_info = mlflow.pmdarima.log_model( 295 auto_arima_object_model, name="model", pip_requirements=[f"-c {req_file}", "b"] 296 ) 297 _assert_pip_requirements( 298 model_info.model_uri, 299 [expected_mlflow_version, "b", "-c constraints.txt"], 300 ["a"], 301 strict=True, 302 ) 303 304 305 def test_pmdarima_log_model_with_extra_pip_requirements(auto_arima_model, tmp_path): 306 expected_mlflow_version = _mlflow_major_version_string() 307 default_reqs = mlflow.pmdarima.get_default_pip_requirements() 308 309 # Path to a requirements file 310 req_file = tmp_path.joinpath("requirements.txt") 311 req_file.write_text("a") 312 with mlflow.start_run(): 313 model_info = mlflow.pmdarima.log_model( 314 auto_arima_model, name="model", extra_pip_requirements=str(req_file) 315 ) 316 _assert_pip_requirements( 317 model_info.model_uri, [expected_mlflow_version, *default_reqs, "a"] 318 ) 319 320 # List of requirements 321 with mlflow.start_run(): 322 model_info = mlflow.pmdarima.log_model( 323 auto_arima_model, name="model", extra_pip_requirements=[f"-r {req_file}", "b"] 324 ) 325 _assert_pip_requirements( 326 model_info.model_uri, [expected_mlflow_version, *default_reqs, "a", "b"] 327 ) 328 329 # Constraints file 330 with mlflow.start_run(): 331 model_info = mlflow.pmdarima.log_model( 332 auto_arima_model, name="model", extra_pip_requirements=[f"-c {req_file}", "b"] 333 ) 334 _assert_pip_requirements( 335 model_uri=model_info.model_uri, 336 requirements=[expected_mlflow_version, *default_reqs, "b", "-c constraints.txt"], 337 constraints=["a"], 338 strict=False, 339 ) 340 341 342 def test_pmdarima_model_save_without_conda_env_uses_default_env_with_expected_dependencies( 343 auto_arima_model, model_path 344 ): 345 mlflow.pmdarima.save_model(auto_arima_model, model_path) 346 _assert_pip_requirements(model_path, mlflow.pmdarima.get_default_pip_requirements()) 347 348 349 def test_pmdarima_model_log_without_conda_env_uses_default_env_with_expected_dependencies( 350 auto_arima_object_model, 351 ): 352 artifact_path = "model" 353 with mlflow.start_run(): 354 model_info = mlflow.pmdarima.log_model(auto_arima_object_model, name=artifact_path) 355 _assert_pip_requirements(model_info.model_uri, mlflow.pmdarima.get_default_pip_requirements()) 356 357 358 def test_pmdarima_pyfunc_serve_and_score(auto_arima_model): 359 artifact_path = "model" 360 with mlflow.start_run(): 361 model_info = mlflow.pmdarima.log_model( 362 auto_arima_model, 363 name=artifact_path, 364 input_example=pd.DataFrame({"n_periods": 30}, index=[0]), 365 ) 366 local_predict = auto_arima_model.predict(30) 367 368 inference_payload = load_serving_example(model_info.model_uri) 369 resp = pyfunc_serve_and_score_model( 370 model_info.model_uri, 371 data=inference_payload, 372 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 373 extra_args=EXTRA_PYFUNC_SERVING_TEST_ARGS, 374 ) 375 scores = ( 376 pd 377 .DataFrame(data=json.loads(resp.content.decode("utf-8"))["predictions"]) 378 .to_numpy() 379 .flatten() 380 ) 381 np.testing.assert_array_almost_equal(scores, local_predict) 382 383 384 def test_pmdarima_pyfunc_raises_invalid_df_input(auto_arima_model, model_path): 385 mlflow.pmdarima.save_model(pmdarima_model=auto_arima_model, path=model_path) 386 loaded_pyfunc = mlflow.pyfunc.load_model(model_uri=model_path) 387 388 with pytest.raises(MlflowException, match="The provided prediction pd.DataFrame "): 389 loaded_pyfunc.predict(pd.DataFrame([{"n_periods": 60}, {"n_periods": 100}])) 390 391 with pytest.raises(MlflowException, match="The provided prediction configuration "): 392 loaded_pyfunc.predict(pd.DataFrame([{"invalid": True}])) 393 394 with pytest.raises(MlflowException, match="The provided `n_periods` value "): 395 loaded_pyfunc.predict(pd.DataFrame([{"n_periods": "60"}])) 396 397 398 def test_pmdarima_pyfunc_return_correct_structure(auto_arima_model, model_path): 399 mlflow.pmdarima.save_model(pmdarima_model=auto_arima_model, path=model_path) 400 loaded_pyfunc = mlflow.pyfunc.load_model(model_uri=model_path) 401 402 predict_conf_no_ci = pd.DataFrame([{"n_periods": 10, "return_conf_int": False}]) 403 forecast_no_ci = loaded_pyfunc.predict(predict_conf_no_ci) 404 405 assert isinstance(forecast_no_ci, pd.DataFrame) 406 assert len(forecast_no_ci) == 10 407 assert len(forecast_no_ci.columns.values) == 1 408 409 predict_conf_with_ci = pd.DataFrame([{"n_periods": 10, "return_conf_int": True}]) 410 forecast_with_ci = loaded_pyfunc.predict(predict_conf_with_ci) 411 412 assert isinstance(forecast_with_ci, pd.DataFrame) 413 assert len(forecast_with_ci) == 10 414 assert len(forecast_with_ci.columns.values) == 3 415 416 417 def test_log_model_with_code_paths(auto_arima_model): 418 artifact_path = "model" 419 with ( 420 mlflow.start_run(), 421 mock.patch("mlflow.pmdarima._add_code_from_conf_to_system_path") as add_mock, 422 ): 423 model_info = mlflow.pmdarima.log_model( 424 auto_arima_model, name=artifact_path, code_paths=[__file__] 425 ) 426 _compare_logged_code_paths(__file__, model_info.model_uri, mlflow.pmdarima.FLAVOR_NAME) 427 mlflow.pmdarima.load_model(model_info.model_uri) 428 add_mock.assert_called() 429 430 431 def test_virtualenv_subfield_points_to_correct_path(auto_arima_model, model_path): 432 mlflow.pmdarima.save_model(auto_arima_model, path=model_path) 433 pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) 434 python_env_path = Path(model_path, pyfunc_conf[pyfunc.ENV]["virtualenv"]) 435 assert python_env_path.exists() 436 assert python_env_path.is_file() 437 438 439 def test_model_save_load_with_metadata(auto_arima_model, model_path): 440 mlflow.pmdarima.save_model( 441 auto_arima_model, path=model_path, metadata={"metadata_key": "metadata_value"} 442 ) 443 444 reloaded_model = mlflow.pyfunc.load_model(model_uri=model_path) 445 assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value" 446 447 448 def test_model_log_with_metadata(auto_arima_model): 449 artifact_path = "model" 450 451 with mlflow.start_run(): 452 model_info = mlflow.pmdarima.log_model( 453 auto_arima_model, 454 name=artifact_path, 455 metadata={"metadata_key": "metadata_value"}, 456 ) 457 458 reloaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri) 459 assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value" 460 461 462 def test_model_log_with_signature_inference(auto_arima_model): 463 artifact_path = "model" 464 example = pd.DataFrame({"n_periods": 60, "return_conf_int": True, "alpha": 0.1}, index=[0]) 465 466 with mlflow.start_run(): 467 model_info = mlflow.pmdarima.log_model( 468 auto_arima_model, name=artifact_path, input_example=example 469 ) 470 471 model_info_loaded = Model.load(model_info.model_uri) 472 assert model_info_loaded.signature == ModelSignature( 473 inputs=Schema([ 474 ColSpec(name="n_periods", type=DataType.long), 475 ColSpec(name="return_conf_int", type=DataType.boolean), 476 ColSpec(name="alpha", type=DataType.double), 477 ]), 478 outputs=Schema([ 479 ColSpec(name="yhat", type=DataType.double), 480 ColSpec(name="yhat_lower", type=DataType.double), 481 ColSpec(name="yhat_upper", type=DataType.double), 482 ]), 483 )