test_wheeled_model.py
1 import os 2 import random 3 import re 4 from io import BytesIO 5 from typing import Any, NamedTuple 6 from unittest import mock 7 8 import numpy as np 9 import pandas as pd 10 import pytest 11 import sklearn.neighbors as knn 12 import yaml 13 from sklearn import datasets 14 15 import mlflow 16 import mlflow.pyfunc.scoring_server as pyfunc_scoring_server 17 from mlflow.exceptions import MlflowException 18 from mlflow.models.model import METADATA_FILES 19 from mlflow.models.utils import load_serving_example 20 from mlflow.models.wheeled_model import _ORIGINAL_REQ_FILE_NAME, _WHEELS_FOLDER_NAME, WheeledModel 21 from mlflow.pyfunc.model import MLMODEL_FILE_NAME, Model 22 from mlflow.store.artifact.utils.models import _improper_model_uri_msg 23 from mlflow.tracking.artifact_utils import _download_artifact_from_uri 24 from mlflow.utils.environment import ( 25 _CONDA_ENV_FILE_NAME, 26 _REQUIREMENTS_FILE_NAME, 27 _is_pip_deps, 28 _mlflow_conda_env, 29 ) 30 31 from tests.helper_functions import ( 32 _is_available_on_pypi, 33 _mlflow_major_version_string, 34 pyfunc_serve_and_score_model, 35 ) 36 37 EXTRA_PYFUNC_SERVING_TEST_ARGS = ( 38 [] if _is_available_on_pypi("scikit-learn", module="sklearn") else ["--env-manager", "local"] 39 ) 40 41 42 class ModelWithData(NamedTuple): 43 model: Any 44 inference_data: Any 45 46 47 @pytest.fixture(scope="module") 48 def sklearn_knn_model(): 49 iris = datasets.load_iris() 50 X = iris.data[:, :2] # we only take the first two features. 51 y = iris.target 52 knn_model = knn.KNeighborsClassifier() 53 knn_model.fit(X, y) 54 return ModelWithData(model=knn_model, inference_data=X) 55 56 57 def random_int(lo=1, hi=1000000000): 58 return random.randint(int(lo), int(hi)) 59 60 61 def _get_list_from_file(path): 62 with open(path) as file: 63 return file.read().splitlines() 64 65 66 def _get_pip_requirements_list(path): 67 return _get_list_from_file(path) 68 69 70 def get_pip_requirements_from_conda_file(conda_env_path): 71 with open(conda_env_path) as f: 72 conda_env = yaml.safe_load(f) 73 74 conda_pip_requirements_list = [] 75 dependencies = conda_env.get("dependencies") 76 77 for dependency in dependencies: 78 if _is_pip_deps(dependency): 79 conda_pip_requirements_list = dependency["pip"] 80 81 return conda_pip_requirements_list 82 83 84 def validate_updated_model_file(original_model_config, wheeled_model_config): 85 differing_keys = {"run_id", "utc_time_created", "model_uuid", "artifact_path"} 86 ignore_keys = {"model_id"} 87 88 # Compare wheeled model configs with original model config (MLModel files) 89 for key in original_model_config.keys() - ignore_keys: 90 if key not in differing_keys: 91 assert wheeled_model_config[key] == original_model_config[key] 92 else: 93 assert wheeled_model_config[key] != original_model_config[key] 94 95 # Wheeled model key should only exist in wheeled_model_config 96 assert wheeled_model_config.get(_WHEELS_FOLDER_NAME, None) 97 assert not original_model_config.get(_WHEELS_FOLDER_NAME, None) 98 99 # Every key in the original config should also exist in the wheeled config. 100 for key in original_model_config: 101 assert key in wheeled_model_config 102 103 104 def validate_updated_conda_dependencies(original_model_path, wheeled_model_path): 105 # Check if conda.yaml files of the original model and wheeled model are the same 106 # excluding the dependencies 107 wheeled_model_path = os.path.join(wheeled_model_path, _CONDA_ENV_FILE_NAME) 108 original_conda_env_path = os.path.join(original_model_path, _CONDA_ENV_FILE_NAME) 109 110 with ( 111 open(wheeled_model_path) as wheeled_conda_env, 112 open(original_conda_env_path) as original_conda_env, 113 ): 114 wheeled_conda_env = yaml.safe_load(wheeled_conda_env) 115 original_conda_env = yaml.safe_load(original_conda_env) 116 117 for key in wheeled_conda_env: 118 if key != "dependencies": 119 assert wheeled_conda_env[key] == original_conda_env[key] 120 else: 121 assert wheeled_conda_env[key] != original_conda_env[key] 122 123 124 def validate_wheeled_dependencies(wheeled_model_path): 125 # Check if conda.yaml and requirements.txt are consistent 126 pip_requirements_path = os.path.join(wheeled_model_path, _REQUIREMENTS_FILE_NAME) 127 pip_requirements_list = _get_pip_requirements_list(pip_requirements_path) 128 conda_pip_requirements_list = get_pip_requirements_from_conda_file( 129 os.path.join(wheeled_model_path, _CONDA_ENV_FILE_NAME) 130 ) 131 132 pip_requirements_list.sort() 133 conda_pip_requirements_list.sort() 134 assert pip_requirements_list == conda_pip_requirements_list 135 136 # Check if requirements.txt and wheels directory are consistent 137 wheels_dir = os.path.join(wheeled_model_path, _WHEELS_FOLDER_NAME) 138 wheels_list = [] 139 for wheel_file in os.listdir(wheels_dir): 140 if wheel_file.endswith(".whl"): 141 relative_wheel_path = os.path.join(_WHEELS_FOLDER_NAME, wheel_file) 142 wheels_list.append(relative_wheel_path) 143 144 wheels_list.sort() 145 assert wheels_list == pip_requirements_list 146 147 148 def test_model_log_load(tmp_path, sklearn_knn_model): 149 model_name = f"wheels-test-{random_int()}" 150 model_uri = f"models:/{model_name}/1" 151 wheeled_model_uri = f"models:/{model_name}/2" 152 artifact_path = "model" 153 154 # Log a model 155 with mlflow.start_run(): 156 mlflow.sklearn.log_model( 157 sklearn_knn_model.model, 158 name=artifact_path, 159 registered_model_name=model_name, 160 ) 161 model_path = _download_artifact_from_uri(model_uri, tmp_path) 162 original_model_config = Model.load(os.path.join(model_path, MLMODEL_FILE_NAME)).__dict__ 163 164 # Re-log with wheels 165 with mlflow.start_run(): 166 WheeledModel.log_model(model_uri=model_uri) 167 wheeled_model_path = _download_artifact_from_uri(wheeled_model_uri) 168 wheeled_model_run_id = mlflow.tracking.fluent._get_or_start_run().info.run_id 169 wheeled_model_config = Model.load( 170 os.path.join(wheeled_model_path, MLMODEL_FILE_NAME) 171 ).__dict__ 172 173 validate_updated_model_file(original_model_config, wheeled_model_config) 174 # Assert correct run_id 175 assert wheeled_model_config["run_id"] == wheeled_model_run_id 176 177 validate_updated_conda_dependencies(model_path, wheeled_model_path) 178 179 validate_wheeled_dependencies(wheeled_model_path) 180 181 182 def test_model_save_load(tmp_path, sklearn_knn_model): 183 model_name = f"wheels-test-{random_int()}" 184 model_uri = f"models:/{model_name}/1" 185 artifact_path = "model" 186 model_download_path = os.path.join(tmp_path, "m") 187 wheeled_model_path = os.path.join(tmp_path, "wm") 188 189 os.mkdir(model_download_path) 190 # Log a model 191 with mlflow.start_run(): 192 mlflow.sklearn.log_model( 193 sklearn_knn_model.model, 194 name=artifact_path, 195 registered_model_name=model_name, 196 ) 197 model_path = _download_artifact_from_uri(model_uri, model_download_path) 198 original_model_config = Model.load(os.path.join(model_path, MLMODEL_FILE_NAME)).__dict__ 199 200 # Save with wheels 201 with mlflow.start_run(): 202 wheeled_model = WheeledModel(model_uri=model_uri) 203 wheeled_model_data = wheeled_model.save_model(path=wheeled_model_path) 204 wheeled_model_config = Model.load(os.path.join(wheeled_model_path, MLMODEL_FILE_NAME)) 205 wheeled_model_config_dict = wheeled_model_config.__dict__ 206 207 # Check to see if python model returned is the same as the MLModel file 208 assert wheeled_model_config == wheeled_model_data 209 210 validate_updated_model_file(original_model_config, wheeled_model_config_dict) 211 validate_updated_conda_dependencies(model_path, wheeled_model_path) 212 validate_wheeled_dependencies(wheeled_model_path) 213 214 215 def test_logging_and_saving_wheeled_model_throws(tmp_path, sklearn_knn_model): 216 model_name = f"wheels-test-{random_int()}" 217 model_uri = f"models:/{model_name}/1" 218 wheeled_model_uri = f"models:/{model_name}/2" 219 artifact_path = "model" 220 221 # Log a model 222 with mlflow.start_run(): 223 mlflow.sklearn.log_model( 224 sklearn_knn_model.model, 225 name=artifact_path, 226 registered_model_name=model_name, 227 ) 228 229 # Re-log with wheels 230 with mlflow.start_run(): 231 WheeledModel.log_model( 232 model_uri=model_uri, 233 ) 234 235 match = "Model libraries are already added" 236 237 # Log wheeled model 238 with pytest.raises(MlflowException, match=re.escape(match)): 239 with mlflow.start_run(): 240 WheeledModel.log_model( 241 model_uri=wheeled_model_uri, 242 ) 243 244 # Saved a wheeled model 245 saved_model_path = os.path.join(tmp_path, "test") 246 with pytest.raises(MlflowException, match=re.escape(match)): 247 with mlflow.start_run(): 248 WheeledModel(wheeled_model_uri).save_model(saved_model_path) 249 250 251 def test_log_model_with_non_model_uri(): 252 model_uri = "runs:/beefe0b6b5bd4acf9938244cdc006b64/model" 253 254 # Log with wheels 255 with pytest.raises(MlflowException, match=_improper_model_uri_msg(model_uri)): 256 with mlflow.start_run(): 257 WheeledModel.log_model( 258 model_uri=model_uri, 259 ) 260 261 # Save with wheels 262 with pytest.raises(MlflowException, match=_improper_model_uri_msg(model_uri)): 263 with mlflow.start_run(): 264 WheeledModel(model_uri) 265 266 267 def test_create_pip_requirement(tmp_path): 268 expected_mlflow_version = _mlflow_major_version_string() 269 model_name = f"wheels-test-{random_int()}" 270 model_uri = f"models:/{model_name}/1" 271 conda_env_path = os.path.join(tmp_path, "conda.yaml") 272 pip_reqs_path = os.path.join(tmp_path, "requirements.txt") 273 274 wm = WheeledModel(model_uri) 275 276 expected_pip_deps = [expected_mlflow_version, "cloudpickle==2.1.0", "psutil==5.8.0"] 277 _mlflow_conda_env( 278 path=conda_env_path, additional_pip_deps=expected_pip_deps, install_mlflow=False 279 ) 280 wm._create_pip_requirement(conda_env_path, pip_reqs_path) 281 with open(pip_reqs_path) as f: 282 pip_reqs = [x.strip() for x in f] 283 assert expected_pip_deps.sort() == pip_reqs.sort() 284 285 286 def test_update_conda_env_only_updates_pip_deps(tmp_path): 287 expected_mlflow_version = _mlflow_major_version_string() 288 model_name = f"wheels-test-{random_int()}" 289 model_uri = f"models:/{model_name}/1" 290 conda_env_path = os.path.join(tmp_path, "conda.yaml") 291 pip_deps = [expected_mlflow_version, "cloudpickle==2.1.0", "psutil==5.8.0"] 292 new_pip_deps = ["wheels/mlflow", "wheels/cloudpickle", "wheels/psutil"] 293 294 wm = WheeledModel(model_uri) 295 additional_conda_deps = ["add_conda_deps"] 296 additional_conda_channels = ["add_conda_channels"] 297 298 _mlflow_conda_env( 299 conda_env_path, 300 additional_conda_deps, 301 pip_deps, 302 additional_conda_channels, 303 install_mlflow=False, 304 ) 305 with open(conda_env_path) as f: 306 old_conda_yaml = yaml.safe_load(f) 307 wm._update_conda_env(new_pip_deps, conda_env_path) 308 with open(conda_env_path) as f: 309 new_conda_yaml = yaml.safe_load(f) 310 assert old_conda_yaml.get("name") == new_conda_yaml.get("name") 311 assert old_conda_yaml.get("channels") == new_conda_yaml.get("channels") 312 for old_item, new_item in zip( 313 old_conda_yaml.get("dependencies"), new_conda_yaml.get("dependencies") 314 ): 315 if isinstance(old_item, str): 316 assert old_item == new_item 317 if isinstance(old_item, dict): 318 assert old_item.get("pip") == pip_deps 319 if isinstance(new_item, dict): 320 assert new_item.get("pip") == new_pip_deps 321 322 323 def test_serving_wheeled_model(sklearn_knn_model): 324 model_name = f"wheels-test-{random_int()}" 325 model_uri = f"models:/{model_name}/1" 326 wheeled_model_uri = f"models:/{model_name}/2" 327 artifact_path = "model" 328 (model, inference_data) = sklearn_knn_model 329 330 # Log a model 331 with mlflow.start_run(): 332 model_info = mlflow.sklearn.log_model( 333 model, 334 name=artifact_path, 335 registered_model_name=model_name, 336 input_example=pd.DataFrame(inference_data), 337 ) 338 339 # Re-log with wheels 340 with mlflow.start_run(): 341 WheeledModel.log_model(model_uri=model_uri) 342 343 inference_payload = load_serving_example(model_info.model_uri) 344 resp = pyfunc_serve_and_score_model( 345 wheeled_model_uri, 346 data=inference_payload, 347 content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON, 348 extra_args=EXTRA_PYFUNC_SERVING_TEST_ARGS, 349 ) 350 scores = pd.read_json(BytesIO(resp.content), orient="records").values.squeeze() 351 np.testing.assert_array_almost_equal(scores, model.predict(inference_data)) 352 353 354 def test_wheel_download_works(tmp_path): 355 simple_dependency = "cloudpickle" 356 requirements_file = os.path.join(tmp_path, "req.txt") 357 wheel_dir = os.path.join(tmp_path, "wheels") 358 with open(requirements_file, "w") as req_file: 359 req_file.write(simple_dependency) 360 361 WheeledModel._download_wheels(requirements_file, wheel_dir) 362 wheels = os.listdir(wheel_dir) 363 assert len(wheels) == 1 # Only a single wheel is downloaded 364 assert wheels[0].endswith(".whl") # Type is wheel 365 assert simple_dependency in wheels[0] # Cloudpickle wheel downloaded 366 367 368 def test_wheel_download_override_option_works(tmp_path, monkeypatch): 369 dependency = "pyspark" 370 requirements_file = os.path.join(tmp_path, "req.txt") 371 wheel_dir = os.path.join(tmp_path, "wheels") 372 with open(requirements_file, "w") as req_file: 373 req_file.write(dependency) 374 375 # Default option fails to download wheel 376 with pytest.raises( 377 MlflowException, match="An error occurred while downloading the dependency wheels" 378 ): 379 WheeledModel._download_wheels(requirements_file, wheel_dir) 380 381 # Set option override 382 monkeypatch.setenv("MLFLOW_WHEELED_MODEL_PIP_DOWNLOAD_OPTIONS", "--prefer-binary") 383 WheeledModel._download_wheels(requirements_file, wheel_dir) 384 assert len(os.listdir(wheel_dir)) # Wheel dir is not empty 385 386 387 def test_wheel_download_dependency_conflicts(tmp_path): 388 reqs_file = tmp_path / "requirements.txt" 389 reqs_file.write_text("mlflow==2.15.0\nmlflow==2.16.0") 390 with pytest.raises( 391 MlflowException, 392 # Ensure the error message contains conflict details 393 match=r"Cannot install mlflow==2\.15\.0 and mlflow==2\.16\.0.+The conflict is caused by", 394 ): 395 WheeledModel._download_wheels(reqs_file, tmp_path / "wheels") 396 397 398 def test_copy_metadata(mock_is_in_databricks, sklearn_knn_model): 399 with mlflow.start_run(): 400 mlflow.sklearn.log_model( 401 sklearn_knn_model.model, 402 name="model", 403 registered_model_name="sklearn_knn_model", 404 ) 405 406 with mlflow.start_run(): 407 model_info = WheeledModel.log_model(model_uri="models:/sklearn_knn_model/1") 408 409 artifact_path = mlflow.artifacts.download_artifacts(model_info.model_uri) 410 metadata_path = os.path.join(artifact_path, "metadata") 411 if mock_is_in_databricks.return_value: 412 assert set(os.listdir(metadata_path)) == set(METADATA_FILES + [_ORIGINAL_REQ_FILE_NAME]) 413 else: 414 assert not os.path.exists(metadata_path) 415 assert mock_is_in_databricks.call_count == 2 416 417 418 def test_wheel_download_prevents_command_injection(tmp_path, monkeypatch): 419 malicious_attempts = [ 420 "--only-binary=:all: && echo pwned", 421 "--prefer-binary; rm -rf /", 422 "--no-binary=:none: | cat /etc/passwd", 423 "../../../etc/passwd", 424 "--extra-index-url http://evil.com", 425 "--find-links /tmp", 426 "--index-url http://malicious.com", 427 "--trusted-host evil.com", 428 "--only-binary=package`rm -rf /`", 429 "--config-settings malicious=value", 430 ] 431 432 for malicious_option in malicious_attempts: 433 monkeypatch.setenv("MLFLOW_WHEELED_MODEL_PIP_DOWNLOAD_OPTIONS", malicious_option) 434 with pytest.raises(MlflowException, match="Invalid pip wheel option"): 435 WheeledModel._download_wheels(tmp_path / "req.txt", tmp_path / "wheels") 436 437 438 def test_wheel_download_allowed_options(tmp_path, monkeypatch): 439 allowed_options = [ 440 "--only-binary=:all:", 441 "--only-binary=:none:", 442 "--no-binary=:all:", 443 "--no-binary=:none:", 444 "--prefer-binary", 445 "--no-build-isolation", 446 "--use-pep517", 447 "--check-build-dependencies", 448 "--ignore-requires-python", 449 "--no-deps", 450 "--no-verify", 451 "--pre", 452 "--require-hashes", 453 "--no-clean", 454 ] 455 456 for option in allowed_options: 457 monkeypatch.setenv("MLFLOW_WHEELED_MODEL_PIP_DOWNLOAD_OPTIONS", option) 458 with mock.patch("subprocess.run") as mock_run: 459 WheeledModel._download_wheels(tmp_path / "req.txt", tmp_path / "wheels") 460 mock_run.assert_called_once() 461 assert option in mock_run.call_args[0][0] 462 463 # test combination of options 464 monkeypatch.setenv("MLFLOW_WHEELED_MODEL_PIP_DOWNLOAD_OPTIONS", "--prefer-binary --no-clean") 465 with mock.patch("subprocess.run") as mock_run: 466 WheeledModel._download_wheels(tmp_path / "req.txt", tmp_path / "wheels") 467 mock_run.assert_called_once() 468 call_args = mock_run.call_args 469 assert "--prefer-binary --no-clean" in call_args[0][0] 470 471 472 def test_wheel_download_extra_envs(tmp_path, monkeypatch): 473 monkeypatch.setenv("MLFLOW_WHEELED_MODEL_PIP_DOWNLOAD_OPTIONS", "--prefer-binary") 474 extra_envs = { 475 "PIP_INDEX_URL": "https://test.pypi.org/simple/", 476 "PIP_TRUSTED_HOST": "test.pypi.org", 477 "CUSTOM_VAR": "test_value", 478 } 479 480 with mock.patch("subprocess.run") as mock_run: 481 mock_run.return_value = mock.Mock(returncode=0) 482 483 WheeledModel._download_wheels( 484 tmp_path / "req.txt", tmp_path / "wheels", extra_envs=extra_envs 485 ) 486 487 mock_run.assert_called_once() 488 call_args = mock_run.call_args 489 assert "--prefer-binary" in call_args[0][0] 490 passed_env = call_args[1]["env"] 491 assert passed_env["PIP_INDEX_URL"] == "https://test.pypi.org/simple/" 492 assert passed_env["PIP_TRUSTED_HOST"] == "test.pypi.org" 493 assert passed_env["CUSTOM_VAR"] == "test_value" 494 495 # Verify original environment variables are preserved 496 assert passed_env["PATH"] == os.environ["PATH"] 497 498 499 def test_wheel_download_no_extra_envs(tmp_path, monkeypatch): 500 monkeypatch.setenv("MLFLOW_WHEELED_MODEL_PIP_DOWNLOAD_OPTIONS", "--prefer-binary") 501 502 with mock.patch("subprocess.run") as mock_run: 503 mock_run.return_value = mock.Mock(returncode=0) 504 505 WheeledModel._download_wheels(tmp_path / "req.txt", tmp_path / "wheels", extra_envs=None) 506 mock_run.assert_called_once() 507 call_args = mock_run.call_args 508 assert call_args[1]["env"] is None