test_docker.py
1 import difflib 2 import os 3 import shutil 4 from dataclasses import dataclass 5 from pathlib import Path 6 from unittest import mock 7 8 import pytest 9 import sklearn 10 import sklearn.neighbors 11 from packaging.version import Version 12 13 import mlflow 14 from mlflow.environment_variables import _MLFLOW_RUN_SLOW_TESTS 15 from mlflow.models import Model 16 from mlflow.models.docker_utils import build_image_from_context 17 from mlflow.models.flavor_backend_registry import get_flavor_backend 18 from mlflow.utils import PYTHON_VERSION 19 from mlflow.utils.env_manager import CONDA, LOCAL, VIRTUALENV 20 from mlflow.version import VERSION 21 22 from tests.pyfunc.docker.conftest import RESOURCE_DIR, get_released_mlflow_version 23 24 25 def _get_mlflow_install_specifier(): 26 if Version(VERSION).is_devrelease: 27 return "https://github.com/mlflow/mlflow/archive/refs/heads/master.zip" 28 return f"mlflow=={VERSION}" 29 30 31 def assert_dockerfiles_equal(actual_dockerfile_path: Path, expected_dockerfile_path: Path): 32 actual_dockerfile = actual_dockerfile_path.read_text() 33 expected_dockerfile = ( 34 expected_dockerfile_path 35 .read_text() 36 .replace("${{ MLFLOW_INSTALL }}", _get_mlflow_install_specifier()) 37 .replace("${{ PYTHON_VERSION }}", PYTHON_VERSION) 38 ) 39 assert actual_dockerfile == expected_dockerfile, ( 40 "Generated Dockerfile does not match expected one. Diff:\n" 41 + "\n".join( 42 difflib.unified_diff(expected_dockerfile.splitlines(), actual_dockerfile.splitlines()) 43 ) 44 ) 45 46 47 def save_model(tmp_path): 48 knn_model = sklearn.neighbors.KNeighborsClassifier() 49 model_path = os.path.join(tmp_path, "model") 50 mlflow.sklearn.save_model( 51 knn_model, 52 path=model_path, 53 pip_requirements=[ 54 f"mlflow=={get_released_mlflow_version()}", 55 f"scikit-learn=={sklearn.__version__}", 56 ], # Skip requirements inference for speed up 57 ) 58 return model_path 59 60 61 def add_spark_flavor_to_model(model_path): 62 model_config_path = os.path.join(model_path, "MLmodel") 63 model = Model.load(model_config_path) 64 model.add_flavor("spark", spark_version="3.5.0") 65 model.save(model_config_path) 66 67 68 @dataclass 69 class Param: 70 expected_dockerfile: str 71 env_manager: str | None = None 72 mlflow_home: str | None = None 73 install_mlflow: bool = False 74 enable_mlserver: bool = False 75 # If True, image is built with --model-uri param 76 specify_model_uri: bool = True 77 78 79 @pytest.mark.parametrize( 80 "params", 81 [ 82 Param(expected_dockerfile="Dockerfile_default"), 83 Param(expected_dockerfile="Dockerfile_default", env_manager=LOCAL), 84 Param(expected_dockerfile="Dockerfile_java_flavor", env_manager=VIRTUALENV), 85 Param(expected_dockerfile="Dockerfile_conda", env_manager=CONDA), 86 Param(install_mlflow=True, expected_dockerfile="Dockerfile_install_mlflow"), 87 Param(enable_mlserver=True, expected_dockerfile="Dockerfile_enable_mlserver"), 88 Param(mlflow_home=".", expected_dockerfile="Dockerfile_with_mlflow_home"), 89 Param(specify_model_uri=False, expected_dockerfile="Dockerfile_no_model_uri"), 90 ], 91 ) 92 def test_build_image(tmp_path, params): 93 model_uri = save_model(tmp_path) if params.specify_model_uri else None 94 95 backend = get_flavor_backend(model_uri, docker_build=True, env_manager=params.env_manager) 96 97 # Copy the context dir to a temp dir so we can verify the generated Dockerfile 98 def _build_image_with_copy(context_dir, image_name): 99 shutil.copytree(context_dir, dst_dir) 100 # Build the image if the slow-tests flag is enabled 101 if _MLFLOW_RUN_SLOW_TESTS.get(): 102 for _ in range(3): 103 try: 104 # Docker image build is unstable on GitHub Actions, retry up to 3 times 105 build_image_from_context(context_dir, image_name) 106 break 107 except RuntimeError: 108 pass 109 else: 110 raise RuntimeError("Docker image build failed.") 111 112 dst_dir = tmp_path / "context" 113 with mock.patch( 114 "mlflow.models.docker_utils.build_image_from_context", 115 side_effect=_build_image_with_copy, 116 ): 117 backend.build_image( 118 model_uri=model_uri, 119 image_name="test_image", 120 mlflow_home=params.mlflow_home, 121 install_mlflow=params.install_mlflow, 122 enable_mlserver=params.enable_mlserver, 123 ) 124 125 actual = dst_dir / "Dockerfile" 126 expected = Path(RESOURCE_DIR) / params.expected_dockerfile 127 assert_dockerfiles_equal(actual, expected) 128 129 130 def test_generate_dockerfile_for_java_flavor(tmp_path): 131 model_path = save_model(tmp_path) 132 add_spark_flavor_to_model(model_path) 133 134 backend = get_flavor_backend(model_path, docker_build=True, env_manager=None) 135 136 backend.generate_dockerfile( 137 model_uri=model_path, 138 output_dir=tmp_path, 139 ) 140 141 actual = tmp_path / "Dockerfile" 142 expected = Path(RESOURCE_DIR) / "Dockerfile_java_flavor" 143 assert_dockerfiles_equal(actual, expected) 144 145 146 def test_generate_dockerfile_for_custom_image(tmp_path): 147 model_path = save_model(tmp_path) 148 add_spark_flavor_to_model(model_path) 149 150 backend = get_flavor_backend(model_path, docker_build=True, env_manager=None) 151 152 backend.generate_dockerfile( 153 base_image="quay.io/jupyter/scipy-notebook:latest", 154 model_uri=model_path, 155 output_dir=tmp_path, 156 ) 157 158 actual = tmp_path / "Dockerfile" 159 expected = Path(RESOURCE_DIR) / "Dockerfile_custom_scipy" 160 assert_dockerfiles_equal(actual, expected)