/ tests / catboost / test_catboost_model_export.py
test_catboost_model_export.py
  1  import json
  2  import os
  3  from pathlib import Path
  4  from typing import Any, NamedTuple
  5  from unittest import mock
  6  
  7  import catboost as cb
  8  import numpy as np
  9  import pandas as pd
 10  import pytest
 11  import yaml
 12  from packaging.version import Version
 13  from sklearn import datasets
 14  from sklearn.pipeline import Pipeline
 15  
 16  import mlflow.catboost
 17  import mlflow.pyfunc.scoring_server as pyfunc_scoring_server
 18  from mlflow import pyfunc
 19  from mlflow.models import Model, ModelSignature
 20  from mlflow.models.utils import _read_example, load_serving_example
 21  from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository
 22  from mlflow.tracking.artifact_utils import _download_artifact_from_uri
 23  from mlflow.types import DataType
 24  from mlflow.types.schema import ColSpec, Schema, TensorSpec
 25  from mlflow.utils.environment import _mlflow_conda_env
 26  from mlflow.utils.model_utils import _get_flavor_configuration
 27  
 28  from tests.helper_functions import (
 29      _assert_pip_requirements,
 30      _compare_conda_env_requirements,
 31      _compare_logged_code_paths,
 32      _is_available_on_pypi,
 33      _mlflow_major_version_string,
 34      assert_register_model_called_with_local_model_path,
 35      pyfunc_serve_and_score_model,
 36  )
 37  
 38  EXTRA_PYFUNC_SERVING_TEST_ARGS = (
 39      [] if _is_available_on_pypi("catboost") else ["--env-manager", "local"]
 40  )
 41  
 42  
 43  class ModelWithData(NamedTuple):
 44      model: Any
 45      inference_dataframe: Any
 46  
 47  
 48  def get_iris():
 49      iris = datasets.load_iris()
 50      X = pd.DataFrame(iris.data[:, :2], columns=iris.feature_names[:2])
 51      y = pd.Series(iris.target)
 52      return X, y
 53  
 54  
 55  def read_yaml(path):
 56      with open(path) as f:
 57          return yaml.safe_load(f)
 58  
 59  
 60  MODEL_PARAMS = {"allow_writing_files": False, "iterations": 10}
 61  
 62  
 63  def iter_models():
 64      X, y = get_iris()
 65      model = cb.CatBoost(MODEL_PARAMS).fit(X, y)
 66      yield ModelWithData(model=model, inference_dataframe=X)
 67  
 68      model = cb.CatBoostClassifier(**MODEL_PARAMS).fit(X, y)
 69      yield ModelWithData(model=model, inference_dataframe=X)
 70  
 71      model = cb.CatBoostRegressor(**MODEL_PARAMS).fit(X, y)
 72      yield ModelWithData(model=model, inference_dataframe=X)
 73  
 74  
 75  @pytest.fixture(
 76      scope="module",
 77      params=iter_models(),
 78      ids=["CatBoost", "CatBoostClassifier", "CatBoostRegressor"],
 79  )
 80  def cb_model(request):
 81      return request.param
 82  
 83  
 84  @pytest.fixture
 85  def reg_model():
 86      model = cb.CatBoostRegressor(**MODEL_PARAMS)
 87      X, y = get_iris()
 88      return ModelWithData(model=model.fit(X, y), inference_dataframe=X)
 89  
 90  
 91  def get_reg_model_signature():
 92      return ModelSignature(
 93          inputs=Schema([
 94              ColSpec(name="sepal length (cm)", type=DataType.double),
 95              ColSpec(name="sepal width (cm)", type=DataType.double),
 96          ]),
 97          outputs=Schema([ColSpec(type=DataType.double)]),
 98      )
 99  
100  
101  @pytest.fixture
102  def model_path(tmp_path):
103      return os.path.join(tmp_path, "model")
104  
105  
106  @pytest.fixture
107  def custom_env(tmp_path):
108      conda_env_path = os.path.join(tmp_path, "conda_env.yml")
109      _mlflow_conda_env(conda_env_path, additional_pip_deps=["catboost", "pytest"])
110      return conda_env_path
111  
112  
113  @pytest.mark.parametrize("model_type", ["CatBoost", "CatBoostClassifier", "CatBoostRegressor"])
114  def test_init_model(model_type):
115      model = mlflow.catboost._init_model(model_type)
116      assert model.__class__.__name__ == model_type
117  
118  
119  @pytest.mark.skipif(
120      Version(cb.__version__) < Version("0.26.0"),
121      reason="catboost < 0.26.0 does not support CatBoostRanker",
122  )
123  def test_log_catboost_ranker():
124      """
125      This is a separate test for the CatBoostRanker model.
126      It is separate since the ranking task requires a group_id column which makes the code different.
127      """
128      # the ranking task requires setting a group_id
129      # we are creating a dummy group_id here that doesn't make any sense for the Iris dataset,
130      # but is ok for testing if the code is running correctly
131      X, y = get_iris()
132      dummy_group_id = np.arange(len(X)) % 3
133      dummy_group_id.sort()
134  
135      model = cb.CatBoostRanker(**MODEL_PARAMS, subsample=1.0)
136      model.fit(X, y, group_id=dummy_group_id)
137  
138      with mlflow.start_run():
139          model_info = mlflow.catboost.log_model(model, name="model")
140          loaded_model = mlflow.catboost.load_model(model_info.model_uri)
141          assert isinstance(loaded_model, cb.CatBoostRanker)
142          np.testing.assert_array_almost_equal(model.predict(X), loaded_model.predict(X))
143  
144  
145  def test_init_model_throws_for_invalid_model_type():
146      with pytest.raises(TypeError, match="Invalid model type"):
147          mlflow.catboost._init_model("unsupported")
148  
149  
150  def test_model_save_load(cb_model, model_path):
151      model, inference_dataframe = cb_model
152      mlflow.catboost.save_model(cb_model=model, path=model_path)
153  
154      loaded_model = mlflow.catboost.load_model(model_uri=model_path)
155      np.testing.assert_array_almost_equal(
156          model.predict(inference_dataframe),
157          loaded_model.predict(inference_dataframe),
158      )
159  
160      loaded_pyfunc = pyfunc.load_model(model_uri=model_path)
161      np.testing.assert_array_almost_equal(
162          loaded_model.predict(inference_dataframe),
163          loaded_pyfunc.predict(inference_dataframe),
164      )
165  
166  
167  def test_log_model_logs_model_type(cb_model):
168      with mlflow.start_run():
169          artifact_path = "model"
170          model_info = mlflow.catboost.log_model(cb_model.model, name=artifact_path)
171  
172      flavor_conf = Model.load(model_info.model_uri).flavors["catboost"]
173      assert "model_type" in flavor_conf
174      assert flavor_conf["model_type"] == cb_model.model.__class__.__name__
175  
176  
177  # Supported serialization formats:
178  # https://catboost.ai/docs/concepts/python-reference_catboost_save_model.html
179  SUPPORTS_DESERIALIZATION = ["cbm", "coreml", "json", "onnx"]
180  save_formats = SUPPORTS_DESERIALIZATION + ["python", "cpp", "pmml"]
181  
182  
183  @pytest.mark.allow_infer_pip_requirements_fallback
184  @pytest.mark.parametrize("save_format", save_formats)
185  def test_log_model_logs_save_format(reg_model, save_format):
186      with mlflow.start_run():
187          artifact_path = "model"
188          model_info = mlflow.catboost.log_model(
189              reg_model.model, name=artifact_path, format=save_format
190          )
191  
192      flavor_conf = Model.load(model_info.model_uri).flavors["catboost"]
193      assert "save_format" in flavor_conf
194      assert flavor_conf["save_format"] == save_format
195  
196      if save_format in SUPPORTS_DESERIALIZATION:
197          mlflow.catboost.load_model(model_info.model_uri)
198      else:
199          with pytest.raises(cb.CatBoostError, match="deserialization not supported or missing"):
200              mlflow.catboost.load_model(model_info.model_uri)
201  
202  
203  @pytest.mark.parametrize("signature", [None, get_reg_model_signature()])
204  @pytest.mark.parametrize("input_example", [None, get_iris()[0].head(3)])
205  def test_signature_and_examples_are_saved_correctly(
206      reg_model, model_path, signature, input_example
207  ):
208      mlflow.catboost.save_model(
209          reg_model.model, model_path, signature=signature, input_example=input_example
210      )
211      mlflow_model = Model.load(model_path)
212      if signature is None and input_example is None:
213          assert mlflow_model.signature is None
214      else:
215          assert mlflow_model.signature == get_reg_model_signature()
216      if input_example is None:
217          assert mlflow_model.saved_input_example_info is None
218      else:
219          pd.testing.assert_frame_equal(_read_example(mlflow_model, model_path), input_example)
220  
221  
222  def test_model_load_from_remote_uri_succeeds(reg_model, model_path, mock_s3_bucket):
223      model, inference_dataframe = reg_model
224      mlflow.catboost.save_model(cb_model=model, path=model_path)
225      artifact_root = f"s3://{mock_s3_bucket}"
226      artifact_repo = S3ArtifactRepository(artifact_root)
227      artifact_path = "model"
228      artifact_repo.log_artifacts(model_path, artifact_path=artifact_path)
229  
230      model_uri = artifact_root + "/" + artifact_path
231      loaded_model = mlflow.catboost.load_model(model_uri=model_uri)
232      np.testing.assert_array_almost_equal(
233          model.predict(inference_dataframe),
234          loaded_model.predict(inference_dataframe),
235      )
236  
237  
238  def test_log_model(cb_model, tmp_path):
239      model, inference_dataframe = cb_model
240      with mlflow.start_run():
241          artifact_path = "model"
242          conda_env = os.path.join(tmp_path, "conda_env.yaml")
243          _mlflow_conda_env(conda_env, additional_pip_deps=["catboost"])
244  
245          model_info = mlflow.catboost.log_model(model, name=artifact_path, conda_env=conda_env)
246  
247          loaded_model = mlflow.catboost.load_model(model_info.model_uri)
248          np.testing.assert_array_almost_equal(
249              model.predict(inference_dataframe),
250              loaded_model.predict(inference_dataframe),
251          )
252  
253          local_path = _download_artifact_from_uri(model_info.model_uri)
254          model_config = Model.load(os.path.join(local_path, "MLmodel"))
255          assert pyfunc.FLAVOR_NAME in model_config.flavors
256          assert pyfunc.ENV in model_config.flavors[pyfunc.FLAVOR_NAME]
257          env_path = model_config.flavors[pyfunc.FLAVOR_NAME][pyfunc.ENV]["conda"]
258          assert os.path.exists(os.path.join(local_path, env_path))
259  
260  
261  def test_log_model_calls_register_model(cb_model, tmp_path):
262      artifact_path = "model"
263      registered_model_name = "registered_model"
264      with (
265          mlflow.start_run(),
266          mock.patch("mlflow.tracking._model_registry.fluent._register_model"),
267      ):
268          conda_env_path = os.path.join(tmp_path, "conda_env.yaml")
269          _mlflow_conda_env(conda_env_path, additional_pip_deps=["catboost"])
270          model_info = mlflow.catboost.log_model(
271              cb_model.model,
272              name=artifact_path,
273              conda_env=conda_env_path,
274              registered_model_name=registered_model_name,
275          )
276          assert_register_model_called_with_local_model_path(
277              register_model_mock=mlflow.tracking._model_registry.fluent._register_model,
278              model_uri=model_info.model_uri,
279              registered_model_name=registered_model_name,
280          )
281  
282  
283  def test_log_model_no_registered_model_name(cb_model, tmp_path):
284      with mlflow.start_run(), mock.patch("mlflow.register_model") as register_model_mock:
285          artifact_path = "model"
286          conda_env_path = os.path.join(tmp_path, "conda_env.yaml")
287          _mlflow_conda_env(conda_env_path, additional_pip_deps=["catboost"])
288          mlflow.catboost.log_model(cb_model.model, name=artifact_path, conda_env=conda_env_path)
289          register_model_mock.assert_not_called()
290  
291  
292  def test_model_save_persists_specified_conda_env_in_mlflow_model_directory(
293      reg_model, model_path, custom_env
294  ):
295      mlflow.catboost.save_model(cb_model=reg_model.model, path=model_path, conda_env=custom_env)
296      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
297      saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"])
298      assert os.path.exists(saved_conda_env_path)
299      assert saved_conda_env_path != custom_env
300      assert read_yaml(saved_conda_env_path) == read_yaml(custom_env)
301  
302  
303  def test_model_save_persists_requirements_in_mlflow_model_directory(
304      reg_model, model_path, custom_env
305  ):
306      mlflow.catboost.save_model(cb_model=reg_model.model, path=model_path, conda_env=custom_env)
307  
308      saved_pip_req_path = os.path.join(model_path, "requirements.txt")
309      _compare_conda_env_requirements(custom_env, saved_pip_req_path)
310  
311  
312  def test_model_save_accepts_conda_env_as_dict(reg_model, model_path):
313      conda_env = mlflow.catboost.get_default_conda_env()
314      conda_env["dependencies"].append("pytest")
315      mlflow.catboost.save_model(cb_model=reg_model.model, path=model_path, conda_env=conda_env)
316  
317      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
318      saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"])
319      assert os.path.exists(saved_conda_env_path)
320      assert read_yaml(saved_conda_env_path) == conda_env
321  
322  
323  def test_model_log_persists_specified_conda_env_in_mlflow_model_directory(reg_model, custom_env):
324      artifact_path = "model"
325      with mlflow.start_run():
326          model_info = mlflow.catboost.log_model(
327              reg_model.model, name=artifact_path, conda_env=custom_env
328          )
329  
330      local_path = _download_artifact_from_uri(artifact_uri=model_info.model_uri)
331      pyfunc_conf = _get_flavor_configuration(model_path=local_path, flavor_name=pyfunc.FLAVOR_NAME)
332      saved_conda_env_path = os.path.join(local_path, pyfunc_conf[pyfunc.ENV]["conda"])
333      assert os.path.exists(saved_conda_env_path)
334      assert saved_conda_env_path != custom_env
335      assert read_yaml(saved_conda_env_path) == read_yaml(custom_env)
336  
337  
338  def test_model_log_persists_requirements_in_mlflow_model_directory(reg_model, custom_env):
339      with mlflow.start_run():
340          model_info = mlflow.catboost.log_model(reg_model.model, name="model", conda_env=custom_env)
341  
342      local_path = _download_artifact_from_uri(artifact_uri=model_info.model_uri)
343      saved_pip_req_path = os.path.join(local_path, "requirements.txt")
344      _compare_conda_env_requirements(custom_env, saved_pip_req_path)
345  
346  
347  def test_log_model_with_pip_requirements(reg_model, tmp_path):
348      expected_mlflow_version = _mlflow_major_version_string()
349      # Path to a requirements file
350      req_file = tmp_path.joinpath("requirements.txt")
351      req_file.write_text("a")
352      with mlflow.start_run():
353          model_info = mlflow.catboost.log_model(
354              reg_model.model, name="model", pip_requirements=str(req_file)
355          )
356          _assert_pip_requirements(model_info.model_uri, [expected_mlflow_version, "a"], strict=True)
357  
358      # List of requirements
359      with mlflow.start_run():
360          model_info = mlflow.catboost.log_model(
361              reg_model.model, name="model", pip_requirements=[f"-r {req_file}", "b"]
362          )
363          _assert_pip_requirements(
364              model_info.model_uri, [expected_mlflow_version, "a", "b"], strict=True
365          )
366  
367      # Constraints file
368      with mlflow.start_run():
369          model_info = mlflow.catboost.log_model(
370              reg_model.model, name="model", pip_requirements=[f"-c {req_file}", "b"]
371          )
372          _assert_pip_requirements(
373              model_info.model_uri,
374              [expected_mlflow_version, "b", "-c constraints.txt"],
375              ["a"],
376              strict=True,
377          )
378  
379  
380  def test_log_model_with_extra_pip_requirements(reg_model, tmp_path):
381      expected_mlflow_version = _mlflow_major_version_string()
382      default_reqs = mlflow.catboost.get_default_pip_requirements()
383  
384      # Path to a requirements file
385      req_file = tmp_path.joinpath("requirements.txt")
386      req_file.write_text("a")
387      with mlflow.start_run():
388          model_info = mlflow.catboost.log_model(
389              reg_model.model, name="model", extra_pip_requirements=str(req_file)
390          )
391          _assert_pip_requirements(
392              model_info.model_uri, [expected_mlflow_version, *default_reqs, "a"]
393          )
394  
395      # List of requirements
396      with mlflow.start_run():
397          model_info = mlflow.catboost.log_model(
398              reg_model.model, name="model", extra_pip_requirements=[f"-r {req_file}", "b"]
399          )
400          _assert_pip_requirements(
401              model_info.model_uri, [expected_mlflow_version, *default_reqs, "a", "b"]
402          )
403  
404      # Constraints file
405      with mlflow.start_run():
406          model_info = mlflow.catboost.log_model(
407              reg_model.model, name="model", extra_pip_requirements=[f"-c {req_file}", "b"]
408          )
409          _assert_pip_requirements(
410              model_info.model_uri,
411              [expected_mlflow_version, *default_reqs, "b", "-c constraints.txt"],
412              ["a"],
413          )
414  
415  
416  def test_model_save_without_specified_conda_env_uses_default_env_with_expected_dependencies(
417      reg_model, model_path
418  ):
419      mlflow.catboost.save_model(reg_model.model, model_path)
420      _assert_pip_requirements(model_path, mlflow.catboost.get_default_pip_requirements())
421  
422  
423  def test_model_log_without_specified_conda_env_uses_default_env_with_expected_dependencies(
424      reg_model,
425  ):
426      with mlflow.start_run():
427          model_info = mlflow.catboost.log_model(reg_model.model, name="model")
428  
429      _assert_pip_requirements(model_info.model_uri, mlflow.catboost.get_default_pip_requirements())
430  
431  
432  def test_pyfunc_serve_and_score(reg_model):
433      model, inference_dataframe = reg_model
434      artifact_path = "model"
435      with mlflow.start_run():
436          model_info = mlflow.catboost.log_model(
437              model, name=artifact_path, input_example=inference_dataframe
438          )
439  
440      inference_payload = load_serving_example(model_info.model_uri)
441      resp = pyfunc_serve_and_score_model(
442          model_info.model_uri,
443          data=inference_payload,
444          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
445          extra_args=EXTRA_PYFUNC_SERVING_TEST_ARGS,
446      )
447      scores = pd.DataFrame(
448          data=json.loads(resp.content.decode("utf-8"))["predictions"]
449      ).values.squeeze()
450      np.testing.assert_array_almost_equal(scores, model.predict(inference_dataframe))
451  
452  
453  def test_pyfunc_serve_and_score_sklearn(reg_model):
454      model, inference_dataframe = reg_model
455      model = Pipeline([("model", reg_model.model)])
456  
457      with mlflow.start_run():
458          model_info = mlflow.sklearn.log_model(
459              model, name="model", input_example=inference_dataframe.head(3)
460          )
461  
462      inference_payload = load_serving_example(model_info.model_uri)
463      resp = pyfunc_serve_and_score_model(
464          model_info.model_uri,
465          inference_payload,
466          pyfunc_scoring_server.CONTENT_TYPE_JSON,
467          extra_args=EXTRA_PYFUNC_SERVING_TEST_ARGS,
468      )
469      scores = pd.DataFrame(
470          data=json.loads(resp.content.decode("utf-8"))["predictions"]
471      ).values.squeeze()
472      np.testing.assert_array_almost_equal(scores, model.predict(inference_dataframe.head(3)))
473  
474  
475  def test_log_model_with_code_paths(cb_model):
476      artifact_path = "model"
477      with (
478          mlflow.start_run(),
479          mock.patch("mlflow.catboost._add_code_from_conf_to_system_path") as add_mock,
480      ):
481          model_info = mlflow.catboost.log_model(
482              cb_model.model, name=artifact_path, code_paths=[__file__]
483          )
484          _compare_logged_code_paths(__file__, model_info.model_uri, mlflow.catboost.FLAVOR_NAME)
485          mlflow.catboost.load_model(model_uri=model_info.model_uri)
486          add_mock.assert_called()
487  
488  
489  def test_virtualenv_subfield_points_to_correct_path(cb_model, model_path):
490      mlflow.catboost.save_model(cb_model.model, path=model_path)
491      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
492      python_env_path = Path(model_path, pyfunc_conf[pyfunc.ENV]["virtualenv"])
493      assert python_env_path.exists()
494      assert python_env_path.is_file()
495  
496  
497  def test_model_save_load_with_metadata(cb_model, model_path):
498      mlflow.catboost.save_model(
499          cb_model.model, path=model_path, metadata={"metadata_key": "metadata_value"}
500      )
501  
502      reloaded_model = mlflow.pyfunc.load_model(model_uri=model_path)
503      assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value"
504  
505  
506  def test_model_log_with_metadata(cb_model):
507      with mlflow.start_run():
508          model_info = mlflow.catboost.log_model(
509              cb_model.model, name="model", metadata={"metadata_key": "metadata_value"}
510          )
511  
512      reloaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri)
513      assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value"
514  
515  
516  def test_model_log_with_signature_inference(cb_model):
517      artifact_path = "model"
518      example = cb_model.inference_dataframe.head(3)
519  
520      with mlflow.start_run():
521          model_info = mlflow.catboost.log_model(
522              cb_model.model, name=artifact_path, input_example=example
523          )
524  
525      loaded_model_info = Model.load(model_info.model_uri)
526      assert loaded_model_info.signature.inputs == Schema([
527          ColSpec(name="sepal length (cm)", type=DataType.double),
528          ColSpec(name="sepal width (cm)", type=DataType.double),
529      ])
530      assert loaded_model_info.signature.outputs in [
531          # when the model output is a 1D numpy array, it is cast into a `ColSpec`
532          Schema([ColSpec(type=DataType.double)]),
533          # when the model output is a higher dimensional numpy array, it remains a `TensorSpec`
534          Schema([TensorSpec(np.dtype("int64"), (-1, 1))]),
535      ]