/ tests / projects / test_docker_projects.py
test_docker_projects.py
  1  import os
  2  from unittest import mock
  3  
  4  import docker
  5  import pytest
  6  
  7  import mlflow
  8  from mlflow import MlflowClient
  9  from mlflow.entities import ViewType
 10  from mlflow.environment_variables import MLFLOW_TRACKING_URI
 11  from mlflow.exceptions import MlflowException
 12  from mlflow.legacy_databricks_cli.configure.provider import DatabricksConfig
 13  from mlflow.projects import ExecutionException, _project_spec
 14  from mlflow.projects.backend.local import _get_docker_command
 15  from mlflow.projects.docker import _get_docker_image_uri
 16  from mlflow.store.tracking import file_store
 17  from mlflow.utils.mlflow_tags import (
 18      MLFLOW_DOCKER_IMAGE_ID,
 19      MLFLOW_DOCKER_IMAGE_URI,
 20      MLFLOW_PROJECT_BACKEND,
 21      MLFLOW_PROJECT_ENV,
 22  )
 23  
 24  from tests.projects.utils import (
 25      TEST_DOCKER_PROJECT_DIR,
 26      docker_example_base_image,  # noqa: F401
 27  )
 28  
 29  
 30  def _build_uri(base_uri, subdirectory):
 31      if subdirectory != "":
 32          return f"{base_uri}#{subdirectory}"
 33      return base_uri
 34  
 35  
 36  @pytest.mark.parametrize("use_start_run", map(str, [0, 1]))
 37  def test_docker_project_execution(use_start_run, docker_example_base_image):
 38      expected_params = {"use_start_run": use_start_run}
 39      submitted_run = mlflow.projects.run(
 40          TEST_DOCKER_PROJECT_DIR,
 41          experiment_id=file_store.FileStore.DEFAULT_EXPERIMENT_ID,
 42          parameters=expected_params,
 43          entry_point="test_tracking",
 44          build_image=True,
 45          docker_args={"memory": "1g", "privileged": True},
 46      )
 47      # Validate run contents in the FileStore
 48      run_id = submitted_run.run_id
 49      mlflow_service = MlflowClient()
 50      runs = mlflow_service.search_runs(
 51          [file_store.FileStore.DEFAULT_EXPERIMENT_ID], run_view_type=ViewType.ACTIVE_ONLY
 52      )
 53      assert len(runs) == 1
 54      store_run_id = runs[0].info.run_id
 55      assert run_id == store_run_id
 56      run = mlflow_service.get_run(run_id)
 57      assert run.data.params == expected_params
 58      assert run.data.metrics == {"some_key": 3}
 59      exact_expected_tags = {
 60          MLFLOW_PROJECT_ENV: "docker",
 61          MLFLOW_PROJECT_BACKEND: "local",
 62      }
 63      approx_expected_tags = {
 64          MLFLOW_DOCKER_IMAGE_URI: "docker-example",
 65          MLFLOW_DOCKER_IMAGE_ID: "sha256:",
 66      }
 67      run_tags = run.data.tags
 68      for k, v in exact_expected_tags.items():
 69          assert run_tags[k] == v
 70      for k, v in approx_expected_tags.items():
 71          assert run_tags[k].startswith(v)
 72      artifacts = mlflow_service.list_artifacts(run_id=run_id)
 73      assert len(artifacts) == 1
 74      docker_cmd = submitted_run.command_proc.args[2]
 75      assert "--memory 1g" in docker_cmd
 76      assert "--privileged" in docker_cmd
 77  
 78  
 79  def test_docker_project_execution_async_docker_args(
 80      docker_example_base_image,
 81  ):
 82      submitted_run = mlflow.projects.run(
 83          TEST_DOCKER_PROJECT_DIR,
 84          experiment_id=file_store.FileStore.DEFAULT_EXPERIMENT_ID,
 85          parameters={"use_start_run": "0"},
 86          entry_point="test_tracking",
 87          docker_args={"memory": "1g", "privileged": True},
 88          synchronous=False,
 89      )
 90      submitted_run.wait()
 91  
 92      args = submitted_run.command_proc.args
 93      assert len([a for a in args if a == "--docker-args"]) == 2
 94      first_idx = args.index("--docker-args")
 95      second_idx = args.index("--docker-args", first_idx + 1)
 96      assert args[first_idx + 1] == "memory=1g"
 97      assert args[second_idx + 1] == "privileged"
 98  
 99  
100  @pytest.mark.parametrize(
101      ("tracking_uri", "expected_command_segment"),
102      [
103          (None, "-e MLFLOW_TRACKING_URI=/mlflow/tmp/mlruns"),
104          ("http://some-tracking-uri", "-e MLFLOW_TRACKING_URI=http://some-tracking-uri"),
105          ("databricks://some-profile", "-e MLFLOW_TRACKING_URI=databricks "),
106      ],
107  )
108  def test_docker_project_tracking_uri_propagation(
109      tmp_path,
110      tracking_uri,
111      expected_command_segment,
112      docker_example_base_image,
113  ):
114      pytest.skip("FileStore is no longer supported.")
115      mock_provider = mock.MagicMock()
116      mock_provider.get_config.return_value = DatabricksConfig.from_password(
117          "host", "user", "pass", insecure=True
118      )
119      # Create and mock local tracking directory
120      local_tracking_dir = os.path.join(tmp_path, "mlruns")
121      if tracking_uri is None:
122          tracking_uri = local_tracking_dir
123      old_uri = mlflow.get_tracking_uri()
124      with (
125          mock.patch(
126              "mlflow.utils.databricks_utils.ProfileConfigProvider", return_value=mock_provider
127          ),
128          mock.patch(
129              "mlflow.tracking._tracking_service.utils._get_store",
130              return_value=file_store.FileStore(local_tracking_dir),
131          ),
132      ):
133          try:
134              mlflow.set_tracking_uri(tracking_uri)
135              mlflow.projects.run(
136                  TEST_DOCKER_PROJECT_DIR,
137                  experiment_id=file_store.FileStore.DEFAULT_EXPERIMENT_ID,
138              )
139          finally:
140              mlflow.set_tracking_uri(old_uri)
141  
142  
143  def test_docker_uri_mode_validation(docker_example_base_image):
144      with pytest.raises(ExecutionException, match="When running on Databricks"):
145          mlflow.projects.run(TEST_DOCKER_PROJECT_DIR, backend="databricks", backend_config={})
146  
147  
148  def test_docker_image_uri_with_git():
149      with mock.patch("mlflow.projects.docker.get_git_commit") as get_git_commit_mock:
150          get_git_commit_mock.return_value = "1234567890"
151          image_uri = _get_docker_image_uri("my_project", "my_workdir")
152          assert image_uri == "my_project:1234567"
153          get_git_commit_mock.assert_called_with("my_workdir")
154  
155  
156  def test_docker_image_uri_no_git():
157      with mock.patch("mlflow.projects.docker.get_git_commit", return_value=None) as mock_commit:
158          image_uri = _get_docker_image_uri("my_project", "my_workdir")
159          assert image_uri == "my_project"
160          mock_commit.assert_called_with("my_workdir")
161  
162  
163  def test_docker_valid_project_backend_local():
164      work_dir = "./examples/docker"
165      project = _project_spec.load_project(work_dir)
166      mlflow.projects.docker.validate_docker_env(project)
167  
168  
169  def test_docker_invalid_project_backend_local():
170      work_dir = "./examples/docker"
171      project = _project_spec.load_project(work_dir)
172      project.name = None
173      with pytest.raises(ExecutionException, match="Project name in MLProject must be specified"):
174          mlflow.projects.docker.validate_docker_env(project)
175  
176  
177  @pytest.mark.parametrize(
178      ("artifact_uri", "host_artifact_uri", "container_artifact_uri", "should_mount"),
179      [
180          ("/tmp/mlruns/artifacts", "/tmp/mlruns/artifacts", "/tmp/mlruns/artifacts", True),
181          ("s3://my_bucket", None, None, False),
182          ("file:///tmp/mlruns/artifacts", "/tmp/mlruns/artifacts", "/tmp/mlruns/artifacts", True),
183          ("./mlruns", os.path.abspath("./mlruns"), "/mlflow/projects/code/mlruns", True),
184      ],
185  )
186  def test_docker_mount_local_artifact_uri(
187      artifact_uri, host_artifact_uri, container_artifact_uri, should_mount
188  ):
189      active_run = mock.MagicMock()
190      run_info = mock.MagicMock()
191      run_info.run_id = "fake_run_id"
192      run_info.experiment_id = "fake_experiment_id"
193      run_info.artifact_uri = artifact_uri
194      active_run.info = run_info
195      image = mock.MagicMock()
196      image.tags = ["image:tag"]
197  
198      docker_command = _get_docker_command(image, active_run)
199  
200      docker_volume_expected = f"-v {host_artifact_uri}:{container_artifact_uri}"
201      assert (docker_volume_expected in " ".join(docker_command)) == should_mount
202  
203  
204  def test_docker_databricks_tracking_cmd_and_envs():
205      mock_provider = mock.MagicMock()
206      mock_provider.get_config.return_value = DatabricksConfig.from_password(
207          "host", "user", "pass", insecure=True
208      )
209      with mock.patch(
210          "mlflow.utils.databricks_utils.ProfileConfigProvider", return_value=mock_provider
211      ):
212          cmds, envs = mlflow.projects.docker.get_docker_tracking_cmd_and_envs(
213              "databricks://some-profile"
214          )
215          assert envs == {
216              "DATABRICKS_HOST": "host",
217              "DATABRICKS_USERNAME": "user",
218              "DATABRICKS_PASSWORD": "pass",
219              "DATABRICKS_INSECURE": "True",
220              MLFLOW_TRACKING_URI.name: "databricks",
221          }
222          assert cmds == []
223  
224  
225  @pytest.mark.parametrize(
226      ("volumes", "environment", "os_environ", "expected"),
227      [
228          ([], ["VAR1"], {"VAR1": "value1"}, [("-e", "VAR1=value1")]),
229          ([], ["VAR1"], {}, ["should_crash", ("-e", "VAR1=value1")]),
230          ([], ["VAR1"], {"OTHER_VAR": "value1"}, ["should_crash", ("-e", "VAR1=value1")]),
231          (
232              [],
233              ["VAR1", ["VAR2", "value2"]],
234              {"VAR1": "value1"},
235              [("-e", "VAR1=value1"), ("-e", "VAR2=value2")],
236          ),
237          ([], [["VAR2", "value2"]], {"VAR1": "value1"}, [("-e", "VAR2=value2")]),
238          (
239              ["/path:/path"],
240              ["VAR1"],
241              {"VAR1": "value1"},
242              [("-e", "VAR1=value1"), ("-v", "/path:/path")],
243          ),
244          (
245              ["/path:/path"],
246              [["VAR2", "value2"]],
247              {"VAR1": "value1"},
248              [("-e", "VAR2=value2"), ("-v", "/path:/path")],
249          ),
250      ],
251  )
252  def test_docker_user_specified_env_vars(volumes, environment, expected, os_environ, monkeypatch):
253      active_run = mock.MagicMock()
254      run_info = mock.MagicMock()
255      run_info.run_id = "fake_run_id"
256      run_info.experiment_id = "fake_experiment_id"
257      run_info.artifact_uri = "/tmp/mlruns/artifacts"
258      active_run.info = run_info
259      image = mock.MagicMock()
260      image.tags = ["image:tag"]
261  
262      for name, value in os_environ.items():
263          monkeypatch.setenv(name, value)
264      if "should_crash" in expected:
265          expected.remove("should_crash")
266          with pytest.raises(MlflowException, match="This project expects"):
267              _get_docker_command(image, active_run, None, volumes, environment)
268      else:
269          docker_command = _get_docker_command(image, active_run, None, volumes, environment)
270          for exp_type, expected in expected:
271              assert expected in docker_command
272              assert docker_command[docker_command.index(expected) - 1] == exp_type
273  
274  
275  @pytest.mark.parametrize("docker_args", [{}, {"ARG": "VAL"}, {"ARG1": "VAL1", "ARG2": "VAL2"}])
276  def test_docker_run_args(docker_args):
277      active_run = mock.MagicMock()
278      run_info = mock.MagicMock()
279      run_info.run_id = "fake_run_id"
280      run_info.experiment_id = "fake_experiment_id"
281      run_info.artifact_uri = "/tmp/mlruns/artifacts"
282      active_run.info = run_info
283      image = mock.MagicMock()
284      image.tags = ["image:tag"]
285  
286      docker_command = _get_docker_command(image, active_run, docker_args, None, None)
287  
288      for flag, value in docker_args.items():
289          assert docker_command[docker_command.index(value) - 1] == f"--{flag}"
290  
291  
292  def test_docker_build_image_local(tmp_path):
293      client = docker.from_env()
294      dockerfile = tmp_path.joinpath("Dockerfile")
295      dockerfile.write_text(
296          """
297  FROM python:3.10
298  RUN pip --version
299  """
300      )
301      client.images.build(path=str(tmp_path), dockerfile=str(dockerfile), tag="my-python:latest")
302      tmp_path.joinpath("MLproject").write_text(
303          """
304  name: test
305  docker_env:
306    image: my-python
307  entry_points:
308    main:
309      command: python --version
310  """
311      )
312      submitted_run = mlflow.projects.run(str(tmp_path))
313      run = mlflow.get_run(submitted_run.run_id)
314      assert run.data.tags[MLFLOW_DOCKER_IMAGE_URI] == "my-python"
315  
316  
317  def test_docker_build_image_remote(tmp_path):
318      tmp_path.joinpath("MLproject").write_text(
319          """
320  name: test
321  docker_env:
322    image: python:3.9
323  entry_points:
324    main:
325      command: python --version
326  """
327      )
328      submitted_run = mlflow.projects.run(str(tmp_path))
329      run = mlflow.get_run(submitted_run.run_id)
330      assert run.data.tags[MLFLOW_DOCKER_IMAGE_URI] == "python:3.9"