test_catboost_model_export.py
1 import json 2 import os 3 from pathlib import Path 4 from typing import Any, NamedTuple 5 from unittest import mock 6 7 import catboost as cb 8 import numpy as np 9 import pandas as pd 10 import pytest 11 import yaml 12 from packaging.version import Version 13 from sklearn import datasets 14 from sklearn.pipeline import Pipeline 15 16 import mlflow.catboost 17 import mlflow.pyfunc.scoring_server as pyfunc_scoring_server 18 from mlflow import pyfunc 19 from mlflow.models import Model, ModelSignature 20 from mlflow.models.utils import _read_example, load_serving_example 21 from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository 22 from mlflow.tracking.artifact_utils import _download_artifact_from_uri 23 from mlflow.types import DataType 24 from mlflow.types.schema import ColSpec, Schema, TensorSpec 25 from mlflow.utils.environment import _mlflow_conda_env 26 from mlflow.utils.model_utils import _get_flavor_configuration 27 28 from tests.helper_functions import ( 29 _assert_pip_requirements, 30 _compare_conda_env_requirements, 31 _compare_logged_code_paths, 32 _is_available_on_pypi, 33 _mlflow_major_version_string, 34 assert_register_model_called_with_local_model_path, 35 pyfunc_serve_and_score_model, 36 ) 37 38 EXTRA_PYFUNC_SERVING_TEST_ARGS = ( 39 [] if _is_available_on_pypi("catboost") else ["--env-manager", "local"] 40 ) 41 42 43 class ModelWithData(NamedTuple): 44 model: Any 45 inference_dataframe: Any 46 47 48 def get_iris(): 49 iris = datasets.load_iris() 50 X = pd.DataFrame(iris.data[:, :2], columns=iris.feature_names[:2]) 51 y = pd.Series(iris.target) 52 return X, y 53 54 55 def read_yaml(path): 56 with open(path) as f: 57 return yaml.safe_load(f) 58 59 60 MODEL_PARAMS = {"allow_writing_files": False, "iterations": 10} 61 62 63 def iter_models(): 64 X, y = get_iris() 65 model = cb.CatBoost(MODEL_PARAMS).fit(X, y) 66 yield ModelWithData(model=model, inference_dataframe=X) 67 68 model = cb.CatBoostClassifier(**MODEL_PARAMS).fit(X, y) 69 yield ModelWithData(model=model, inference_dataframe=X) 70 71 model = cb.CatBoostRegressor(**MODEL_PARAMS).fit(X, y) 72 yield ModelWithData(model=model, inference_dataframe=X) 73 74 75 @pytest.fixture( 76 scope="module", 77 params=iter_models(), 78 ids=["CatBoost", "CatBoostClassifier", "CatBoostRegressor"], 79 ) 80 def cb_model(request): 81 return request.param 82 83 84 @pytest.fixture 85 def reg_model(): 86 model = cb.CatBoostRegressor(**MODEL_PARAMS) 87 X, y = get_iris() 88 return ModelWithData(model=model.fit(X, y), inference_dataframe=X) 89 90 91 def get_reg_model_signature(): 92 return ModelSignature( 93 inputs=Schema([ 94 ColSpec(name="sepal length (cm)", type=DataType.double), 95 ColSpec(name="sepal width (cm)", type=DataType.double), 96 ]), 97 outputs=Schema([ColSpec(type=DataType.double)]), 98 ) 99 100 101 @pytest.fixture 102 def model_path(tmp_path): 103 return os.path.join(tmp_path, "model") 104 105 106 @pytest.fixture 107 def custom_env(tmp_path): 108 conda_env_path = os.path.join(tmp_path, "conda_env.yml") 109 _mlflow_conda_env(conda_env_path, additional_pip_deps=["catboost", "pytest"]) 110 return conda_env_path 111 112 113 @pytest.mark.parametrize("model_type", ["CatBoost", "CatBoostClassifier", "CatBoostRegressor"]) 114 def test_init_model(model_type): 115 model = mlflow.catboost._init_model(model_type) 116 assert model.__class__.__name__ == model_type 117 118 119 @pytest.mark.skipif( 120 Version(cb.__version__) < Version("0.26.0"), 121 reason="catboost < 0.26.0 does not support CatBoostRanker", 122 ) 123 def test_log_catboost_ranker(): 124 """ 125 This is a separate test for the CatBoostRanker model. 126 It is separate since the ranking task requires a group_id column which makes the code different. 127 """ 128 # the ranking task requires setting a group_id 129 # we are creating a dummy group_id here that doesn't make any sense for the Iris dataset, 130 # but is ok for testing if the code is running correctly 131 X, y = get_iris() 132 dummy_group_id = np.arange(len(X)) % 3 133 dummy_group_id.sort() 134 135 model = cb.CatBoostRanker(**MODEL_PARAMS, subsample=1.0) 136 model.fit(X, y, group_id=dummy_group_id) 137 138 with mlflow.start_run(): 139 model_info = mlflow.catboost.log_model(model, name="model") 140 loaded_model = mlflow.catboost.load_model(model_info.model_uri) 141 assert isinstance(loaded_model, cb.CatBoostRanker) 142 np.testing.assert_array_almost_equal(model.predict(X), loaded_model.predict(X)) 143 144 145 def test_init_model_throws_for_invalid_model_type(): 146 with pytest.raises(TypeError, match="Invalid model type"): 147 mlflow.catboost._init_model("unsupported") 148 149 150 def test_model_save_load(cb_model, model_path): 151 model, inference_dataframe = cb_model 152 mlflow.catboost.save_model(cb_model=model, path=model_path) 153 154 loaded_model = mlflow.catboost.load_model(model_uri=model_path) 155 np.testing.assert_array_almost_equal( 156 model.predict(inference_dataframe), 157 loaded_model.predict(inference_dataframe), 158 ) 159 160 loaded_pyfunc = pyfunc.load_model(model_uri=model_path) 161 np.testing.assert_array_almost_equal( 162 loaded_model.predict(inference_dataframe), 163 loaded_pyfunc.predict(inference_dataframe), 164 ) 165 166 167 def test_log_model_logs_model_type(cb_model): 168 with mlflow.start_run(): 169 artifact_path = "model" 170 model_info = mlflow.catboost.log_model(cb_model.model, name=artifact_path) 171 172 flavor_conf = Model.load(model_info.model_uri).flavors["catboost"] 173 assert "model_type" in flavor_conf 174 assert flavor_conf["model_type"] == cb_model.model.__class__.__name__ 175 176 177 # Supported serialization formats: 178 # https://catboost.ai/docs/concepts/python-reference_catboost_save_model.html 179 SUPPORTS_DESERIALIZATION = ["cbm", "coreml", "json", "onnx"] 180 save_formats = SUPPORTS_DESERIALIZATION + ["python", "cpp", "pmml"] 181 182 183 @pytest.mark.allow_infer_pip_requirements_fallback 184 @pytest.mark.parametrize("save_format", save_formats) 185 def test_log_model_logs_save_format(reg_model, save_format): 186 with mlflow.start_run(): 187 artifact_path = "model" 188 model_info = mlflow.catboost.log_model( 189 reg_model.model, name=artifact_path, format=save_format 190 ) 191 192 flavor_conf = Model.load(model_info.model_uri).flavors["catboost"] 193 assert "save_format" in flavor_conf 194 assert flavor_conf["save_format"] == save_format 195 196 if save_format in SUPPORTS_DESERIALIZATION: 197 mlflow.catboost.load_model(model_info.model_uri) 198 else: 199 with pytest.raises(cb.CatBoostError, match="deserialization not supported or missing"): 200 mlflow.catboost.load_model(model_info.model_uri) 201 202 203 @pytest.mark.parametrize("signature", [None, get_reg_model_signature()]) 204 @pytest.mark.parametrize("input_example", [None, get_iris()[0].head(3)]) 205 def test_signature_and_examples_are_saved_correctly( 206 reg_model, model_path, signature, input_example 207 ): 208 mlflow.catboost.save_model( 209 reg_model.model, model_path, signature=signature, input_example=input_example 210 ) 211 mlflow_model = Model.load(model_path) 212 if signature is None and input_example is None: 213 assert mlflow_model.signature is None 214 else: 215 assert mlflow_model.signature == get_reg_model_signature() 216 if input_example is None: 217 assert mlflow_model.saved_input_example_info is None 218 else: 219 pd.testing.assert_frame_equal(_read_example(mlflow_model, model_path), input_example) 220 221 222 def test_model_load_from_remote_uri_succeeds(reg_model, model_path, mock_s3_bucket): 223 model, inference_dataframe = reg_model 224 mlflow.catboost.save_model(cb_model=model, path=model_path) 225 artifact_root = f"s3://{mock_s3_bucket}" 226 artifact_repo = S3ArtifactRepository(artifact_root) 227 artifact_path = "model" 228 artifact_repo.log_artifacts(model_path, artifact_path=artifact_path) 229 230 model_uri = artifact_root + "/" + artifact_path 231 loaded_model = mlflow.catboost.load_model(model_uri=model_uri) 232 np.testing.assert_array_almost_equal( 233 model.predict(inference_dataframe), 234 loaded_model.predict(inference_dataframe), 235 ) 236 237 238 def test_log_model(cb_model, tmp_path): 239 model, inference_dataframe = cb_model 240 with mlflow.start_run(): 241 artifact_path = "model" 242 conda_env = os.path.join(tmp_path, "conda_env.yaml") 243 _mlflow_conda_env(conda_env, additional_pip_deps=["catboost"]) 244 245 model_info = mlflow.catboost.log_model(model, name=artifact_path, conda_env=conda_env) 246 247 loaded_model = mlflow.catboost.load_model(model_info.model_uri) 248 np.testing.assert_array_almost_equal( 249 model.predict(inference_dataframe), 250 loaded_model.predict(inference_dataframe), 251 ) 252 253 local_path = _download_artifact_from_uri(model_info.model_uri) 254 model_config = Model.load(os.path.join(local_path, "MLmodel")) 255 assert pyfunc.FLAVOR_NAME in model_config.flavors 256 assert pyfunc.ENV in model_config.flavors[pyfunc.FLAVOR_NAME] 257 env_path = model_config.flavors[pyfunc.FLAVOR_NAME][pyfunc.ENV]["conda"] 258 assert os.path.exists(os.path.join(local_path, env_path)) 259 260 261 def test_log_model_calls_register_model(cb_model, tmp_path): 262 artifact_path = "model" 263 registered_model_name = "registered_model" 264 with ( 265 mlflow.start_run(), 266 mock.patch("mlflow.tracking._model_registry.fluent._register_model"), 267 ): 268 conda_env_path = os.path.join(tmp_path, "conda_env.yaml") 269 _mlflow_conda_env(conda_env_path, additional_pip_deps=["catboost"]) 270 model_info = mlflow.catboost.log_model( 271 cb_model.model, 272 name=artifact_path, 273 conda_env=conda_env_path, 274 registered_model_name=registered_model_name, 275 ) 276 assert_register_model_called_with_local_model_path( 277 register_model_mock=mlflow.tracking._model_registry.fluent._register_model, 278 model_uri=model_info.model_uri, 279 registered_model_name=registered_model_name, 280 ) 281 282 283 def test_log_model_no_registered_model_name(cb_model, tmp_path): 284 with mlflow.start_run(), mock.patch("mlflow.register_model") as register_model_mock: 285 artifact_path = "model" 286 conda_env_path = os.path.join(tmp_path, "conda_env.yaml") 287 _mlflow_conda_env(conda_env_path, additional_pip_deps=["catboost"]) 288 mlflow.catboost.log_model(cb_model.model, name=artifact_path, conda_env=conda_env_path) 289 register_model_mock.assert_not_called() 290 291 292 def test_model_save_persists_specified_conda_env_in_mlflow_model_directory( 293 reg_model, model_path, custom_env 294 ): 295 mlflow.catboost.save_model(cb_model=reg_model.model, path=model_path, conda_env=custom_env) 296 pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) 297 saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"]) 298 assert os.path.exists(saved_conda_env_path) 299 assert saved_conda_env_path != custom_env 300 assert read_yaml(saved_conda_env_path) == read_yaml(custom_env) 301 302 303 def test_model_save_persists_requirements_in_mlflow_model_directory( 304 reg_model, model_path, custom_env 305 ): 306 mlflow.catboost.save_model(cb_model=reg_model.model, path=model_path, conda_env=custom_env) 307 308 saved_pip_req_path = os.path.join(model_path, "requirements.txt") 309 _compare_conda_env_requirements(custom_env, saved_pip_req_path) 310 311 312 def test_model_save_accepts_conda_env_as_dict(reg_model, model_path): 313 conda_env = mlflow.catboost.get_default_conda_env() 314 conda_env["dependencies"].append("pytest") 315 mlflow.catboost.save_model(cb_model=reg_model.model, path=model_path, conda_env=conda_env) 316 317 pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) 318 saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"]) 319 assert os.path.exists(saved_conda_env_path) 320 assert read_yaml(saved_conda_env_path) == conda_env 321 322 323 def test_model_log_persists_specified_conda_env_in_mlflow_model_directory(reg_model, custom_env): 324 artifact_path = "model" 325 with mlflow.start_run(): 326 model_info = mlflow.catboost.log_model( 327 reg_model.model, name=artifact_path, conda_env=custom_env 328 ) 329 330 local_path = _download_artifact_from_uri(artifact_uri=model_info.model_uri) 331 pyfunc_conf = _get_flavor_configuration(model_path=local_path, flavor_name=pyfunc.FLAVOR_NAME) 332 saved_conda_env_path = os.path.join(local_path, pyfunc_conf[pyfunc.ENV]["conda"]) 333 assert os.path.exists(saved_conda_env_path) 334 assert saved_conda_env_path != custom_env 335 assert read_yaml(saved_conda_env_path) == read_yaml(custom_env) 336 337 338 def test_model_log_persists_requirements_in_mlflow_model_directory(reg_model, custom_env): 339 with mlflow.start_run(): 340 model_info = mlflow.catboost.log_model(reg_model.model, name="model", conda_env=custom_env) 341 342 local_path = _download_artifact_from_uri(artifact_uri=model_info.model_uri) 343 saved_pip_req_path = os.path.join(local_path, "requirements.txt") 344 _compare_conda_env_requirements(custom_env, saved_pip_req_path) 345 346 347 def test_log_model_with_pip_requirements(reg_model, tmp_path): 348 expected_mlflow_version = _mlflow_major_version_string() 349 # Path to a requirements file 350 req_file = tmp_path.joinpath("requirements.txt") 351 req_file.write_text("a") 352 with mlflow.start_run(): 353 model_info = mlflow.catboost.log_model( 354 reg_model.model, name="model", pip_requirements=str(req_file) 355 ) 356 _assert_pip_requirements(model_info.model_uri, [expected_mlflow_version, "a"], strict=True) 357 358 # List of requirements 359 with mlflow.start_run(): 360 model_info = mlflow.catboost.log_model( 361 reg_model.model, name="model", pip_requirements=[f"-r {req_file}", "b"] 362 ) 363 _assert_pip_requirements( 364 model_info.model_uri, [expected_mlflow_version, "a", "b"], strict=True 365 ) 366 367 # Constraints file 368 with mlflow.start_run(): 369 model_info = mlflow.catboost.log_model( 370 reg_model.model, name="model", pip_requirements=[f"-c {req_file}", "b"] 371 ) 372 _assert_pip_requirements( 373 model_info.model_uri, 374 [expected_mlflow_version, "b", "-c constraints.txt"], 375 ["a"], 376 strict=True, 377 ) 378 379 380 def test_log_model_with_extra_pip_requirements(reg_model, tmp_path): 381 expected_mlflow_version = _mlflow_major_version_string() 382 default_reqs = mlflow.catboost.get_default_pip_requirements() 383 384 # Path to a requirements file 385 req_file = tmp_path.joinpath("requirements.txt") 386 req_file.write_text("a") 387 with mlflow.start_run(): 388 model_info = mlflow.catboost.log_model( 389 reg_model.model, name="model", extra_pip_requirements=str(req_file) 390 ) 391 _assert_pip_requirements( 392 model_info.model_uri, [expected_mlflow_version, *default_reqs, "a"] 393 ) 394 395 # List of requirements 396 with mlflow.start_run(): 397 model_info = mlflow.catboost.log_model( 398 reg_model.model, name="model", extra_pip_requirements=[f"-r {req_file}", "b"] 399 ) 400 _assert_pip_requirements( 401 model_info.model_uri, [expected_mlflow_version, *default_reqs, "a", "b"] 402 ) 403 404 # Constraints file 405 with mlflow.start_run(): 406 model_info = mlflow.catboost.log_model( 407 reg_model.model, name="model", extra_pip_requirements=[f"-c {req_file}", "b"] 408 ) 409 _assert_pip_requirements( 410 model_info.model_uri, 411 [expected_mlflow_version, *default_reqs, "b", "-c constraints.txt"], 412 ["a"], 413 ) 414 415 416 def test_model_save_without_specified_conda_env_uses_default_env_with_expected_dependencies( 417 reg_model, model_path 418 ): 419 mlflow.catboost.save_model(reg_model.model, model_path) 420 _assert_pip_requirements(model_path, mlflow.catboost.get_default_pip_requirements()) 421 422 423 def test_model_log_without_specified_conda_env_uses_default_env_with_expected_dependencies( 424 reg_model, 425 ): 426 with mlflow.start_run(): 427 model_info = mlflow.catboost.log_model(reg_model.model, name="model") 428 429 _assert_pip_requirements(model_info.model_uri, mlflow.catboost.get_default_pip_requirements()) 430 431 432 def test_pyfunc_serve_and_score(reg_model): 433 model, inference_dataframe = reg_model 434 artifact_path = "model" 435 with mlflow.start_run(): 436 model_info = mlflow.catboost.log_model( 437 model, name=artifact_path, input_example=inference_dataframe 438 ) 439 440 inference_payload = load_serving_example(model_info.model_uri) 441 resp = pyfunc_serve_and_score_model( 442 model_info.model_uri, 443 data=inference_payload, 444 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 445 extra_args=EXTRA_PYFUNC_SERVING_TEST_ARGS, 446 ) 447 scores = pd.DataFrame( 448 data=json.loads(resp.content.decode("utf-8"))["predictions"] 449 ).values.squeeze() 450 np.testing.assert_array_almost_equal(scores, model.predict(inference_dataframe)) 451 452 453 def test_pyfunc_serve_and_score_sklearn(reg_model): 454 model, inference_dataframe = reg_model 455 model = Pipeline([("model", reg_model.model)]) 456 457 with mlflow.start_run(): 458 model_info = mlflow.sklearn.log_model( 459 model, name="model", input_example=inference_dataframe.head(3) 460 ) 461 462 inference_payload = load_serving_example(model_info.model_uri) 463 resp = pyfunc_serve_and_score_model( 464 model_info.model_uri, 465 inference_payload, 466 pyfunc_scoring_server.CONTENT_TYPE_JSON, 467 extra_args=EXTRA_PYFUNC_SERVING_TEST_ARGS, 468 ) 469 scores = pd.DataFrame( 470 data=json.loads(resp.content.decode("utf-8"))["predictions"] 471 ).values.squeeze() 472 np.testing.assert_array_almost_equal(scores, model.predict(inference_dataframe.head(3))) 473 474 475 def test_log_model_with_code_paths(cb_model): 476 artifact_path = "model" 477 with ( 478 mlflow.start_run(), 479 mock.patch("mlflow.catboost._add_code_from_conf_to_system_path") as add_mock, 480 ): 481 model_info = mlflow.catboost.log_model( 482 cb_model.model, name=artifact_path, code_paths=[__file__] 483 ) 484 _compare_logged_code_paths(__file__, model_info.model_uri, mlflow.catboost.FLAVOR_NAME) 485 mlflow.catboost.load_model(model_uri=model_info.model_uri) 486 add_mock.assert_called() 487 488 489 def test_virtualenv_subfield_points_to_correct_path(cb_model, model_path): 490 mlflow.catboost.save_model(cb_model.model, path=model_path) 491 pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME) 492 python_env_path = Path(model_path, pyfunc_conf[pyfunc.ENV]["virtualenv"]) 493 assert python_env_path.exists() 494 assert python_env_path.is_file() 495 496 497 def test_model_save_load_with_metadata(cb_model, model_path): 498 mlflow.catboost.save_model( 499 cb_model.model, path=model_path, metadata={"metadata_key": "metadata_value"} 500 ) 501 502 reloaded_model = mlflow.pyfunc.load_model(model_uri=model_path) 503 assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value" 504 505 506 def test_model_log_with_metadata(cb_model): 507 with mlflow.start_run(): 508 model_info = mlflow.catboost.log_model( 509 cb_model.model, name="model", metadata={"metadata_key": "metadata_value"} 510 ) 511 512 reloaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri) 513 assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value" 514 515 516 def test_model_log_with_signature_inference(cb_model): 517 artifact_path = "model" 518 example = cb_model.inference_dataframe.head(3) 519 520 with mlflow.start_run(): 521 model_info = mlflow.catboost.log_model( 522 cb_model.model, name=artifact_path, input_example=example 523 ) 524 525 loaded_model_info = Model.load(model_info.model_uri) 526 assert loaded_model_info.signature.inputs == Schema([ 527 ColSpec(name="sepal length (cm)", type=DataType.double), 528 ColSpec(name="sepal width (cm)", type=DataType.double), 529 ]) 530 assert loaded_model_info.signature.outputs in [ 531 # when the model output is a 1D numpy array, it is cast into a `ColSpec` 532 Schema([ColSpec(type=DataType.double)]), 533 # when the model output is a higher dimensional numpy array, it remains a `TensorSpec` 534 Schema([TensorSpec(np.dtype("int64"), (-1, 1))]), 535 ]