/ tests / projects / test_projects.py
test_projects.py
  1  import json
  2  import os
  3  import shutil
  4  import subprocess
  5  import uuid
  6  from unittest import mock
  7  
  8  import git
  9  import pytest
 10  import yaml
 11  
 12  import mlflow
 13  from mlflow import MlflowClient
 14  from mlflow.entities import RunStatus, SourceType, ViewType
 15  from mlflow.environment_variables import MLFLOW_CONDA_CREATE_ENV_CMD, MLFLOW_CONDA_HOME
 16  from mlflow.exceptions import ExecutionException, MlflowException
 17  from mlflow.projects import _parse_kubernetes_config, _resolve_experiment_id
 18  from mlflow.store.tracking.file_store import FileStore
 19  from mlflow.utils import PYTHON_VERSION
 20  from mlflow.utils.conda import CONDA_EXE, get_or_create_conda_env
 21  from mlflow.utils.mlflow_tags import (
 22      MLFLOW_GIT_BRANCH,
 23      MLFLOW_GIT_REPO_URL,
 24      MLFLOW_PARENT_RUN_ID,
 25      MLFLOW_PROJECT_BACKEND,
 26      MLFLOW_PROJECT_ENTRY_POINT,
 27      MLFLOW_PROJECT_ENV,
 28      MLFLOW_SOURCE_NAME,
 29      MLFLOW_SOURCE_TYPE,
 30      MLFLOW_USER,
 31  )
 32  from mlflow.utils.process import ShellCommandException
 33  
 34  from tests.projects.utils import TEST_PROJECT_DIR, TEST_PROJECT_NAME, validate_exit_status
 35  
 36  MOCK_USER = "janebloggs"
 37  
 38  
 39  @pytest.fixture
 40  def patch_user():
 41      with mock.patch("mlflow.projects.utils._get_user", return_value=MOCK_USER):
 42          yield
 43  
 44  
 45  def _get_version_local_git_repo(local_git_repo):
 46      repo = git.Repo(local_git_repo, search_parent_directories=True)
 47      return repo.git.rev_parse("HEAD")
 48  
 49  
 50  @pytest.fixture(scope="module", autouse=True)
 51  def clean_mlruns_dir():
 52      yield
 53      dir_path = os.path.join(TEST_PROJECT_DIR, "mlruns")
 54      if os.path.exists(dir_path):
 55          shutil.rmtree(dir_path)
 56  
 57  
 58  @pytest.mark.parametrize(
 59      ("experiment_name", "experiment_id", "expected"),
 60      [
 61          ("Default", None, "0"),
 62          ("add an experiment", None, "1"),
 63          (None, 2, "2"),
 64          (None, "2", "2"),
 65          (None, None, "0"),
 66      ],
 67  )
 68  def test_resolve_experiment_id(experiment_name, experiment_id, expected):
 69      assert expected == _resolve_experiment_id(
 70          experiment_name=experiment_name, experiment_id=experiment_id
 71      )
 72  
 73  
 74  def test_resolve_experiment_id_should_not_allow_both_name_and_id_in_use():
 75      with pytest.raises(
 76          MlflowException, match="Specify only one of 'experiment_name' or 'experiment_id'."
 77      ):
 78          _resolve_experiment_id(experiment_name="experiment_named", experiment_id="44")
 79  
 80  
 81  def test_invalid_run_mode():
 82      with pytest.raises(
 83          ExecutionException, match="Got unsupported execution mode some unsupported mode"
 84      ):
 85          mlflow.projects.run(uri=TEST_PROJECT_DIR, backend="some unsupported mode")
 86  
 87  
 88  def test_expected_tags_logged_when_using_conda():
 89      with mock.patch.object(MlflowClient, "set_tag") as tag_mock:
 90          try:
 91              mlflow.projects.run(TEST_PROJECT_DIR, env_manager="conda")
 92          finally:
 93              tag_mock.assert_has_calls(
 94                  [
 95                      mock.call(mock.ANY, MLFLOW_PROJECT_BACKEND, "local"),
 96                      mock.call(mock.ANY, MLFLOW_PROJECT_ENV, "conda"),
 97                  ],
 98                  any_order=True,
 99              )
100  
101  
102  @pytest.mark.usefixtures("patch_user")
103  @pytest.mark.parametrize("use_start_run", map(str, [0, 1]))
104  @pytest.mark.parametrize("version", [None, "master", "git-commit"])
105  def test_run_local_git_repo(
106      local_git_repo, local_git_repo_uri, use_start_run, version, monkeypatch
107  ):
108      monkeypatch.setenv("DATABRICKS_HOST", "my-host")
109      monkeypatch.setenv("DATABRICKS_TOKEN", "my-token")
110      if version is not None:
111          uri = local_git_repo_uri + "#" + TEST_PROJECT_NAME
112      else:
113          uri = os.path.join(f"{local_git_repo}/", TEST_PROJECT_NAME)
114      if version == "git-commit":
115          version = _get_version_local_git_repo(local_git_repo)
116      submitted_run = mlflow.projects.run(
117          uri,
118          entry_point="test_tracking",
119          version=version,
120          parameters={"use_start_run": use_start_run},
121          env_manager="local",
122          experiment_id=FileStore.DEFAULT_EXPERIMENT_ID,
123      )
124  
125      # Blocking runs should be finished when they return
126      validate_exit_status(submitted_run.get_status(), RunStatus.FINISHED)
127      # Test that we can call wait() on a synchronous run & that the run has the correct
128      # status after calling wait().
129      submitted_run.wait()
130      validate_exit_status(submitted_run.get_status(), RunStatus.FINISHED)
131      # Validate run contents in the FileStore
132      run_id = submitted_run.run_id
133      mlflow_service = MlflowClient()
134      runs = mlflow_service.search_runs(
135          [FileStore.DEFAULT_EXPERIMENT_ID], run_view_type=ViewType.ACTIVE_ONLY
136      )
137      assert len(runs) == 1
138      store_run_id = runs[0].info.run_id
139      assert run_id == store_run_id
140      run = mlflow_service.get_run(run_id)
141  
142      assert run.info.status == RunStatus.to_string(RunStatus.FINISHED)
143  
144      assert run.data.params == {
145          "use_start_run": use_start_run,
146      }
147      assert run.data.metrics == {"some_key": 3}
148  
149      tags = run.data.tags
150      assert tags[MLFLOW_USER] == MOCK_USER
151      assert "file:" in tags[MLFLOW_SOURCE_NAME]
152      assert tags[MLFLOW_SOURCE_TYPE] == SourceType.to_string(SourceType.PROJECT)
153      assert tags[MLFLOW_PROJECT_ENTRY_POINT] == "test_tracking"
154      assert tags[MLFLOW_PROJECT_BACKEND] == "local"
155  
156      if version == "master":
157          assert tags[MLFLOW_GIT_BRANCH] == "master"
158          assert tags[MLFLOW_GIT_REPO_URL] == local_git_repo_uri
159  
160  
161  def test_invalid_version_local_git_repo(local_git_repo_uri):
162      # Run project with invalid commit hash
163      with pytest.raises(ExecutionException, match=r"Unable to checkout version \'badc0de\'"):
164          mlflow.projects.run(
165              local_git_repo_uri + "#" + TEST_PROJECT_NAME,
166              entry_point="test_tracking",
167              version="badc0de",
168              env_manager="local",
169              experiment_id=FileStore.DEFAULT_EXPERIMENT_ID,
170          )
171  
172  
173  @pytest.mark.parametrize("use_start_run", map(str, [0, 1]))
174  @pytest.mark.usefixtures("patch_user")
175  def test_run(use_start_run):
176      submitted_run = mlflow.projects.run(
177          TEST_PROJECT_DIR,
178          entry_point="test_tracking",
179          parameters={"use_start_run": use_start_run},
180          env_manager="local",
181          experiment_id=FileStore.DEFAULT_EXPERIMENT_ID,
182      )
183      assert submitted_run.run_id is not None
184      # Blocking runs should be finished when they return
185      validate_exit_status(submitted_run.get_status(), RunStatus.FINISHED)
186      # Test that we can call wait() on a synchronous run & that the run has the correct
187      # status after calling wait().
188      submitted_run.wait()
189      validate_exit_status(submitted_run.get_status(), RunStatus.FINISHED)
190      # Validate run contents in the FileStore
191      run_id = submitted_run.run_id
192      mlflow_service = MlflowClient()
193  
194      runs = mlflow_service.search_runs(
195          [FileStore.DEFAULT_EXPERIMENT_ID], run_view_type=ViewType.ACTIVE_ONLY
196      )
197      assert len(runs) == 1
198      store_run_id = runs[0].info.run_id
199      assert run_id == store_run_id
200      run = mlflow_service.get_run(run_id)
201  
202      assert run.info.status == RunStatus.to_string(RunStatus.FINISHED)
203  
204      assert run.data.params == {
205          "use_start_run": use_start_run,
206      }
207      assert run.data.metrics == {"some_key": 3}
208  
209      tags = run.data.tags
210      assert tags[MLFLOW_USER] == MOCK_USER
211      assert "file:" in tags[MLFLOW_SOURCE_NAME]
212      assert tags[MLFLOW_SOURCE_TYPE] == SourceType.to_string(SourceType.PROJECT)
213      assert tags[MLFLOW_PROJECT_ENTRY_POINT] == "test_tracking"
214  
215  
216  def test_run_with_parent():
217      with mlflow.start_run():
218          parent_run_id = mlflow.active_run().info.run_id
219          submitted_run = mlflow.projects.run(
220              TEST_PROJECT_DIR,
221              entry_point="test_tracking",
222              parameters={"use_start_run": "1"},
223              env_manager="local",
224              experiment_id=FileStore.DEFAULT_EXPERIMENT_ID,
225          )
226      assert submitted_run.run_id is not None
227      validate_exit_status(submitted_run.get_status(), RunStatus.FINISHED)
228      run_id = submitted_run.run_id
229      run = MlflowClient().get_run(run_id)
230      assert run.data.tags[MLFLOW_PARENT_RUN_ID] == parent_run_id
231  
232  
233  def test_run_with_artifact_path(tmp_path):
234      artifact_file = tmp_path.joinpath("model.pkl")
235      artifact_file.write_text("Hello world")
236      with mlflow.start_run() as run:
237          mlflow.log_artifact(artifact_file)
238          submitted_run = mlflow.projects.run(
239              TEST_PROJECT_DIR,
240              entry_point="test_artifact_path",
241              parameters={"model": f"runs:/{run.info.run_id}/model.pkl"},
242              env_manager="local",
243              experiment_id=FileStore.DEFAULT_EXPERIMENT_ID,
244          )
245          validate_exit_status(submitted_run.get_status(), RunStatus.FINISHED)
246  
247  
248  def test_run_async():
249      submitted_run0 = mlflow.projects.run(
250          TEST_PROJECT_DIR,
251          entry_point="sleep",
252          parameters={"duration": 2},
253          env_manager="local",
254          experiment_id=FileStore.DEFAULT_EXPERIMENT_ID,
255          synchronous=False,
256      )
257      validate_exit_status(submitted_run0.get_status(), RunStatus.RUNNING)
258      submitted_run0.wait()
259      validate_exit_status(submitted_run0.get_status(), RunStatus.FINISHED)
260      submitted_run1 = mlflow.projects.run(
261          TEST_PROJECT_DIR,
262          entry_point="sleep",
263          parameters={"duration": -1, "invalid-param": 30},
264          env_manager="local",
265          experiment_id=FileStore.DEFAULT_EXPERIMENT_ID,
266          synchronous=False,
267      )
268      submitted_run1.wait()
269      validate_exit_status(submitted_run1.get_status(), RunStatus.FAILED)
270  
271  
272  @pytest.mark.parametrize(
273      ("mock_env", "expected_conda", "expected_activate"),
274      [
275          (
276              {CONDA_EXE: "/abc/conda"},
277              "/abc/conda",
278              "/abc/activate",
279          ),
280          (
281              {MLFLOW_CONDA_HOME.name: "/some/dir/"},
282              "/some/dir/bin/conda",
283              "/some/dir/bin/activate",
284          ),
285      ],
286  )
287  def test_conda_path(mock_env, expected_conda, expected_activate, monkeypatch):
288      for name in [CONDA_EXE, MLFLOW_CONDA_HOME.name]:
289          monkeypatch.delenv(name, raising=False)
290      for name, value in mock_env.items():
291          monkeypatch.setenv(name, value)
292      assert mlflow.utils.conda.get_conda_bin_executable("conda") == expected_conda
293      assert mlflow.utils.conda.get_conda_bin_executable("activate") == expected_activate
294  
295  
296  @pytest.mark.parametrize(
297      ("mock_env", "expected_conda_env_create_path"),
298      [
299          (
300              {CONDA_EXE: "/abc/conda"},
301              "/abc/conda",
302          ),
303          (
304              {CONDA_EXE: "/abc/conda", MLFLOW_CONDA_CREATE_ENV_CMD.name: "mamba"},
305              "/abc/mamba",
306          ),
307          (
308              {MLFLOW_CONDA_HOME.name: "/some/dir/"},
309              "/some/dir/bin/conda",
310          ),
311          (
312              {MLFLOW_CONDA_HOME.name: "/some/dir/", MLFLOW_CONDA_CREATE_ENV_CMD.name: "mamba"},
313              "/some/dir/bin/mamba",
314          ),
315      ],
316  )
317  def test_find_conda_executables(mock_env, expected_conda_env_create_path, monkeypatch):
318      """
319      Verify that we correctly determine the path to executables to be used to
320      create environments (for example, it could be mamba instead of conda)
321      """
322      monkeypatch.delenv(CONDA_EXE, raising=False)
323      monkeypatch.delenv(MLFLOW_CONDA_HOME.name, raising=False)
324      monkeypatch.delenv(MLFLOW_CONDA_CREATE_ENV_CMD.name, raising=False)
325      for name, value in mock_env.items():
326          monkeypatch.setenv(name, value)
327      conda_env_create_path = mlflow.utils.conda._get_conda_executable_for_create_env()
328      assert conda_env_create_path == expected_conda_env_create_path
329  
330  
331  def test_create_env_with_mamba(monkeypatch):
332      """
333      Test that mamba is called when set, and that we fail when mamba is not available or is
334      not working. We mock the calls so we do not actually execute mamba (which is not
335      installed in the test environment anyway)
336      """
337  
338      def exec_cmd_mock(cmd, *args, **kwargs):
339          if cmd[-1] == "--json":
340              # We are supposed to list environments in JSON format
341              return subprocess.CompletedProcess(
342                  cmd, 0, json.dumps({"envs": ["mlflow-mock-environment"]}), None
343              )
344          else:
345              # Here we are creating the environment, no need to return
346              # anything
347              return subprocess.CompletedProcess(cmd, 0)
348  
349      def exec_cmd_mock_raise(cmd, *args, **kwargs):
350          if os.path.basename(cmd[0]) == "mamba":
351              raise OSError()
352  
353      conda_env_path = os.path.join(TEST_PROJECT_DIR, "conda.yaml")
354  
355      monkeypatch.setenv(MLFLOW_CONDA_CREATE_ENV_CMD.name, "mamba")
356      # Simulate success
357      with mock.patch("mlflow.utils.process._exec_cmd", side_effect=exec_cmd_mock):
358          mlflow.utils.conda.get_or_create_conda_env(conda_env_path)
359  
360      # Simulate a non-working or non-existent mamba
361      with mock.patch("mlflow.utils.process._exec_cmd", side_effect=exec_cmd_mock_raise):
362          with pytest.raises(
363              ExecutionException,
364              match="You have set the env variable MLFLOW_CONDA_CREATE_ENV_CMD",
365          ):
366              mlflow.utils.conda.get_or_create_conda_env(conda_env_path)
367  
368  
369  def test_conda_environment_cleaned_up_when_pip_fails(tmp_path):
370      conda_yaml = tmp_path / "conda.yaml"
371      content = f"""
372  name: {uuid.uuid4().hex}
373  channels:
374    - conda-forge
375  dependencies:
376    - python={PYTHON_VERSION}
377    - pip
378    - pip:
379        - mlflow==999.999.999
380  """
381      conda_yaml.write_text(content)
382      envs_before = mlflow.utils.conda._list_conda_environments()
383  
384      # `conda create` should fail because mlflow 999.999.999 doesn't exist
385      with pytest.raises(ShellCommandException, match=r"No matching distribution found"):
386          mlflow.utils.conda.get_or_create_conda_env(conda_yaml, capture_output=True)
387  
388      # Ensure the environment is cleaned up
389      envs_after = mlflow.utils.conda._list_conda_environments()
390      assert envs_before == envs_after
391  
392  
393  def test_cancel_run():
394      submitted_run0, submitted_run1 = (
395          mlflow.projects.run(
396              TEST_PROJECT_DIR,
397              entry_point="sleep",
398              parameters={"duration": 2},
399              env_manager="local",
400              experiment_id=FileStore.DEFAULT_EXPERIMENT_ID,
401              synchronous=False,
402          )
403          for _ in range(2)
404      )
405      submitted_run0.cancel()
406      validate_exit_status(submitted_run0.get_status(), RunStatus.FAILED)
407      # Sanity check: cancelling one run has no effect on the other
408      assert submitted_run1.wait()
409      validate_exit_status(submitted_run1.get_status(), RunStatus.FINISHED)
410      # Try cancelling after calling wait()
411      submitted_run1.cancel()
412      validate_exit_status(submitted_run1.get_status(), RunStatus.FINISHED)
413  
414  
415  def test_parse_kubernetes_config():
416      work_dir = "./examples/docker"
417      kubernetes_config = {
418          "kube-context": "docker-for-desktop",
419          "kube-job-template-path": os.path.join(work_dir, "kubernetes_job_template.yaml"),
420          "repository-uri": "dockerhub_account/mlflow-kubernetes-example",
421      }
422      yaml_obj = None
423      with open(kubernetes_config["kube-job-template-path"]) as job_template:
424          yaml_obj = yaml.safe_load(job_template.read())
425      kube_config = _parse_kubernetes_config(kubernetes_config)
426      assert kube_config["kube-context"] == kubernetes_config["kube-context"]
427      assert kube_config["kube-job-template-path"] == kubernetes_config["kube-job-template-path"]
428      assert kube_config["repository-uri"] == kubernetes_config["repository-uri"]
429      assert kube_config["kube-job-template"] == yaml_obj
430  
431  
432  @pytest.fixture
433  def mock_kubernetes_job_template(tmp_path):
434      k8s_yaml = tmp_path.joinpath("kubernetes_job_template.yaml")
435      k8s_yaml.write_text(
436          """
437  apiVersion: batch/v1
438  kind: Job
439  metadata:
440    name: "{replaced with MLflow Project name}"
441    namespace: mlflow
442  spec:
443    ttlSecondsAfterFinished: 100
444    backoffLimit: 0
445    template:
446      spec:
447        containers:
448        - name: "{replaced with MLflow Project name}"
449          image: "{replaced with URI of Docker image created during Project execution}"
450          command: ["{replaced with MLflow Project entry point command}"]
451          resources:
452            limits:
453              memory: 512Mi
454            requests:
455              memory: 256Mi
456        restartPolicy: Never
457  """.lstrip()
458      )
459      return str(k8s_yaml)
460  
461  
462  class StartsWithMatcher:
463      def __init__(self, prefix):
464          self.prefix = prefix
465  
466      def __eq__(self, other):
467          return isinstance(other, str) and other.startswith(self.prefix)
468  
469  
470  def test_parse_kubernetes_config_without_context(mock_kubernetes_job_template):
471      with mock.patch("mlflow.projects._logger.debug") as mock_debug:
472          kubernetes_config = {
473              "repository-uri": "dockerhub_account/mlflow-kubernetes-example",
474              "kube-job-template-path": mock_kubernetes_job_template,
475          }
476          _parse_kubernetes_config(kubernetes_config)
477          mock_debug.assert_called_once_with(
478              StartsWithMatcher("Could not find kube-context in backend_config")
479          )
480  
481  
482  def test_parse_kubernetes_config_without_image_uri(mock_kubernetes_job_template):
483      kubernetes_config = {
484          "kube-context": "docker-for-desktop",
485          "kube-job-template-path": mock_kubernetes_job_template,
486      }
487      with pytest.raises(ExecutionException, match="Could not find 'repository-uri'"):
488          _parse_kubernetes_config(kubernetes_config)
489  
490  
491  def test_parse_kubernetes_config_invalid_template_job_file():
492      kubernetes_config = {
493          "kube-context": "docker-for-desktop",
494          "repository-uri": "username/mlflow-kubernetes-example",
495          "kube-job-template-path": "file_not_found.yaml",
496      }
497      with pytest.raises(ExecutionException, match="Could not find 'kube-job-template-path'"):
498          _parse_kubernetes_config(kubernetes_config)
499  
500  
501  @pytest.mark.parametrize("synchronous", [True, False])
502  def test_credential_propagation(synchronous, monkeypatch):
503      class DummyProcess:
504          def wait(self):
505              return 0
506  
507          def poll(self):
508              return 0
509  
510          def communicate(self, _):
511              return "", ""
512  
513      monkeypatch.setenv("DATABRICKS_HOST", "host")
514      monkeypatch.setenv("DATABRICKS_TOKEN", "mytoken")
515      with (
516          mock.patch("subprocess.Popen", return_value=DummyProcess()) as popen_mock,
517          mock.patch("mlflow.utils.uri.is_databricks_uri", return_value=True),
518      ):
519          mlflow.projects.run(
520              TEST_PROJECT_DIR,
521              entry_point="sleep",
522              experiment_id=FileStore.DEFAULT_EXPERIMENT_ID,
523              parameters={"duration": 2},
524              env_manager="local",
525              synchronous=synchronous,
526          )
527          _, kwargs = popen_mock.call_args
528          env = kwargs["env"]
529          assert env["DATABRICKS_HOST"] == "host"
530          assert env["DATABRICKS_TOKEN"] == "mytoken"
531  
532  
533  def test_get_or_create_conda_env_capture_output_mode(tmp_path):
534      conda_yaml_file = tmp_path / "conda.yaml"
535      conda_yaml_file.write_text(
536          """
537  channels:
538  - conda-forge
539  dependencies:
540  - pip:
541    - scikit-learn==99.99.99
542  """
543      )
544      with pytest.raises(
545          ShellCommandException,
546          match="Could not find a version that satisfies the requirement scikit-learn==99.99.99",
547      ):
548          get_or_create_conda_env(str(conda_yaml_file), capture_output=True)