test_projects_cli.py
1 import hashlib 2 import json 3 import logging 4 import os 5 import shutil 6 from pathlib import Path 7 from unittest import mock 8 9 import pytest 10 from click.testing import CliRunner 11 12 from mlflow import MlflowClient, cli 13 from mlflow.utils import process 14 from mlflow.utils.environment import _PythonEnv 15 from mlflow.utils.virtualenv import _get_mlflow_virtualenv_root, _get_virtualenv_name 16 17 from tests.integration.utils import invoke_cli_runner 18 from tests.projects.utils import ( 19 GIT_PROJECT_URI, 20 SSH_PROJECT_URI, 21 TEST_DOCKER_PROJECT_DIR, 22 TEST_PROJECT_DIR, 23 TEST_VIRTUALENV_PROJECT_DIR, 24 docker_example_base_image, # noqa: F401 25 ) 26 27 _logger = logging.getLogger(__name__) 28 29 skip_if_skinny = pytest.mark.skipif( 30 "MLFLOW_SKINNY" in os.environ, 31 reason="MLflow skinny does not have dependencies to run this test", 32 ) 33 34 35 @pytest.mark.parametrize("name", ["friend", "friend=you", "='friend'"]) 36 def test_run_local_params(name): 37 excitement_arg = 2 38 invoke_cli_runner( 39 cli.run, 40 [ 41 TEST_PROJECT_DIR, 42 "-e", 43 "greeter", 44 "-P", 45 "greeting=hi", 46 "-P", 47 f"name={name}", 48 "-P", 49 f"excitement={excitement_arg}", 50 ], 51 ) 52 53 54 @skip_if_skinny 55 def test_run_local_with_docker_args(docker_example_base_image): 56 # Verify that Docker project execution is successful when Docker flag and string 57 # commandline arguments are supplied (`tty` and `name`, respectively) 58 invoke_cli_runner(cli.run, [TEST_DOCKER_PROJECT_DIR, "-A", "tty", "-A", "name=mycontainer"]) 59 60 61 @pytest.mark.parametrize("experiment_name", [b"test-experiment".decode("utf-8"), "test-experiment"]) 62 def test_run_local_experiment_specification(experiment_name): 63 invoke_cli_runner( 64 cli.run, 65 [ 66 TEST_PROJECT_DIR, 67 "-e", 68 "greeter", 69 "-P", 70 "name=test", 71 "--experiment-name", 72 experiment_name, 73 ], 74 ) 75 76 client = MlflowClient() 77 experiment_id = client.get_experiment_by_name(experiment_name).experiment_id 78 79 invoke_cli_runner( 80 cli.run, 81 [TEST_PROJECT_DIR, "-e", "greeter", "-P", "name=test", "--experiment-id", experiment_id], 82 ) 83 84 85 @pytest.fixture(scope="module", autouse=True) 86 def clean_mlruns_dir(): 87 yield 88 dir_path = os.path.join(TEST_PROJECT_DIR, "mlruns") 89 if os.path.exists(dir_path): 90 shutil.rmtree(dir_path) 91 92 93 @skip_if_skinny 94 def test_run_local_conda_env(): 95 with open(os.path.join(TEST_PROJECT_DIR, "conda.yaml")) as handle: 96 conda_env_contents = handle.read() 97 expected_env_name = "mlflow-{}".format( 98 hashlib.sha1(conda_env_contents.encode("utf-8"), usedforsecurity=False).hexdigest() 99 ) 100 try: 101 process._exec_cmd(cmd=["conda", "env", "remove", "--name", expected_env_name]) 102 except process.ShellCommandException: 103 _logger.error( 104 "Unable to remove conda environment %s. The environment may not have been present, " 105 "continuing with running the test.", 106 expected_env_name, 107 ) 108 invoke_cli_runner( 109 cli.run, 110 [TEST_PROJECT_DIR, "-e", "check_conda_env", "-P", f"conda_env_name={expected_env_name}"], 111 ) 112 113 114 @skip_if_skinny 115 def test_run_uv_python_env(): 116 python_env_path = os.path.join(TEST_VIRTUALENV_PROJECT_DIR, "python_env.yaml") 117 python_env_contents = _PythonEnv.from_yaml(python_env_path) 118 119 work_dir_path = Path(TEST_VIRTUALENV_PROJECT_DIR) 120 virtualenv_root = Path(_get_mlflow_virtualenv_root()) 121 env_name = _get_virtualenv_name(python_env_contents, work_dir_path) 122 env_dir = virtualenv_root / env_name 123 124 if env_dir.exists(): 125 shutil.rmtree(env_dir) 126 127 invoke_cli_runner( 128 cli.run, 129 [TEST_VIRTUALENV_PROJECT_DIR, "-e", "test", "--env-manager", "uv"], 130 env={"UV_PRERELEASE": "allow"}, 131 ) 132 133 134 @skip_if_skinny 135 def test_run_git_https(): 136 # Invoke command twice to ensure we set Git state in an isolated manner (e.g. don't attempt to 137 # create a git repo in the same directory twice, etc) 138 assert GIT_PROJECT_URI.startswith("https") 139 invoke_cli_runner(cli.run, [GIT_PROJECT_URI, "--env-manager", "local", "-P", "alpha=0.5"]) 140 invoke_cli_runner(cli.run, [GIT_PROJECT_URI, "--env-manager", "local", "-P", "alpha=0.5"]) 141 142 143 @pytest.mark.skipif( 144 "GITHUB_ACTIONS" in os.environ, reason="SSH keys are unavailable in GitHub Actions" 145 ) 146 def test_run_git_ssh(): 147 # Note: this test requires SSH authentication to GitHub, and so is disabled in GitHub Actions, 148 # where SSH keys are unavailable. However it should be run locally whenever logic related to 149 # running Git projects is modified. 150 assert SSH_PROJECT_URI.startswith("git@") 151 invoke_cli_runner(cli.run, [SSH_PROJECT_URI, "--env-manager", "local", "-P", "alpha=0.5"]) 152 invoke_cli_runner(cli.run, [SSH_PROJECT_URI, "--env-manager", "local", "-P", "alpha=0.5"]) 153 154 155 @pytest.mark.skipif( 156 "GITHUB_ACTIONS" in os.environ, reason="SSH keys are unavailable in GitHub Actions" 157 ) 158 def test_run_git_ssh_from_release_version(): 159 # Note: this test requires SSH authentication to GitHub, and so is disabled in GitHub Actions, 160 # where SSH keys are unavailable. However it should be run locally whenever logic related to 161 # running Git projects is modified. 162 assert SSH_PROJECT_URI.startswith("git@") 163 invoke_cli_runner( 164 cli.run, [SSH_PROJECT_URI, "--no-conda", "-P", "alpha=0.5", "-v", "version_testing"] 165 ) 166 invoke_cli_runner( 167 cli.run, [SSH_PROJECT_URI, "--no-conda", "-P", "alpha=0.5", "-v", "version_testing"] 168 ) 169 170 171 @pytest.mark.notrackingurimock 172 def test_run_databricks_cluster_spec(tmp_path): 173 cluster_spec = { 174 "spark_version": "5.0.x-scala2.11", 175 "num_workers": 2, 176 "node_type_id": "i3.xlarge", 177 } 178 cluster_spec_path = tmp_path.joinpath("cluster-spec.json") 179 with open(cluster_spec_path, "w") as handle: 180 json.dump(cluster_spec, handle) 181 182 with mock.patch("mlflow.projects._run") as run_mock: 183 for cluster_spec_arg in [json.dumps(cluster_spec), cluster_spec_path]: 184 invoke_cli_runner( 185 cli.run, 186 [ 187 TEST_PROJECT_DIR, 188 "-b", 189 "databricks", 190 "--backend-config", 191 cluster_spec_arg, 192 "-e", 193 "greeter", 194 "-P", 195 "name=hi", 196 ], 197 env={"MLFLOW_TRACKING_URI": "databricks://profile"}, 198 ) 199 assert run_mock.call_count == 1 200 _, run_kwargs = run_mock.call_args_list[0] 201 assert run_kwargs["backend_config"] == cluster_spec 202 run_mock.reset_mock() 203 res = CliRunner().invoke( 204 cli.run, 205 [ 206 TEST_PROJECT_DIR, 207 "-m", 208 "databricks", 209 "--cluster-spec", 210 json.dumps(cluster_spec) + "JUNK", 211 "-e", 212 "greeter", 213 "-P", 214 "name=hi", 215 ], 216 env={"MLFLOW_TRACKING_URI": "databricks://profile"}, 217 ) 218 assert res.exit_code != 0 219 220 221 def test_mlflow_run(): 222 with mock.patch("mlflow.cli.projects") as mock_projects: 223 result = CliRunner().invoke(cli.run) 224 mock_projects.run.assert_not_called() 225 assert "Missing argument 'URI'" in result.output 226 227 with mock.patch("mlflow.cli.projects") as mock_projects: 228 CliRunner().invoke(cli.run, ["project_uri"]) 229 mock_projects.run.assert_called_once() 230 231 with mock.patch("mlflow.cli.projects") as mock_projects: 232 CliRunner().invoke(cli.run, ["--experiment-id", "5", "project_uri"]) 233 mock_projects.run.assert_called_once() 234 235 with mock.patch("mlflow.cli.projects") as mock_projects: 236 CliRunner().invoke(cli.run, ["--experiment-name", "random name", "project_uri"]) 237 mock_projects.run.assert_called_once() 238 239 with mock.patch("mlflow.cli.projects") as mock_projects: 240 result = CliRunner().invoke( 241 cli.run, ["--experiment-id", "51", "--experiment-name", "name blah", "uri"] 242 ) 243 mock_projects.run.assert_not_called() 244 assert "Specify only one of 'experiment-name' or 'experiment-id' options." in result.output