/ tests / projects / test_projects_cli.py
test_projects_cli.py
  1  import hashlib
  2  import json
  3  import logging
  4  import os
  5  import shutil
  6  from pathlib import Path
  7  from unittest import mock
  8  
  9  import pytest
 10  from click.testing import CliRunner
 11  
 12  from mlflow import MlflowClient, cli
 13  from mlflow.utils import process
 14  from mlflow.utils.environment import _PythonEnv
 15  from mlflow.utils.virtualenv import _get_mlflow_virtualenv_root, _get_virtualenv_name
 16  
 17  from tests.integration.utils import invoke_cli_runner
 18  from tests.projects.utils import (
 19      GIT_PROJECT_URI,
 20      SSH_PROJECT_URI,
 21      TEST_DOCKER_PROJECT_DIR,
 22      TEST_PROJECT_DIR,
 23      TEST_VIRTUALENV_PROJECT_DIR,
 24      docker_example_base_image,  # noqa: F401
 25  )
 26  
 27  _logger = logging.getLogger(__name__)
 28  
 29  skip_if_skinny = pytest.mark.skipif(
 30      "MLFLOW_SKINNY" in os.environ,
 31      reason="MLflow skinny does not have dependencies to run this test",
 32  )
 33  
 34  
 35  @pytest.mark.parametrize("name", ["friend", "friend=you", "='friend'"])
 36  def test_run_local_params(name):
 37      excitement_arg = 2
 38      invoke_cli_runner(
 39          cli.run,
 40          [
 41              TEST_PROJECT_DIR,
 42              "-e",
 43              "greeter",
 44              "-P",
 45              "greeting=hi",
 46              "-P",
 47              f"name={name}",
 48              "-P",
 49              f"excitement={excitement_arg}",
 50          ],
 51      )
 52  
 53  
 54  @skip_if_skinny
 55  def test_run_local_with_docker_args(docker_example_base_image):
 56      # Verify that Docker project execution is successful when Docker flag and string
 57      # commandline arguments are supplied (`tty` and `name`, respectively)
 58      invoke_cli_runner(cli.run, [TEST_DOCKER_PROJECT_DIR, "-A", "tty", "-A", "name=mycontainer"])
 59  
 60  
 61  @pytest.mark.parametrize("experiment_name", [b"test-experiment".decode("utf-8"), "test-experiment"])
 62  def test_run_local_experiment_specification(experiment_name):
 63      invoke_cli_runner(
 64          cli.run,
 65          [
 66              TEST_PROJECT_DIR,
 67              "-e",
 68              "greeter",
 69              "-P",
 70              "name=test",
 71              "--experiment-name",
 72              experiment_name,
 73          ],
 74      )
 75  
 76      client = MlflowClient()
 77      experiment_id = client.get_experiment_by_name(experiment_name).experiment_id
 78  
 79      invoke_cli_runner(
 80          cli.run,
 81          [TEST_PROJECT_DIR, "-e", "greeter", "-P", "name=test", "--experiment-id", experiment_id],
 82      )
 83  
 84  
 85  @pytest.fixture(scope="module", autouse=True)
 86  def clean_mlruns_dir():
 87      yield
 88      dir_path = os.path.join(TEST_PROJECT_DIR, "mlruns")
 89      if os.path.exists(dir_path):
 90          shutil.rmtree(dir_path)
 91  
 92  
 93  @skip_if_skinny
 94  def test_run_local_conda_env():
 95      with open(os.path.join(TEST_PROJECT_DIR, "conda.yaml")) as handle:
 96          conda_env_contents = handle.read()
 97      expected_env_name = "mlflow-{}".format(
 98          hashlib.sha1(conda_env_contents.encode("utf-8"), usedforsecurity=False).hexdigest()
 99      )
100      try:
101          process._exec_cmd(cmd=["conda", "env", "remove", "--name", expected_env_name])
102      except process.ShellCommandException:
103          _logger.error(
104              "Unable to remove conda environment %s. The environment may not have been present, "
105              "continuing with running the test.",
106              expected_env_name,
107          )
108      invoke_cli_runner(
109          cli.run,
110          [TEST_PROJECT_DIR, "-e", "check_conda_env", "-P", f"conda_env_name={expected_env_name}"],
111      )
112  
113  
114  @skip_if_skinny
115  def test_run_uv_python_env():
116      python_env_path = os.path.join(TEST_VIRTUALENV_PROJECT_DIR, "python_env.yaml")
117      python_env_contents = _PythonEnv.from_yaml(python_env_path)
118  
119      work_dir_path = Path(TEST_VIRTUALENV_PROJECT_DIR)
120      virtualenv_root = Path(_get_mlflow_virtualenv_root())
121      env_name = _get_virtualenv_name(python_env_contents, work_dir_path)
122      env_dir = virtualenv_root / env_name
123  
124      if env_dir.exists():
125          shutil.rmtree(env_dir)
126  
127      invoke_cli_runner(
128          cli.run,
129          [TEST_VIRTUALENV_PROJECT_DIR, "-e", "test", "--env-manager", "uv"],
130          env={"UV_PRERELEASE": "allow"},
131      )
132  
133  
134  @skip_if_skinny
135  def test_run_git_https():
136      # Invoke command twice to ensure we set Git state in an isolated manner (e.g. don't attempt to
137      # create a git repo in the same directory twice, etc)
138      assert GIT_PROJECT_URI.startswith("https")
139      invoke_cli_runner(cli.run, [GIT_PROJECT_URI, "--env-manager", "local", "-P", "alpha=0.5"])
140      invoke_cli_runner(cli.run, [GIT_PROJECT_URI, "--env-manager", "local", "-P", "alpha=0.5"])
141  
142  
143  @pytest.mark.skipif(
144      "GITHUB_ACTIONS" in os.environ, reason="SSH keys are unavailable in GitHub Actions"
145  )
146  def test_run_git_ssh():
147      # Note: this test requires SSH authentication to GitHub, and so is disabled in GitHub Actions,
148      # where SSH keys are unavailable. However it should be run locally whenever logic related to
149      # running Git projects is modified.
150      assert SSH_PROJECT_URI.startswith("git@")
151      invoke_cli_runner(cli.run, [SSH_PROJECT_URI, "--env-manager", "local", "-P", "alpha=0.5"])
152      invoke_cli_runner(cli.run, [SSH_PROJECT_URI, "--env-manager", "local", "-P", "alpha=0.5"])
153  
154  
155  @pytest.mark.skipif(
156      "GITHUB_ACTIONS" in os.environ, reason="SSH keys are unavailable in GitHub Actions"
157  )
158  def test_run_git_ssh_from_release_version():
159      # Note: this test requires SSH authentication to GitHub, and so is disabled in GitHub Actions,
160      # where SSH keys are unavailable. However it should be run locally whenever logic related to
161      # running Git projects is modified.
162      assert SSH_PROJECT_URI.startswith("git@")
163      invoke_cli_runner(
164          cli.run, [SSH_PROJECT_URI, "--no-conda", "-P", "alpha=0.5", "-v", "version_testing"]
165      )
166      invoke_cli_runner(
167          cli.run, [SSH_PROJECT_URI, "--no-conda", "-P", "alpha=0.5", "-v", "version_testing"]
168      )
169  
170  
171  @pytest.mark.notrackingurimock
172  def test_run_databricks_cluster_spec(tmp_path):
173      cluster_spec = {
174          "spark_version": "5.0.x-scala2.11",
175          "num_workers": 2,
176          "node_type_id": "i3.xlarge",
177      }
178      cluster_spec_path = tmp_path.joinpath("cluster-spec.json")
179      with open(cluster_spec_path, "w") as handle:
180          json.dump(cluster_spec, handle)
181  
182      with mock.patch("mlflow.projects._run") as run_mock:
183          for cluster_spec_arg in [json.dumps(cluster_spec), cluster_spec_path]:
184              invoke_cli_runner(
185                  cli.run,
186                  [
187                      TEST_PROJECT_DIR,
188                      "-b",
189                      "databricks",
190                      "--backend-config",
191                      cluster_spec_arg,
192                      "-e",
193                      "greeter",
194                      "-P",
195                      "name=hi",
196                  ],
197                  env={"MLFLOW_TRACKING_URI": "databricks://profile"},
198              )
199              assert run_mock.call_count == 1
200              _, run_kwargs = run_mock.call_args_list[0]
201              assert run_kwargs["backend_config"] == cluster_spec
202              run_mock.reset_mock()
203          res = CliRunner().invoke(
204              cli.run,
205              [
206                  TEST_PROJECT_DIR,
207                  "-m",
208                  "databricks",
209                  "--cluster-spec",
210                  json.dumps(cluster_spec) + "JUNK",
211                  "-e",
212                  "greeter",
213                  "-P",
214                  "name=hi",
215              ],
216              env={"MLFLOW_TRACKING_URI": "databricks://profile"},
217          )
218          assert res.exit_code != 0
219  
220  
221  def test_mlflow_run():
222      with mock.patch("mlflow.cli.projects") as mock_projects:
223          result = CliRunner().invoke(cli.run)
224          mock_projects.run.assert_not_called()
225          assert "Missing argument 'URI'" in result.output
226  
227      with mock.patch("mlflow.cli.projects") as mock_projects:
228          CliRunner().invoke(cli.run, ["project_uri"])
229          mock_projects.run.assert_called_once()
230  
231      with mock.patch("mlflow.cli.projects") as mock_projects:
232          CliRunner().invoke(cli.run, ["--experiment-id", "5", "project_uri"])
233          mock_projects.run.assert_called_once()
234  
235      with mock.patch("mlflow.cli.projects") as mock_projects:
236          CliRunner().invoke(cli.run, ["--experiment-name", "random name", "project_uri"])
237          mock_projects.run.assert_called_once()
238  
239      with mock.patch("mlflow.cli.projects") as mock_projects:
240          result = CliRunner().invoke(
241              cli.run, ["--experiment-id", "51", "--experiment-name", "name blah", "uri"]
242          )
243          mock_projects.run.assert_not_called()
244          assert "Specify only one of 'experiment-name' or 'experiment-id' options." in result.output