test_docker_projects.py
1 import os 2 from unittest import mock 3 4 import docker 5 import pytest 6 7 import mlflow 8 from mlflow import MlflowClient 9 from mlflow.entities import ViewType 10 from mlflow.environment_variables import MLFLOW_TRACKING_URI 11 from mlflow.exceptions import MlflowException 12 from mlflow.legacy_databricks_cli.configure.provider import DatabricksConfig 13 from mlflow.projects import ExecutionException, _project_spec 14 from mlflow.projects.backend.local import _get_docker_command 15 from mlflow.projects.docker import _get_docker_image_uri 16 from mlflow.store.tracking import file_store 17 from mlflow.utils.mlflow_tags import ( 18 MLFLOW_DOCKER_IMAGE_ID, 19 MLFLOW_DOCKER_IMAGE_URI, 20 MLFLOW_PROJECT_BACKEND, 21 MLFLOW_PROJECT_ENV, 22 ) 23 24 from tests.projects.utils import ( 25 TEST_DOCKER_PROJECT_DIR, 26 docker_example_base_image, # noqa: F401 27 ) 28 29 30 def _build_uri(base_uri, subdirectory): 31 if subdirectory != "": 32 return f"{base_uri}#{subdirectory}" 33 return base_uri 34 35 36 @pytest.mark.parametrize("use_start_run", map(str, [0, 1])) 37 def test_docker_project_execution(use_start_run, docker_example_base_image): 38 expected_params = {"use_start_run": use_start_run} 39 submitted_run = mlflow.projects.run( 40 TEST_DOCKER_PROJECT_DIR, 41 experiment_id=file_store.FileStore.DEFAULT_EXPERIMENT_ID, 42 parameters=expected_params, 43 entry_point="test_tracking", 44 build_image=True, 45 docker_args={"memory": "1g", "privileged": True}, 46 ) 47 # Validate run contents in the FileStore 48 run_id = submitted_run.run_id 49 mlflow_service = MlflowClient() 50 runs = mlflow_service.search_runs( 51 [file_store.FileStore.DEFAULT_EXPERIMENT_ID], run_view_type=ViewType.ACTIVE_ONLY 52 ) 53 assert len(runs) == 1 54 store_run_id = runs[0].info.run_id 55 assert run_id == store_run_id 56 run = mlflow_service.get_run(run_id) 57 assert run.data.params == expected_params 58 assert run.data.metrics == {"some_key": 3} 59 exact_expected_tags = { 60 MLFLOW_PROJECT_ENV: "docker", 61 MLFLOW_PROJECT_BACKEND: "local", 62 } 63 approx_expected_tags = { 64 MLFLOW_DOCKER_IMAGE_URI: "docker-example", 65 MLFLOW_DOCKER_IMAGE_ID: "sha256:", 66 } 67 run_tags = run.data.tags 68 for k, v in exact_expected_tags.items(): 69 assert run_tags[k] == v 70 for k, v in approx_expected_tags.items(): 71 assert run_tags[k].startswith(v) 72 artifacts = mlflow_service.list_artifacts(run_id=run_id) 73 assert len(artifacts) == 1 74 docker_cmd = submitted_run.command_proc.args[2] 75 assert "--memory 1g" in docker_cmd 76 assert "--privileged" in docker_cmd 77 78 79 def test_docker_project_execution_async_docker_args( 80 docker_example_base_image, 81 ): 82 submitted_run = mlflow.projects.run( 83 TEST_DOCKER_PROJECT_DIR, 84 experiment_id=file_store.FileStore.DEFAULT_EXPERIMENT_ID, 85 parameters={"use_start_run": "0"}, 86 entry_point="test_tracking", 87 docker_args={"memory": "1g", "privileged": True}, 88 synchronous=False, 89 ) 90 submitted_run.wait() 91 92 args = submitted_run.command_proc.args 93 assert len([a for a in args if a == "--docker-args"]) == 2 94 first_idx = args.index("--docker-args") 95 second_idx = args.index("--docker-args", first_idx + 1) 96 assert args[first_idx + 1] == "memory=1g" 97 assert args[second_idx + 1] == "privileged" 98 99 100 @pytest.mark.parametrize( 101 ("tracking_uri", "expected_command_segment"), 102 [ 103 (None, "-e MLFLOW_TRACKING_URI=/mlflow/tmp/mlruns"), 104 ("http://some-tracking-uri", "-e MLFLOW_TRACKING_URI=http://some-tracking-uri"), 105 ("databricks://some-profile", "-e MLFLOW_TRACKING_URI=databricks "), 106 ], 107 ) 108 def test_docker_project_tracking_uri_propagation( 109 tmp_path, 110 tracking_uri, 111 expected_command_segment, 112 docker_example_base_image, 113 ): 114 pytest.skip("FileStore is no longer supported.") 115 mock_provider = mock.MagicMock() 116 mock_provider.get_config.return_value = DatabricksConfig.from_password( 117 "host", "user", "pass", insecure=True 118 ) 119 # Create and mock local tracking directory 120 local_tracking_dir = os.path.join(tmp_path, "mlruns") 121 if tracking_uri is None: 122 tracking_uri = local_tracking_dir 123 old_uri = mlflow.get_tracking_uri() 124 with ( 125 mock.patch( 126 "mlflow.utils.databricks_utils.ProfileConfigProvider", return_value=mock_provider 127 ), 128 mock.patch( 129 "mlflow.tracking._tracking_service.utils._get_store", 130 return_value=file_store.FileStore(local_tracking_dir), 131 ), 132 ): 133 try: 134 mlflow.set_tracking_uri(tracking_uri) 135 mlflow.projects.run( 136 TEST_DOCKER_PROJECT_DIR, 137 experiment_id=file_store.FileStore.DEFAULT_EXPERIMENT_ID, 138 ) 139 finally: 140 mlflow.set_tracking_uri(old_uri) 141 142 143 def test_docker_uri_mode_validation(docker_example_base_image): 144 with pytest.raises(ExecutionException, match="When running on Databricks"): 145 mlflow.projects.run(TEST_DOCKER_PROJECT_DIR, backend="databricks", backend_config={}) 146 147 148 def test_docker_image_uri_with_git(): 149 with mock.patch("mlflow.projects.docker.get_git_commit") as get_git_commit_mock: 150 get_git_commit_mock.return_value = "1234567890" 151 image_uri = _get_docker_image_uri("my_project", "my_workdir") 152 assert image_uri == "my_project:1234567" 153 get_git_commit_mock.assert_called_with("my_workdir") 154 155 156 def test_docker_image_uri_no_git(): 157 with mock.patch("mlflow.projects.docker.get_git_commit", return_value=None) as mock_commit: 158 image_uri = _get_docker_image_uri("my_project", "my_workdir") 159 assert image_uri == "my_project" 160 mock_commit.assert_called_with("my_workdir") 161 162 163 def test_docker_valid_project_backend_local(): 164 work_dir = "./examples/docker" 165 project = _project_spec.load_project(work_dir) 166 mlflow.projects.docker.validate_docker_env(project) 167 168 169 def test_docker_invalid_project_backend_local(): 170 work_dir = "./examples/docker" 171 project = _project_spec.load_project(work_dir) 172 project.name = None 173 with pytest.raises(ExecutionException, match="Project name in MLProject must be specified"): 174 mlflow.projects.docker.validate_docker_env(project) 175 176 177 @pytest.mark.parametrize( 178 ("artifact_uri", "host_artifact_uri", "container_artifact_uri", "should_mount"), 179 [ 180 ("/tmp/mlruns/artifacts", "/tmp/mlruns/artifacts", "/tmp/mlruns/artifacts", True), 181 ("s3://my_bucket", None, None, False), 182 ("file:///tmp/mlruns/artifacts", "/tmp/mlruns/artifacts", "/tmp/mlruns/artifacts", True), 183 ("./mlruns", os.path.abspath("./mlruns"), "/mlflow/projects/code/mlruns", True), 184 ], 185 ) 186 def test_docker_mount_local_artifact_uri( 187 artifact_uri, host_artifact_uri, container_artifact_uri, should_mount 188 ): 189 active_run = mock.MagicMock() 190 run_info = mock.MagicMock() 191 run_info.run_id = "fake_run_id" 192 run_info.experiment_id = "fake_experiment_id" 193 run_info.artifact_uri = artifact_uri 194 active_run.info = run_info 195 image = mock.MagicMock() 196 image.tags = ["image:tag"] 197 198 docker_command = _get_docker_command(image, active_run) 199 200 docker_volume_expected = f"-v {host_artifact_uri}:{container_artifact_uri}" 201 assert (docker_volume_expected in " ".join(docker_command)) == should_mount 202 203 204 def test_docker_databricks_tracking_cmd_and_envs(): 205 mock_provider = mock.MagicMock() 206 mock_provider.get_config.return_value = DatabricksConfig.from_password( 207 "host", "user", "pass", insecure=True 208 ) 209 with mock.patch( 210 "mlflow.utils.databricks_utils.ProfileConfigProvider", return_value=mock_provider 211 ): 212 cmds, envs = mlflow.projects.docker.get_docker_tracking_cmd_and_envs( 213 "databricks://some-profile" 214 ) 215 assert envs == { 216 "DATABRICKS_HOST": "host", 217 "DATABRICKS_USERNAME": "user", 218 "DATABRICKS_PASSWORD": "pass", 219 "DATABRICKS_INSECURE": "True", 220 MLFLOW_TRACKING_URI.name: "databricks", 221 } 222 assert cmds == [] 223 224 225 @pytest.mark.parametrize( 226 ("volumes", "environment", "os_environ", "expected"), 227 [ 228 ([], ["VAR1"], {"VAR1": "value1"}, [("-e", "VAR1=value1")]), 229 ([], ["VAR1"], {}, ["should_crash", ("-e", "VAR1=value1")]), 230 ([], ["VAR1"], {"OTHER_VAR": "value1"}, ["should_crash", ("-e", "VAR1=value1")]), 231 ( 232 [], 233 ["VAR1", ["VAR2", "value2"]], 234 {"VAR1": "value1"}, 235 [("-e", "VAR1=value1"), ("-e", "VAR2=value2")], 236 ), 237 ([], [["VAR2", "value2"]], {"VAR1": "value1"}, [("-e", "VAR2=value2")]), 238 ( 239 ["/path:/path"], 240 ["VAR1"], 241 {"VAR1": "value1"}, 242 [("-e", "VAR1=value1"), ("-v", "/path:/path")], 243 ), 244 ( 245 ["/path:/path"], 246 [["VAR2", "value2"]], 247 {"VAR1": "value1"}, 248 [("-e", "VAR2=value2"), ("-v", "/path:/path")], 249 ), 250 ], 251 ) 252 def test_docker_user_specified_env_vars(volumes, environment, expected, os_environ, monkeypatch): 253 active_run = mock.MagicMock() 254 run_info = mock.MagicMock() 255 run_info.run_id = "fake_run_id" 256 run_info.experiment_id = "fake_experiment_id" 257 run_info.artifact_uri = "/tmp/mlruns/artifacts" 258 active_run.info = run_info 259 image = mock.MagicMock() 260 image.tags = ["image:tag"] 261 262 for name, value in os_environ.items(): 263 monkeypatch.setenv(name, value) 264 if "should_crash" in expected: 265 expected.remove("should_crash") 266 with pytest.raises(MlflowException, match="This project expects"): 267 _get_docker_command(image, active_run, None, volumes, environment) 268 else: 269 docker_command = _get_docker_command(image, active_run, None, volumes, environment) 270 for exp_type, expected in expected: 271 assert expected in docker_command 272 assert docker_command[docker_command.index(expected) - 1] == exp_type 273 274 275 @pytest.mark.parametrize("docker_args", [{}, {"ARG": "VAL"}, {"ARG1": "VAL1", "ARG2": "VAL2"}]) 276 def test_docker_run_args(docker_args): 277 active_run = mock.MagicMock() 278 run_info = mock.MagicMock() 279 run_info.run_id = "fake_run_id" 280 run_info.experiment_id = "fake_experiment_id" 281 run_info.artifact_uri = "/tmp/mlruns/artifacts" 282 active_run.info = run_info 283 image = mock.MagicMock() 284 image.tags = ["image:tag"] 285 286 docker_command = _get_docker_command(image, active_run, docker_args, None, None) 287 288 for flag, value in docker_args.items(): 289 assert docker_command[docker_command.index(value) - 1] == f"--{flag}" 290 291 292 def test_docker_build_image_local(tmp_path): 293 client = docker.from_env() 294 dockerfile = tmp_path.joinpath("Dockerfile") 295 dockerfile.write_text( 296 """ 297 FROM python:3.10 298 RUN pip --version 299 """ 300 ) 301 client.images.build(path=str(tmp_path), dockerfile=str(dockerfile), tag="my-python:latest") 302 tmp_path.joinpath("MLproject").write_text( 303 """ 304 name: test 305 docker_env: 306 image: my-python 307 entry_points: 308 main: 309 command: python --version 310 """ 311 ) 312 submitted_run = mlflow.projects.run(str(tmp_path)) 313 run = mlflow.get_run(submitted_run.run_id) 314 assert run.data.tags[MLFLOW_DOCKER_IMAGE_URI] == "my-python" 315 316 317 def test_docker_build_image_remote(tmp_path): 318 tmp_path.joinpath("MLproject").write_text( 319 """ 320 name: test 321 docker_env: 322 image: python:3.9 323 entry_points: 324 main: 325 command: python --version 326 """ 327 ) 328 submitted_run = mlflow.projects.run(str(tmp_path)) 329 run = mlflow.get_run(submitted_run.run_id) 330 assert run.data.tags[MLFLOW_DOCKER_IMAGE_URI] == "python:3.9"