test_model_export_with_loader_module_and_data_path.py
1 import os 2 import pickle 3 import types 4 from unittest import mock 5 6 import cloudpickle 7 import numpy as np 8 import pytest 9 import sklearn.datasets 10 import sklearn.neighbors 11 import yaml 12 13 import mlflow 14 import mlflow.pyfunc 15 from mlflow.exceptions import MlflowException 16 from mlflow.models import Model, infer_signature 17 from mlflow.models.utils import _read_example 18 from mlflow.tracking.artifact_utils import _download_artifact_from_uri 19 from mlflow.utils.environment import _mlflow_conda_env 20 from mlflow.utils.file_utils import TempDir 21 from mlflow.utils.model_utils import _get_flavor_configuration 22 23 from tests.helper_functions import _assert_pip_requirements 24 25 26 def _load_pyfunc(path): 27 with open(path, "rb") as f: 28 return pickle.load(f, encoding="latin1") 29 30 31 @pytest.fixture 32 def pyfunc_custom_env_file(tmp_path): 33 conda_env = os.path.join(tmp_path, "conda_env.yml") 34 _mlflow_conda_env( 35 conda_env, 36 additional_pip_deps=[ 37 "scikit-learn", 38 "pytest", 39 "cloudpickle", 40 "-e " + os.path.dirname(mlflow.__path__[0]), 41 ], 42 ) 43 return conda_env 44 45 46 @pytest.fixture 47 def pyfunc_custom_env_dict(): 48 return _mlflow_conda_env( 49 additional_pip_deps=[ 50 "scikit-learn", 51 "pytest", 52 "cloudpickle", 53 "-e " + os.path.dirname(mlflow.__path__[0]), 54 ], 55 ) 56 57 58 @pytest.fixture(scope="module") 59 def iris_data(): 60 iris = sklearn.datasets.load_iris() 61 x = iris.data[:, :2] 62 y = iris.target 63 return x, y 64 65 66 @pytest.fixture(scope="module") 67 def sklearn_knn_model(iris_data): 68 x, y = iris_data 69 knn_model = sklearn.neighbors.KNeighborsClassifier() 70 knn_model.fit(x, y) 71 return knn_model 72 73 74 @pytest.fixture 75 def model_path(tmp_path): 76 return os.path.join(tmp_path, "model") 77 78 79 def test_model_save_load(sklearn_knn_model, iris_data, tmp_path, model_path): 80 sk_model_path = os.path.join(tmp_path, "knn.pkl") 81 with open(sk_model_path, "wb") as f: 82 pickle.dump(sklearn_knn_model, f) 83 84 model_config = Model(run_id="test", artifact_path="testtest") 85 mlflow.pyfunc.save_model( 86 path=model_path, 87 data_path=sk_model_path, 88 loader_module=__name__, 89 code_paths=[__file__], 90 mlflow_model=model_config, 91 ) 92 93 reloaded_model_config = Model.load(os.path.join(model_path, "MLmodel")) 94 assert model_config.__dict__ == reloaded_model_config.__dict__ 95 assert mlflow.pyfunc.FLAVOR_NAME in reloaded_model_config.flavors 96 assert mlflow.pyfunc.PY_VERSION in reloaded_model_config.flavors[mlflow.pyfunc.FLAVOR_NAME] 97 reloaded_model = mlflow.pyfunc.load_model(model_path) 98 np.testing.assert_array_equal( 99 sklearn_knn_model.predict(iris_data[0]), reloaded_model.predict(iris_data[0]) 100 ) 101 102 103 def test_signature_and_examples_are_saved_correctly(sklearn_knn_model, iris_data): 104 data = iris_data 105 signature_ = infer_signature(*data) 106 example_ = data[0][:3] 107 for signature in (None, signature_): 108 for example in (None, example_): 109 with TempDir() as tmp: 110 with open(tmp.path("skmodel"), "wb") as f: 111 pickle.dump(sklearn_knn_model, f) 112 path = tmp.path("model") 113 mlflow.pyfunc.save_model( 114 path=path, 115 data_path=tmp.path("skmodel"), 116 loader_module=__name__, 117 code_paths=[__file__], 118 signature=signature, 119 input_example=example, 120 ) 121 mlflow_model = Model.load(path) 122 assert signature == mlflow_model.signature 123 if example is None: 124 assert mlflow_model.saved_input_example_info is None 125 else: 126 np.testing.assert_array_equal(_read_example(mlflow_model, path), example) 127 128 129 def test_model_log_load(sklearn_knn_model, iris_data, tmp_path): 130 sk_model_path = os.path.join(tmp_path, "knn.pkl") 131 with open(sk_model_path, "wb") as f: 132 pickle.dump(sklearn_knn_model, f) 133 134 pyfunc_artifact_path = "pyfunc_model" 135 with mlflow.start_run(): 136 mlflow.pyfunc.log_model( 137 name=pyfunc_artifact_path, 138 data_path=sk_model_path, 139 loader_module=__name__, 140 code_paths=[__file__], 141 ) 142 pyfunc_model_path = _download_artifact_from_uri( 143 f"runs:/{mlflow.active_run().info.run_id}/{pyfunc_artifact_path}" 144 ) 145 146 model_config = Model.load(os.path.join(pyfunc_model_path, "MLmodel")) 147 assert mlflow.pyfunc.FLAVOR_NAME in model_config.flavors 148 assert mlflow.pyfunc.PY_VERSION in model_config.flavors[mlflow.pyfunc.FLAVOR_NAME] 149 reloaded_model = mlflow.pyfunc.load_model(pyfunc_model_path) 150 assert model_config.to_yaml() == reloaded_model.metadata.to_yaml() 151 np.testing.assert_array_equal( 152 sklearn_knn_model.predict(iris_data[0]), reloaded_model.predict(iris_data[0]) 153 ) 154 155 156 @pytest.mark.skip( 157 reason="In MLflow 3.0, `log_model` does not start a run. Consider removing this test." 158 ) 159 def test_model_log_load_no_active_run(sklearn_knn_model, iris_data, tmp_path): 160 sk_model_path = os.path.join(tmp_path, "knn.pkl") 161 with open(sk_model_path, "wb") as f: 162 pickle.dump(sklearn_knn_model, f) 163 164 pyfunc_artifact_path = "pyfunc_model" 165 assert mlflow.active_run() is None 166 mlflow.pyfunc.log_model( 167 name=pyfunc_artifact_path, 168 data_path=sk_model_path, 169 loader_module=__name__, 170 code_paths=[__file__], 171 ) 172 pyfunc_model_path = _download_artifact_from_uri( 173 f"runs:/{mlflow.active_run().info.run_id}/{pyfunc_artifact_path}" 174 ) 175 176 model_config = Model.load(os.path.join(pyfunc_model_path, "MLmodel")) 177 assert mlflow.pyfunc.FLAVOR_NAME in model_config.flavors 178 assert mlflow.pyfunc.PY_VERSION in model_config.flavors[mlflow.pyfunc.FLAVOR_NAME] 179 reloaded_model = mlflow.pyfunc.load_model(pyfunc_model_path) 180 np.testing.assert_array_equal( 181 sklearn_knn_model.predict(iris_data[0]), reloaded_model.predict(iris_data[0]) 182 ) 183 mlflow.end_run() 184 185 186 def test_save_model_with_unsupported_argument_combinations_throws_exception(model_path): 187 with pytest.raises( 188 MlflowException, match="Either `loader_module` or `python_model` must be specified" 189 ): 190 mlflow.pyfunc.save_model(path=model_path, data_path="/path/to/data") 191 192 193 def test_log_model_with_unsupported_argument_combinations_throws_exception(): 194 with ( 195 mlflow.start_run(), 196 pytest.raises( 197 MlflowException, match="Either `loader_module` or `python_model` must be specified" 198 ), 199 ): 200 mlflow.pyfunc.log_model(name="pyfunc_model", data_path="/path/to/data") 201 202 203 def test_log_model_persists_specified_conda_env_file_in_mlflow_model_directory( 204 sklearn_knn_model, tmp_path, pyfunc_custom_env_file 205 ): 206 sk_model_path = os.path.join(tmp_path, "knn.pkl") 207 with open(sk_model_path, "wb") as f: 208 pickle.dump(sklearn_knn_model, f) 209 210 pyfunc_artifact_path = "pyfunc_model" 211 with mlflow.start_run(): 212 mlflow.pyfunc.log_model( 213 name=pyfunc_artifact_path, 214 data_path=sk_model_path, 215 loader_module=__name__, 216 code_paths=[__file__], 217 conda_env=pyfunc_custom_env_file, 218 ) 219 run_id = mlflow.active_run().info.run_id 220 221 pyfunc_model_path = _download_artifact_from_uri(f"runs:/{run_id}/{pyfunc_artifact_path}") 222 223 pyfunc_conf = _get_flavor_configuration( 224 model_path=pyfunc_model_path, flavor_name=mlflow.pyfunc.FLAVOR_NAME 225 ) 226 saved_conda_env_path = os.path.join(pyfunc_model_path, pyfunc_conf[mlflow.pyfunc.ENV]["conda"]) 227 assert os.path.exists(saved_conda_env_path) 228 assert saved_conda_env_path != pyfunc_custom_env_file 229 230 with open(pyfunc_custom_env_file) as f: 231 pyfunc_custom_env_parsed = yaml.safe_load(f) 232 with open(saved_conda_env_path) as f: 233 saved_conda_env_parsed = yaml.safe_load(f) 234 assert saved_conda_env_parsed == pyfunc_custom_env_parsed 235 236 237 def test_log_model_persists_specified_conda_env_dict_in_mlflow_model_directory( 238 sklearn_knn_model, tmp_path, pyfunc_custom_env_dict 239 ): 240 sk_model_path = os.path.join(tmp_path, "knn.pkl") 241 with open(sk_model_path, "wb") as f: 242 pickle.dump(sklearn_knn_model, f) 243 244 pyfunc_artifact_path = "pyfunc_model" 245 with mlflow.start_run(): 246 mlflow.pyfunc.log_model( 247 name=pyfunc_artifact_path, 248 data_path=sk_model_path, 249 loader_module=__name__, 250 code_paths=[__file__], 251 conda_env=pyfunc_custom_env_dict, 252 ) 253 run_id = mlflow.active_run().info.run_id 254 255 pyfunc_model_path = _download_artifact_from_uri(f"runs:/{run_id}/{pyfunc_artifact_path}") 256 257 pyfunc_conf = _get_flavor_configuration( 258 model_path=pyfunc_model_path, flavor_name=mlflow.pyfunc.FLAVOR_NAME 259 ) 260 saved_conda_env_path = os.path.join(pyfunc_model_path, pyfunc_conf[mlflow.pyfunc.ENV]["conda"]) 261 assert os.path.exists(saved_conda_env_path) 262 263 with open(saved_conda_env_path) as f: 264 saved_conda_env_parsed = yaml.safe_load(f) 265 assert saved_conda_env_parsed == pyfunc_custom_env_dict 266 267 268 def test_log_model_persists_requirements_in_mlflow_model_directory( 269 sklearn_knn_model, tmp_path, pyfunc_custom_env_dict 270 ): 271 sk_model_path = os.path.join(tmp_path, "knn.pkl") 272 with open(sk_model_path, "wb") as f: 273 pickle.dump(sklearn_knn_model, f) 274 275 pyfunc_artifact_path = "pyfunc_model" 276 with mlflow.start_run(): 277 mlflow.pyfunc.log_model( 278 name=pyfunc_artifact_path, 279 data_path=sk_model_path, 280 loader_module=__name__, 281 code_paths=[__file__], 282 conda_env=pyfunc_custom_env_dict, 283 ) 284 run_id = mlflow.active_run().info.run_id 285 286 pyfunc_model_path = _download_artifact_from_uri(f"runs:/{run_id}/{pyfunc_artifact_path}") 287 288 saved_pip_req_path = os.path.join(pyfunc_model_path, "requirements.txt") 289 assert os.path.exists(saved_pip_req_path) 290 291 with open(saved_pip_req_path) as f: 292 requirements = f.read().split("\n") 293 294 assert pyfunc_custom_env_dict["dependencies"][-1]["pip"] == requirements 295 296 297 def test_log_model_without_specified_conda_env_uses_default_env_with_expected_dependencies( 298 sklearn_knn_model, tmp_path 299 ): 300 sk_model_path = os.path.join(tmp_path, "knn.pkl") 301 with open(sk_model_path, "wb") as f: 302 pickle.dump(sklearn_knn_model, f) 303 304 pyfunc_artifact_path = "pyfunc_model" 305 with mlflow.start_run(): 306 model_info = mlflow.pyfunc.log_model( 307 name=pyfunc_artifact_path, 308 data_path=sk_model_path, 309 loader_module=__name__, 310 code_paths=[__file__], 311 ) 312 _assert_pip_requirements(model_info.model_uri, mlflow.pyfunc.get_default_pip_requirements()) 313 314 315 def test_streamable_model_save_load(tmp_path, model_path): 316 class StreamableModel: 317 def __init__(self): 318 pass 319 320 def predict(self, model_input, params=None): 321 pass 322 323 def predict_stream(self, model_input, params=None): 324 yield "test1" 325 yield "test2" 326 327 custom_model = StreamableModel() 328 329 custom_model_path = os.path.join(tmp_path, "model.pkl") 330 with open(custom_model_path, "wb") as f: 331 cloudpickle.dump(custom_model, f) 332 333 model_config = Model(run_id="test", artifact_path="testtest") 334 mlflow.pyfunc.save_model( 335 path=model_path, 336 data_path=custom_model_path, 337 loader_module=__name__, 338 code_paths=[__file__], 339 mlflow_model=model_config, 340 streamable=True, 341 ) 342 loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri=model_path) 343 344 stream_result = loaded_pyfunc_model.predict_stream("single-input") 345 assert isinstance(stream_result, types.GeneratorType) 346 347 assert list(stream_result) == ["test1", "test2"] 348 349 350 def test_log_loader_module_model_does_not_emit_pickle_warning(sklearn_knn_model, tmp_path): 351 sk_model_path = tmp_path / "knn.pkl" 352 with open(sk_model_path, "wb") as f: 353 pickle.dump(sklearn_knn_model, f) 354 355 with mlflow.start_run(), mock.patch("mlflow.pyfunc._logger.warning") as mock_log_warning: 356 mlflow.pyfunc.log_model( 357 name="pyfunc_model", 358 data_path=sk_model_path, 359 loader_module=__name__, 360 code_paths=[__file__], 361 ) 362 363 warning_messages = [args[0] for args, _ in mock_log_warning.call_args_list if args] 364 assert not any( 365 "Passing a Python object as `python_model` causes it to be serialized using CloudPickle" 366 in msg 367 for msg in warning_messages 368 )