/ tests / xgboost / test_xgboost_model_export.py
test_xgboost_model_export.py
  1  import json
  2  import os
  3  from pathlib import Path
  4  from typing import Any, NamedTuple
  5  from unittest import mock
  6  
  7  import numpy as np
  8  import pandas as pd
  9  import pytest
 10  import xgboost as xgb
 11  import yaml
 12  from sklearn import datasets
 13  from sklearn.pipeline import Pipeline
 14  
 15  import mlflow.pyfunc.scoring_server as pyfunc_scoring_server
 16  import mlflow.utils
 17  import mlflow.xgboost
 18  from mlflow import pyfunc
 19  from mlflow.models import Model, ModelSignature, infer_signature
 20  from mlflow.models.utils import _read_example, load_serving_example
 21  from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository
 22  from mlflow.tracking.artifact_utils import _download_artifact_from_uri
 23  from mlflow.types import DataType
 24  from mlflow.types.schema import ColSpec, Schema, TensorSpec
 25  from mlflow.utils.environment import _mlflow_conda_env
 26  from mlflow.utils.file_utils import TempDir
 27  from mlflow.utils.model_utils import _get_flavor_configuration
 28  from mlflow.utils.proto_json_utils import dataframe_from_parsed_json
 29  from mlflow.xgboost import _exclude_unrecognized_kwargs
 30  
 31  from tests.helper_functions import (
 32      _assert_pip_requirements,
 33      _compare_conda_env_requirements,
 34      _compare_logged_code_paths,
 35      _is_available_on_pypi,
 36      _mlflow_major_version_string,
 37      assert_register_model_called_with_local_model_path,
 38      pyfunc_serve_and_score_model,
 39  )
 40  
 41  EXTRA_PYFUNC_SERVING_TEST_ARGS = (
 42      [] if _is_available_on_pypi("xgboost") else ["--env-manager", "local"]
 43  )
 44  
 45  
 46  class ModelWithData(NamedTuple):
 47      model: Any
 48      inference_dataframe: pd.DataFrame
 49      inference_dmatrix: xgb.DMatrix
 50  
 51  
 52  @pytest.fixture(scope="module")
 53  def xgb_model():
 54      iris = datasets.load_iris()
 55      X = pd.DataFrame(
 56          iris.data[:, :2],
 57          columns=iris.feature_names[:2],  # we only take the first two features.
 58      )
 59      y = iris.target
 60      dtrain = xgb.DMatrix(X, y)
 61      model = xgb.train({"objective": "multi:softprob", "num_class": 3}, dtrain)
 62      return ModelWithData(model=model, inference_dataframe=X, inference_dmatrix=dtrain)
 63  
 64  
 65  @pytest.fixture(scope="module")
 66  def xgb_model_signature():
 67      return ModelSignature(
 68          inputs=Schema([
 69              ColSpec(name="sepal length (cm)", type=DataType.double),
 70              ColSpec(name="sepal width (cm)", type=DataType.double),
 71          ]),
 72          outputs=Schema([TensorSpec(np.dtype("float32"), (-1, 3))]),
 73      )
 74  
 75  
 76  @pytest.fixture(scope="module")
 77  def xgb_sklearn_model():
 78      wine = datasets.load_wine()
 79      X = pd.DataFrame(wine.data, columns=wine.feature_names)
 80      y = pd.Series(wine.target)
 81      regressor = xgb.XGBRegressor(n_estimators=10)
 82      regressor.fit(X, y)
 83      return ModelWithData(model=regressor, inference_dataframe=X, inference_dmatrix=None)
 84  
 85  
 86  @pytest.fixture
 87  def model_path(tmp_path):
 88      return os.path.join(tmp_path, "model")
 89  
 90  
 91  @pytest.fixture
 92  def xgb_custom_env(tmp_path):
 93      conda_env = os.path.join(tmp_path, "conda_env.yml")
 94      _mlflow_conda_env(conda_env, additional_pip_deps=["xgboost", "pytest"])
 95      return conda_env
 96  
 97  
 98  def test_model_save_load(xgb_model, model_path):
 99      model = xgb_model.model
100  
101      mlflow.xgboost.save_model(xgb_model=model, path=model_path)
102      reloaded_model = mlflow.xgboost.load_model(model_uri=model_path)
103      reloaded_pyfunc = pyfunc.load_model(model_uri=model_path)
104  
105      np.testing.assert_array_almost_equal(
106          model.predict(xgb_model.inference_dmatrix),
107          reloaded_model.predict(xgb_model.inference_dmatrix),
108      )
109  
110      np.testing.assert_array_almost_equal(
111          reloaded_model.predict(xgb_model.inference_dmatrix),
112          reloaded_pyfunc.predict(xgb_model.inference_dataframe),
113      )
114  
115  
116  def test_sklearn_model_save_load(xgb_sklearn_model, model_path):
117      model = xgb_sklearn_model.model
118      mlflow.xgboost.save_model(xgb_model=model, path=model_path)
119      reloaded_model = mlflow.xgboost.load_model(model_uri=model_path)
120      reloaded_pyfunc = pyfunc.load_model(model_uri=model_path)
121  
122      np.testing.assert_array_almost_equal(
123          model.predict(xgb_sklearn_model.inference_dataframe),
124          reloaded_model.predict(xgb_sklearn_model.inference_dataframe),
125      )
126  
127      np.testing.assert_array_almost_equal(
128          reloaded_model.predict(xgb_sklearn_model.inference_dataframe),
129          reloaded_pyfunc.predict(xgb_sklearn_model.inference_dataframe),
130      )
131  
132  
133  def test_signature_and_examples_are_saved_correctly(xgb_model, xgb_model_signature):
134      model = xgb_model.model
135      for signature in (None, xgb_model_signature):
136          for example in (None, xgb_model.inference_dataframe.head(3)):
137              with TempDir() as tmp:
138                  path = tmp.path("model")
139                  mlflow.xgboost.save_model(
140                      xgb_model=model, path=path, signature=signature, input_example=example
141                  )
142                  mlflow_model = Model.load(path)
143                  if signature is None and example is None:
144                      assert mlflow_model.signature is None
145                  else:
146                      assert mlflow_model.signature == xgb_model_signature
147                  if example is None:
148                      assert mlflow_model.saved_input_example_info is None
149                  else:
150                      assert all((_read_example(mlflow_model, path) == example).all())
151  
152  
153  def test_model_load_from_remote_uri_succeeds(xgb_model, model_path, mock_s3_bucket):
154      mlflow.xgboost.save_model(xgb_model=xgb_model.model, path=model_path)
155  
156      artifact_root = f"s3://{mock_s3_bucket}"
157      artifact_path = "model"
158      artifact_repo = S3ArtifactRepository(artifact_root)
159      artifact_repo.log_artifacts(model_path, artifact_path=artifact_path)
160  
161      model_uri = artifact_root + "/" + artifact_path
162      reloaded_model = mlflow.xgboost.load_model(model_uri=model_uri)
163      np.testing.assert_array_almost_equal(
164          xgb_model.model.predict(xgb_model.inference_dmatrix),
165          reloaded_model.predict(xgb_model.inference_dmatrix),
166      )
167  
168  
169  def test_model_log(xgb_model, model_path):
170      model = xgb_model.model
171      with TempDir(chdr=True, remove_on_exit=True) as tmp:
172          for should_start_run in [False, True]:
173              try:
174                  if should_start_run:
175                      mlflow.start_run()
176  
177                  artifact_path = "model"
178                  conda_env = os.path.join(tmp.path(), "conda_env.yaml")
179                  _mlflow_conda_env(conda_env, additional_pip_deps=["xgboost"])
180  
181                  model_info = mlflow.xgboost.log_model(
182                      model, name=artifact_path, conda_env=conda_env
183                  )
184                  reloaded_model = mlflow.xgboost.load_model(model_uri=model_info.model_uri)
185                  np.testing.assert_array_almost_equal(
186                      model.predict(xgb_model.inference_dmatrix),
187                      reloaded_model.predict(xgb_model.inference_dmatrix),
188                  )
189  
190                  model_path = _download_artifact_from_uri(artifact_uri=model_info.model_uri)
191                  model_config = Model.load(os.path.join(model_path, "MLmodel"))
192                  assert pyfunc.FLAVOR_NAME in model_config.flavors
193                  assert pyfunc.ENV in model_config.flavors[pyfunc.FLAVOR_NAME]
194                  env_path = model_config.flavors[pyfunc.FLAVOR_NAME][pyfunc.ENV]["conda"]
195                  assert os.path.exists(os.path.join(model_path, env_path))
196  
197              finally:
198                  mlflow.end_run()
199  
200  
201  def test_log_model_calls_register_model(xgb_model):
202      artifact_path = "model"
203      register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model")
204      with mlflow.start_run(), register_model_patch, TempDir(chdr=True, remove_on_exit=True) as tmp:
205          conda_env = os.path.join(tmp.path(), "conda_env.yaml")
206          _mlflow_conda_env(conda_env, additional_pip_deps=["xgboost"])
207          model_info = mlflow.xgboost.log_model(
208              xgb_model.model,
209              name=artifact_path,
210              conda_env=conda_env,
211              registered_model_name="AdsModel1",
212          )
213          assert_register_model_called_with_local_model_path(
214              register_model_mock=mlflow.tracking._model_registry.fluent._register_model,
215              model_uri=model_info.model_uri,
216              registered_model_name="AdsModel1",
217          )
218  
219  
220  def test_log_model_no_registered_model_name(xgb_model):
221      artifact_path = "model"
222      register_model_patch = mock.patch("mlflow.tracking._model_registry.fluent._register_model")
223      with mlflow.start_run(), register_model_patch, TempDir(chdr=True, remove_on_exit=True) as tmp:
224          conda_env = os.path.join(tmp.path(), "conda_env.yaml")
225          _mlflow_conda_env(conda_env, additional_pip_deps=["xgboost"])
226          mlflow.xgboost.log_model(xgb_model.model, name=artifact_path, conda_env=conda_env)
227          mlflow.tracking._model_registry.fluent._register_model.assert_not_called()
228  
229  
230  def test_model_save_persists_specified_conda_env_in_mlflow_model_directory(
231      xgb_model, model_path, xgb_custom_env
232  ):
233      mlflow.xgboost.save_model(xgb_model=xgb_model.model, path=model_path, conda_env=xgb_custom_env)
234  
235      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
236      saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"])
237      assert os.path.exists(saved_conda_env_path)
238      assert saved_conda_env_path != xgb_custom_env
239  
240      with open(xgb_custom_env) as f:
241          xgb_custom_env_parsed = yaml.safe_load(f)
242      with open(saved_conda_env_path) as f:
243          saved_conda_env_parsed = yaml.safe_load(f)
244      assert saved_conda_env_parsed == xgb_custom_env_parsed
245  
246  
247  def test_model_save_persists_requirements_in_mlflow_model_directory(
248      xgb_model, model_path, xgb_custom_env
249  ):
250      mlflow.xgboost.save_model(xgb_model=xgb_model.model, path=model_path, conda_env=xgb_custom_env)
251  
252      saved_pip_req_path = os.path.join(model_path, "requirements.txt")
253      _compare_conda_env_requirements(xgb_custom_env, saved_pip_req_path)
254  
255  
256  def test_save_model_with_pip_requirements(xgb_model, tmp_path):
257      expected_mlflow_version = _mlflow_major_version_string()
258      # Path to a requirements file
259      tmpdir1 = tmp_path.joinpath("1")
260      req_file = tmp_path.joinpath("requirements.txt")
261      req_file.write_text("a")
262      mlflow.xgboost.save_model(xgb_model.model, tmpdir1, pip_requirements=str(req_file))
263      _assert_pip_requirements(tmpdir1, [expected_mlflow_version, "a"], strict=True)
264  
265      # List of requirements
266      tmpdir2 = tmp_path.joinpath("2")
267      mlflow.xgboost.save_model(xgb_model.model, tmpdir2, pip_requirements=[f"-r {req_file}", "b"])
268      _assert_pip_requirements(tmpdir2, [expected_mlflow_version, "a", "b"], strict=True)
269  
270      # Constraints file
271      tmpdir3 = tmp_path.joinpath("3")
272      mlflow.xgboost.save_model(xgb_model.model, tmpdir3, pip_requirements=[f"-c {req_file}", "b"])
273      _assert_pip_requirements(
274          tmpdir3, [expected_mlflow_version, "b", "-c constraints.txt"], ["a"], strict=True
275      )
276  
277  
278  def test_save_model_with_extra_pip_requirements(xgb_model, tmp_path):
279      expected_mlflow_version = _mlflow_major_version_string()
280      default_reqs = mlflow.xgboost.get_default_pip_requirements()
281  
282      # Path to a requirements file
283      tmpdir1 = tmp_path.joinpath("1")
284      req_file = tmp_path.joinpath("requirements.txt")
285      req_file.write_text("a")
286      mlflow.xgboost.save_model(xgb_model.model, tmpdir1, extra_pip_requirements=str(req_file))
287      _assert_pip_requirements(tmpdir1, [expected_mlflow_version, *default_reqs, "a"])
288  
289      # List of requirements
290      tmpdir2 = tmp_path.joinpath("2")
291      mlflow.xgboost.save_model(
292          xgb_model.model, tmpdir2, extra_pip_requirements=[f"-r {req_file}", "b"]
293      )
294      _assert_pip_requirements(tmpdir2, [expected_mlflow_version, *default_reqs, "a", "b"])
295  
296      # Constraints file
297      tmpdir3 = tmp_path.joinpath("3")
298      mlflow.xgboost.save_model(
299          xgb_model.model, tmpdir3, extra_pip_requirements=[f"-c {req_file}", "b"]
300      )
301      _assert_pip_requirements(
302          tmpdir3, [expected_mlflow_version, *default_reqs, "b", "-c constraints.txt"], ["a"]
303      )
304  
305  
306  def test_log_model_with_pip_requirements(xgb_model, tmp_path):
307      expected_mlflow_version = _mlflow_major_version_string()
308      # Path to a requirements file
309      req_file = tmp_path.joinpath("requirements.txt")
310      req_file.write_text("a")
311      with mlflow.start_run():
312          model_info = mlflow.xgboost.log_model(
313              xgb_model.model, name="model", pip_requirements=str(req_file)
314          )
315          _assert_pip_requirements(model_info.model_uri, [expected_mlflow_version, "a"], strict=True)
316  
317      # List of requirements
318      with mlflow.start_run():
319          model_info = mlflow.xgboost.log_model(
320              xgb_model.model, name="model", pip_requirements=[f"-r {req_file}", "b"]
321          )
322          _assert_pip_requirements(
323              model_info.model_uri, [expected_mlflow_version, "a", "b"], strict=True
324          )
325  
326      # Constraints file
327      with mlflow.start_run():
328          model_info = mlflow.xgboost.log_model(
329              xgb_model.model, name="model", pip_requirements=[f"-c {req_file}", "b"]
330          )
331          _assert_pip_requirements(
332              model_info.model_uri,
333              [expected_mlflow_version, "b", "-c constraints.txt"],
334              ["a"],
335              strict=True,
336          )
337  
338  
339  def test_log_model_with_extra_pip_requirements(xgb_model, tmp_path):
340      expected_mlflow_version = _mlflow_major_version_string()
341      default_reqs = mlflow.xgboost.get_default_pip_requirements()
342  
343      # Path to a requirements file
344      req_file = tmp_path.joinpath("requirements.txt")
345      req_file.write_text("a")
346      with mlflow.start_run():
347          model_info = mlflow.xgboost.log_model(
348              xgb_model.model, name="model", extra_pip_requirements=str(req_file)
349          )
350          _assert_pip_requirements(
351              model_info.model_uri, [expected_mlflow_version, *default_reqs, "a"]
352          )
353  
354      # List of requirements
355      with mlflow.start_run():
356          model_info = mlflow.xgboost.log_model(
357              xgb_model.model, name="model", extra_pip_requirements=[f"-r {req_file}", "b"]
358          )
359          _assert_pip_requirements(
360              model_info.model_uri, [expected_mlflow_version, *default_reqs, "a", "b"]
361          )
362  
363      # Constraints file
364      with mlflow.start_run():
365          model_info = mlflow.xgboost.log_model(
366              xgb_model.model, name="model", extra_pip_requirements=[f"-c {req_file}", "b"]
367          )
368          _assert_pip_requirements(
369              model_info.model_uri,
370              [expected_mlflow_version, *default_reqs, "b", "-c constraints.txt"],
371              ["a"],
372          )
373  
374  
375  def test_model_save_accepts_conda_env_as_dict(xgb_model, model_path):
376      conda_env = dict(mlflow.xgboost.get_default_conda_env())
377      conda_env["dependencies"].append("pytest")
378      mlflow.xgboost.save_model(xgb_model=xgb_model.model, path=model_path, conda_env=conda_env)
379  
380      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
381      saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"])
382      assert os.path.exists(saved_conda_env_path)
383  
384      with open(saved_conda_env_path) as f:
385          saved_conda_env_parsed = yaml.safe_load(f)
386      assert saved_conda_env_parsed == conda_env
387  
388  
389  def test_model_log_persists_specified_conda_env_in_mlflow_model_directory(
390      xgb_model, xgb_custom_env
391  ):
392      with mlflow.start_run():
393          model_info = mlflow.xgboost.log_model(
394              xgb_model.model, name="model", conda_env=xgb_custom_env
395          )
396  
397      model_path = _download_artifact_from_uri(artifact_uri=model_info.model_uri)
398      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
399      saved_conda_env_path = os.path.join(model_path, pyfunc_conf[pyfunc.ENV]["conda"])
400      assert os.path.exists(saved_conda_env_path)
401      assert saved_conda_env_path != xgb_custom_env
402  
403      with open(xgb_custom_env) as f:
404          xgb_custom_env_parsed = yaml.safe_load(f)
405      with open(saved_conda_env_path) as f:
406          saved_conda_env_parsed = yaml.safe_load(f)
407      assert saved_conda_env_parsed == xgb_custom_env_parsed
408  
409  
410  def test_model_log_persists_requirements_in_mlflow_model_directory(xgb_model, xgb_custom_env):
411      artifact_path = "model"
412      with mlflow.start_run():
413          model_info = mlflow.xgboost.log_model(
414              xgb_model.model, name=artifact_path, conda_env=xgb_custom_env
415          )
416  
417      model_path = _download_artifact_from_uri(artifact_uri=model_info.model_uri)
418      saved_pip_req_path = os.path.join(model_path, "requirements.txt")
419      _compare_conda_env_requirements(xgb_custom_env, saved_pip_req_path)
420  
421  
422  def test_model_save_without_specified_conda_env_uses_default_env_with_expected_dependencies(
423      xgb_model, model_path
424  ):
425      mlflow.xgboost.save_model(xgb_model=xgb_model.model, path=model_path)
426      _assert_pip_requirements(model_path, mlflow.xgboost.get_default_pip_requirements())
427  
428  
429  def test_model_log_without_specified_conda_env_uses_default_env_with_expected_dependencies(
430      xgb_model,
431  ):
432      artifact_path = "model"
433      with mlflow.start_run():
434          model_info = mlflow.xgboost.log_model(xgb_model.model, name=artifact_path)
435  
436      _assert_pip_requirements(model_info.model_uri, mlflow.xgboost.get_default_pip_requirements())
437  
438  
439  def test_pyfunc_serve_and_score(xgb_model):
440      model, inference_dataframe, inference_dmatrix = xgb_model
441      artifact_path = "model"
442      with mlflow.start_run():
443          model_info = mlflow.xgboost.log_model(
444              model, name=artifact_path, input_example=inference_dataframe
445          )
446  
447      inference_payload = load_serving_example(model_info.model_uri)
448      resp = pyfunc_serve_and_score_model(
449          model_info.model_uri,
450          data=inference_payload,
451          content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON,
452          extra_args=EXTRA_PYFUNC_SERVING_TEST_ARGS,
453      )
454      scores = pd.DataFrame(
455          data=json.loads(resp.content.decode("utf-8"))["predictions"]
456      ).values.squeeze()
457      np.testing.assert_array_almost_equal(scores, model.predict(inference_dmatrix))
458  
459  
460  def get_sklearn_models():
461      model = xgb.XGBClassifier(objective="multi:softmax", n_estimators=10)
462      pipe = Pipeline([("model", model)])
463      return [model, pipe]
464  
465  
466  @pytest.mark.parametrize("model", get_sklearn_models())
467  def test_pyfunc_serve_and_score_sklearn(model):
468      X, y = datasets.load_iris(return_X_y=True, as_frame=True)
469      model.fit(X, y)
470  
471      with mlflow.start_run():
472          model_info = mlflow.sklearn.log_model(model, name="model", input_example=X.head(3))
473  
474      inference_payload = load_serving_example(model_info.model_uri)
475      resp = pyfunc_serve_and_score_model(
476          model_info.model_uri,
477          inference_payload,
478          pyfunc_scoring_server.CONTENT_TYPE_JSON,
479          extra_args=EXTRA_PYFUNC_SERVING_TEST_ARGS,
480      )
481      scores = pd.DataFrame(
482          data=json.loads(resp.content.decode("utf-8"))["predictions"]
483      ).values.squeeze()
484      np.testing.assert_array_equal(scores, model.predict(X.head(3)))
485  
486  
487  def test_load_pyfunc_succeeds_for_older_models_with_pyfunc_data_field(xgb_model, model_path):
488      """
489      This test verifies that xgboost models saved in older versions of MLflow are loaded
490      successfully by ``mlflow.pyfunc.load_model``. These older models specify a pyfunc ``data``
491      field referring directly to a XGBoost model file. Newer models also have the
492      ``model_class`` in XGBoost flavor.
493      """
494      model = xgb_model.model
495      # Use xgb format explicitly since this test verifies backward compatibility with old models
496      mlflow.xgboost.save_model(xgb_model=model, path=model_path, model_format="xgb")
497  
498      model_conf_path = os.path.join(model_path, "MLmodel")
499      model_conf = Model.load(model_conf_path)
500      pyfunc_conf = model_conf.flavors.get(pyfunc.FLAVOR_NAME)
501      xgboost_conf = model_conf.flavors.get(mlflow.xgboost.FLAVOR_NAME)
502      assert xgboost_conf is not None
503      assert "model_class" in xgboost_conf
504      assert "data" in xgboost_conf
505      assert pyfunc_conf is not None
506      assert "model_class" not in pyfunc_conf
507      assert pyfunc.DATA in pyfunc_conf
508  
509      # test old MLmodel conf
510      model_conf.flavors["xgboost"] = {"xgb_version": xgb.__version__, "data": "model.xgb"}
511      model_conf.save(model_conf_path)
512      model_conf = Model.load(model_conf_path)
513      xgboost_conf = model_conf.flavors.get(mlflow.xgboost.FLAVOR_NAME)
514      assert "data" in xgboost_conf
515      assert xgboost_conf["data"] == "model.xgb"
516  
517      reloaded_pyfunc = pyfunc.load_model(model_uri=model_path)
518      assert isinstance(reloaded_pyfunc._model_impl.xgb_model, xgb.Booster)
519      reloaded_xgb = mlflow.xgboost.load_model(model_uri=model_path)
520      assert isinstance(reloaded_xgb, xgb.Booster)
521  
522      np.testing.assert_array_almost_equal(
523          xgb_model.model.predict(xgb_model.inference_dmatrix),
524          reloaded_pyfunc.predict(xgb_model.inference_dataframe),
525      )
526  
527      np.testing.assert_array_almost_equal(
528          reloaded_xgb.predict(xgb_model.inference_dmatrix),
529          reloaded_pyfunc.predict(xgb_model.inference_dataframe),
530      )
531  
532  
533  def test_log_model_with_code_paths(xgb_model):
534      artifact_path = "model"
535      with (
536          mlflow.start_run(),
537          mock.patch("mlflow.xgboost._add_code_from_conf_to_system_path") as add_mock,
538      ):
539          model_info = mlflow.xgboost.log_model(
540              xgb_model.model, name=artifact_path, code_paths=[__file__]
541          )
542          _compare_logged_code_paths(__file__, model_info.model_uri, mlflow.xgboost.FLAVOR_NAME)
543          mlflow.xgboost.load_model(model_uri=model_info.model_uri)
544          add_mock.assert_called()
545  
546  
547  def test_virtualenv_subfield_points_to_correct_path(xgb_model, model_path):
548      mlflow.xgboost.save_model(xgb_model.model, path=model_path)
549      pyfunc_conf = _get_flavor_configuration(model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME)
550      python_env_path = Path(model_path, pyfunc_conf[pyfunc.ENV]["virtualenv"])
551      assert python_env_path.exists()
552      assert python_env_path.is_file()
553  
554  
555  @pytest.mark.parametrize("model_format", ["xgb", "json", "ubj"])
556  def test_log_model_with_model_format(xgb_model, model_format):
557      with mlflow.start_run():
558          model_info = mlflow.xgboost.log_model(
559              xgb_model.model, name="model", model_format=model_format
560          )
561          loaded_model = mlflow.xgboost.load_model(model_info.model_uri)
562          np.testing.assert_array_almost_equal(
563              xgb_model.model.predict(xgb_model.inference_dmatrix),
564              loaded_model.predict(xgb_model.inference_dmatrix),
565          )
566  
567  
568  def test_model_save_load_with_metadata(xgb_model, model_path):
569      mlflow.xgboost.save_model(
570          xgb_model.model, path=model_path, metadata={"metadata_key": "metadata_value"}
571      )
572  
573      reloaded_model = mlflow.pyfunc.load_model(model_uri=model_path)
574      assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value"
575  
576  
577  def test_model_log_with_metadata(xgb_model):
578      with mlflow.start_run():
579          model_info = mlflow.xgboost.log_model(
580              xgb_model.model,
581              name="model",
582              metadata={"metadata_key": "metadata_value"},
583          )
584  
585      reloaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri)
586      assert reloaded_model.metadata.metadata["metadata_key"] == "metadata_value"
587  
588  
589  def test_model_log_with_signature_inference(xgb_model, xgb_model_signature):
590      artifact_path = "model"
591      X = xgb_model.inference_dataframe
592      example = X.iloc[[0]]
593  
594      with mlflow.start_run():
595          model_info = mlflow.xgboost.log_model(
596              xgb_model.model, name=artifact_path, input_example=example
597          )
598  
599      mlflow_model = Model.load(model_info.model_uri)
600      assert mlflow_model.signature == xgb_model_signature
601  
602  
603  def test_model_without_signature_predict(xgb_model):
604      artifact_path = "model"
605      X = xgb_model.inference_dataframe
606      example = X.iloc[[0]]
607  
608      with mlflow.start_run():
609          model_info = mlflow.xgboost.log_model(xgb_model.model, name=artifact_path)
610  
611      loaded_model = mlflow.pyfunc.load_model(model_uri=model_info.model_uri)
612      data = pd.DataFrame(example).to_dict(orient="split")
613      parsed_data = dataframe_from_parsed_json(data, pandas_orient="split")
614      loaded_model.predict(parsed_data)
615  
616  
617  def test_get_raw_model(xgb_model):
618      with mlflow.start_run():
619          model_info = mlflow.xgboost.log_model(
620              xgb_model.model, name="model", input_example=xgb_model.inference_dataframe.head(3)
621          )
622      pyfunc_model = pyfunc.load_model(model_info.model_uri)
623      raw_model = pyfunc_model.get_raw_model()
624      assert type(raw_model) == type(xgb_model.model)
625      np.testing.assert_array_almost_equal(
626          raw_model.predict(xgb_model.inference_dmatrix),
627          xgb_model.model.predict(xgb_model.inference_dmatrix),
628      )
629  
630  
631  def test_xgbooster_predict_exclude_invalid_params(xgb_model):
632      signature = infer_signature(
633          xgb_model.inference_dataframe.head(3), params={"invalid_param": 1, "approx_contribs": True}
634      )
635      with mlflow.start_run():
636          model_info = mlflow.xgboost.log_model(xgb_model.model, name="model", signature=signature)
637      pyfunc_model = pyfunc.load_model(model_info.model_uri)
638      with mock.patch("mlflow.xgboost._logger.warning") as mock_warning:
639          np.testing.assert_array_almost_equal(
640              pyfunc_model.predict(
641                  xgb_model.inference_dataframe, params={"invalid_param": 2, "approx_contribs": True}
642              ),
643              xgb_model.model.predict(xgb_model.inference_dmatrix, approx_contribs=True),
644          )
645          mock_warning.assert_called_once_with(
646              "Params {'invalid_param'} are not accepted by the xgboost model, "
647              "ignoring them during predict."
648          )
649  
650  
651  def test_xgbmodel_predict_exclude_invalid_params(xgb_sklearn_model):
652      signature = infer_signature(
653          xgb_sklearn_model.inference_dataframe.head(3),
654          params={"invalid_param": 1, "output_margin": True},
655      )
656      with mlflow.start_run():
657          model_info = mlflow.xgboost.log_model(
658              xgb_sklearn_model.model, name="model", signature=signature
659          )
660      pyfunc_model = pyfunc.load_model(model_info.model_uri)
661      with mock.patch("mlflow.xgboost._logger.warning") as mock_warning:
662          np.testing.assert_array_almost_equal(
663              pyfunc_model.predict(
664                  xgb_sklearn_model.inference_dataframe,
665                  params={"invalid_param": 2, "output_margin": True},
666              ),
667              xgb_sklearn_model.model.predict(
668                  xgb_sklearn_model.inference_dataframe, output_margin=True
669              ),
670          )
671          mock_warning.assert_called_once_with(
672              "Params {'invalid_param'} are not accepted by the xgboost model, "
673              "ignoring them during predict."
674          )
675  
676  
677  def test_exclude_unrecognized_kwargs():
678      def custom_func(*args, **kwargs):
679          return [1, 2, 3]
680  
681      def custom_func2(data, **kwargs):
682          return [2, 3, 4]
683  
684      def custom_func3(x, y):
685          return x + y
686  
687      params = {"data": 1, "x": 1, "y": 2, "z": 3}
688      assert _exclude_unrecognized_kwargs(custom_func, params) == params
689      assert _exclude_unrecognized_kwargs(custom_func2, params) == params
690      assert _exclude_unrecognized_kwargs(custom_func3, params) == {"x": 1, "y": 2}