test_pyfunc_model_config.py
1 import os 2 3 import pytest 4 import yaml 5 6 import mlflow 7 from mlflow.models import Model 8 9 10 @pytest.fixture 11 def model_path(tmp_path): 12 return os.path.join(tmp_path, "model") 13 14 15 @pytest.fixture 16 def model_config(): 17 return { 18 "use_gpu": True, 19 "temperature": 0.9, 20 "timeout": 300, 21 } 22 23 24 def _load_pyfunc(path): 25 return TestModel() 26 27 28 class TestModel(mlflow.pyfunc.PythonModel): 29 def predict(self, context, model_input, params=None): 30 return model_input 31 32 33 class InferenceContextModel(mlflow.pyfunc.PythonModel): 34 def predict(self, context, model_input, params=None): 35 # This mock class returns the internal inference configuration keys and values available 36 return context.model_config.items() 37 38 39 def test_save_with_model_config(model_path, model_config): 40 model = InferenceContextModel() 41 mlflow.pyfunc.save_model(model_path, python_model=model, model_config=model_config) 42 43 loaded_model = mlflow.pyfunc.load_model(model_uri=model_path) 44 45 assert loaded_model.model_config 46 assert set(model_config.keys()) == set(loaded_model.model_config) 47 assert all(loaded_model.model_config[k] == v for k, v in model_config.items()) 48 assert all(loaded_model.model_config[k] == v for k, v in loaded_model.predict([[0]])) 49 50 51 @pytest.mark.parametrize( 52 "model_config_path", 53 [ 54 os.path.abspath("tests/pyfunc/sample_code/config.yml"), 55 "tests/pyfunc/../pyfunc/sample_code/config.yml", 56 ], 57 ) 58 def test_save_with_model_config_path(model_path, model_config, model_config_path): 59 model = InferenceContextModel() 60 mlflow.pyfunc.save_model(model_path, python_model=model, model_config=model_config_path) 61 62 loaded_model = mlflow.pyfunc.load_model(model_uri=model_path) 63 64 assert loaded_model.model_config 65 assert set(model_config.keys()) == set(loaded_model.model_config) 66 assert all(loaded_model.model_config[k] == v for k, v in model_config.items()) 67 assert all(loaded_model.model_config[k] == v for k, v in loaded_model.predict([[0]])) 68 69 70 def test_override_model_config(model_path, model_config): 71 model = TestModel() 72 inference_override = {"timeout": 400} 73 74 mlflow.pyfunc.save_model(model_path, python_model=model, model_config=model_config) 75 loaded_model = mlflow.pyfunc.load_model(model_uri=model_path, model_config=inference_override) 76 77 assert loaded_model.model_config["timeout"] == 400 78 assert all(loaded_model.model_config[k] == v for k, v in inference_override.items()) 79 80 81 @pytest.mark.parametrize( 82 "model_config_path", 83 [ 84 os.path.abspath("tests/pyfunc/sample_code/config.yml"), 85 "tests/pyfunc/../pyfunc/sample_code/config.yml", 86 ], 87 ) 88 def test_override_model_config_path(tmp_path, model_path, model_config_path): 89 model = TestModel() 90 inference_override = {"timeout": 400} 91 config_path = tmp_path / "config.yml" 92 config_path.write_text(yaml.dump(inference_override)) 93 94 mlflow.pyfunc.save_model(model_path, python_model=model, model_config=model_config_path) 95 loaded_model = mlflow.pyfunc.load_model(model_uri=model_path, model_config=str(config_path)) 96 97 assert loaded_model.model_config["timeout"] == 400 98 assert all(loaded_model.model_config[k] == v for k, v in inference_override.items()) 99 100 101 def test_override_model_config_ignore_invalid(model_path, model_config): 102 model = TestModel() 103 inference_override = {"invalid_key": 400} 104 105 mlflow.pyfunc.save_model(model_path, python_model=model, model_config=model_config) 106 loaded_model = mlflow.pyfunc.load_model(model_uri=model_path, model_config=inference_override) 107 108 assert loaded_model.predict([[5]]) 109 assert all(k not in loaded_model.model_config for k in inference_override.keys()) 110 111 112 @pytest.mark.parametrize( 113 "model_config_path", 114 [ 115 os.path.abspath("tests/pyfunc/sample_code/config.yml"), 116 "tests/pyfunc/../pyfunc/sample_code/config.yml", 117 ], 118 ) 119 def test_override_model_config_path_ignore_invalid(tmp_path, model_path, model_config_path): 120 model = TestModel() 121 inference_override = {"invalid_key": 400} 122 config_path = tmp_path / "config.yml" 123 config_path.write_text(yaml.dump(inference_override)) 124 125 mlflow.pyfunc.save_model(model_path, python_model=model, model_config=model_config_path) 126 loaded_model = mlflow.pyfunc.load_model(model_uri=model_path, model_config=str(config_path)) 127 128 assert loaded_model.predict([[5]]) 129 assert all(k not in loaded_model.model_config for k in inference_override.keys()) 130 131 132 def test_pyfunc_without_model_config(model_path, model_config): 133 model = TestModel() 134 mlflow.pyfunc.save_model(model_path, python_model=model) 135 136 loaded_model = mlflow.pyfunc.load_model(model_uri=model_path, model_config=model_config) 137 138 assert loaded_model.predict([[5]]) 139 assert not loaded_model.model_config 140 141 142 def test_pyfunc_loader_without_model_config(model_path): 143 mlflow.pyfunc.save_model( 144 path=model_path, 145 data_path=".", 146 loader_module=__name__, 147 code_paths=[__file__], 148 mlflow_model=Model(run_id="test", artifact_path="testtest"), 149 ) 150 151 inference_override = {"invalid_key": 400} 152 pyfunc_model = mlflow.pyfunc.load_model(model_path, model_config=inference_override) 153 154 assert not pyfunc_model.model_config