/ tests / pyfunc / test_pyfunc_model_config.py
test_pyfunc_model_config.py
  1  import os
  2  
  3  import pytest
  4  import yaml
  5  
  6  import mlflow
  7  from mlflow.models import Model
  8  
  9  
 10  @pytest.fixture
 11  def model_path(tmp_path):
 12      return os.path.join(tmp_path, "model")
 13  
 14  
 15  @pytest.fixture
 16  def model_config():
 17      return {
 18          "use_gpu": True,
 19          "temperature": 0.9,
 20          "timeout": 300,
 21      }
 22  
 23  
 24  def _load_pyfunc(path):
 25      return TestModel()
 26  
 27  
 28  class TestModel(mlflow.pyfunc.PythonModel):
 29      def predict(self, context, model_input, params=None):
 30          return model_input
 31  
 32  
 33  class InferenceContextModel(mlflow.pyfunc.PythonModel):
 34      def predict(self, context, model_input, params=None):
 35          # This mock class returns the internal inference configuration keys and values available
 36          return context.model_config.items()
 37  
 38  
 39  def test_save_with_model_config(model_path, model_config):
 40      model = InferenceContextModel()
 41      mlflow.pyfunc.save_model(model_path, python_model=model, model_config=model_config)
 42  
 43      loaded_model = mlflow.pyfunc.load_model(model_uri=model_path)
 44  
 45      assert loaded_model.model_config
 46      assert set(model_config.keys()) == set(loaded_model.model_config)
 47      assert all(loaded_model.model_config[k] == v for k, v in model_config.items())
 48      assert all(loaded_model.model_config[k] == v for k, v in loaded_model.predict([[0]]))
 49  
 50  
 51  @pytest.mark.parametrize(
 52      "model_config_path",
 53      [
 54          os.path.abspath("tests/pyfunc/sample_code/config.yml"),
 55          "tests/pyfunc/../pyfunc/sample_code/config.yml",
 56      ],
 57  )
 58  def test_save_with_model_config_path(model_path, model_config, model_config_path):
 59      model = InferenceContextModel()
 60      mlflow.pyfunc.save_model(model_path, python_model=model, model_config=model_config_path)
 61  
 62      loaded_model = mlflow.pyfunc.load_model(model_uri=model_path)
 63  
 64      assert loaded_model.model_config
 65      assert set(model_config.keys()) == set(loaded_model.model_config)
 66      assert all(loaded_model.model_config[k] == v for k, v in model_config.items())
 67      assert all(loaded_model.model_config[k] == v for k, v in loaded_model.predict([[0]]))
 68  
 69  
 70  def test_override_model_config(model_path, model_config):
 71      model = TestModel()
 72      inference_override = {"timeout": 400}
 73  
 74      mlflow.pyfunc.save_model(model_path, python_model=model, model_config=model_config)
 75      loaded_model = mlflow.pyfunc.load_model(model_uri=model_path, model_config=inference_override)
 76  
 77      assert loaded_model.model_config["timeout"] == 400
 78      assert all(loaded_model.model_config[k] == v for k, v in inference_override.items())
 79  
 80  
 81  @pytest.mark.parametrize(
 82      "model_config_path",
 83      [
 84          os.path.abspath("tests/pyfunc/sample_code/config.yml"),
 85          "tests/pyfunc/../pyfunc/sample_code/config.yml",
 86      ],
 87  )
 88  def test_override_model_config_path(tmp_path, model_path, model_config_path):
 89      model = TestModel()
 90      inference_override = {"timeout": 400}
 91      config_path = tmp_path / "config.yml"
 92      config_path.write_text(yaml.dump(inference_override))
 93  
 94      mlflow.pyfunc.save_model(model_path, python_model=model, model_config=model_config_path)
 95      loaded_model = mlflow.pyfunc.load_model(model_uri=model_path, model_config=str(config_path))
 96  
 97      assert loaded_model.model_config["timeout"] == 400
 98      assert all(loaded_model.model_config[k] == v for k, v in inference_override.items())
 99  
100  
101  def test_override_model_config_ignore_invalid(model_path, model_config):
102      model = TestModel()
103      inference_override = {"invalid_key": 400}
104  
105      mlflow.pyfunc.save_model(model_path, python_model=model, model_config=model_config)
106      loaded_model = mlflow.pyfunc.load_model(model_uri=model_path, model_config=inference_override)
107  
108      assert loaded_model.predict([[5]])
109      assert all(k not in loaded_model.model_config for k in inference_override.keys())
110  
111  
112  @pytest.mark.parametrize(
113      "model_config_path",
114      [
115          os.path.abspath("tests/pyfunc/sample_code/config.yml"),
116          "tests/pyfunc/../pyfunc/sample_code/config.yml",
117      ],
118  )
119  def test_override_model_config_path_ignore_invalid(tmp_path, model_path, model_config_path):
120      model = TestModel()
121      inference_override = {"invalid_key": 400}
122      config_path = tmp_path / "config.yml"
123      config_path.write_text(yaml.dump(inference_override))
124  
125      mlflow.pyfunc.save_model(model_path, python_model=model, model_config=model_config_path)
126      loaded_model = mlflow.pyfunc.load_model(model_uri=model_path, model_config=str(config_path))
127  
128      assert loaded_model.predict([[5]])
129      assert all(k not in loaded_model.model_config for k in inference_override.keys())
130  
131  
132  def test_pyfunc_without_model_config(model_path, model_config):
133      model = TestModel()
134      mlflow.pyfunc.save_model(model_path, python_model=model)
135  
136      loaded_model = mlflow.pyfunc.load_model(model_uri=model_path, model_config=model_config)
137  
138      assert loaded_model.predict([[5]])
139      assert not loaded_model.model_config
140  
141  
142  def test_pyfunc_loader_without_model_config(model_path):
143      mlflow.pyfunc.save_model(
144          path=model_path,
145          data_path=".",
146          loader_module=__name__,
147          code_paths=[__file__],
148          mlflow_model=Model(run_id="test", artifact_path="testtest"),
149      )
150  
151      inference_override = {"invalid_key": 400}
152      pyfunc_model = mlflow.pyfunc.load_model(model_path, model_config=inference_override)
153  
154      assert not pyfunc_model.model_config