test_projects.py
1 import json 2 import os 3 import shutil 4 import subprocess 5 import uuid 6 from unittest import mock 7 8 import git 9 import pytest 10 import yaml 11 12 import mlflow 13 from mlflow import MlflowClient 14 from mlflow.entities import RunStatus, SourceType, ViewType 15 from mlflow.environment_variables import MLFLOW_CONDA_CREATE_ENV_CMD, MLFLOW_CONDA_HOME 16 from mlflow.exceptions import ExecutionException, MlflowException 17 from mlflow.projects import _parse_kubernetes_config, _resolve_experiment_id 18 from mlflow.store.tracking.file_store import FileStore 19 from mlflow.utils import PYTHON_VERSION 20 from mlflow.utils.conda import CONDA_EXE, get_or_create_conda_env 21 from mlflow.utils.mlflow_tags import ( 22 MLFLOW_GIT_BRANCH, 23 MLFLOW_GIT_REPO_URL, 24 MLFLOW_PARENT_RUN_ID, 25 MLFLOW_PROJECT_BACKEND, 26 MLFLOW_PROJECT_ENTRY_POINT, 27 MLFLOW_PROJECT_ENV, 28 MLFLOW_SOURCE_NAME, 29 MLFLOW_SOURCE_TYPE, 30 MLFLOW_USER, 31 ) 32 from mlflow.utils.process import ShellCommandException 33 34 from tests.projects.utils import TEST_PROJECT_DIR, TEST_PROJECT_NAME, validate_exit_status 35 36 MOCK_USER = "janebloggs" 37 38 39 @pytest.fixture 40 def patch_user(): 41 with mock.patch("mlflow.projects.utils._get_user", return_value=MOCK_USER): 42 yield 43 44 45 def _get_version_local_git_repo(local_git_repo): 46 repo = git.Repo(local_git_repo, search_parent_directories=True) 47 return repo.git.rev_parse("HEAD") 48 49 50 @pytest.fixture(scope="module", autouse=True) 51 def clean_mlruns_dir(): 52 yield 53 dir_path = os.path.join(TEST_PROJECT_DIR, "mlruns") 54 if os.path.exists(dir_path): 55 shutil.rmtree(dir_path) 56 57 58 @pytest.mark.parametrize( 59 ("experiment_name", "experiment_id", "expected"), 60 [ 61 ("Default", None, "0"), 62 ("add an experiment", None, "1"), 63 (None, 2, "2"), 64 (None, "2", "2"), 65 (None, None, "0"), 66 ], 67 ) 68 def test_resolve_experiment_id(experiment_name, experiment_id, expected): 69 assert expected == _resolve_experiment_id( 70 experiment_name=experiment_name, experiment_id=experiment_id 71 ) 72 73 74 def test_resolve_experiment_id_should_not_allow_both_name_and_id_in_use(): 75 with pytest.raises( 76 MlflowException, match="Specify only one of 'experiment_name' or 'experiment_id'." 77 ): 78 _resolve_experiment_id(experiment_name="experiment_named", experiment_id="44") 79 80 81 def test_invalid_run_mode(): 82 with pytest.raises( 83 ExecutionException, match="Got unsupported execution mode some unsupported mode" 84 ): 85 mlflow.projects.run(uri=TEST_PROJECT_DIR, backend="some unsupported mode") 86 87 88 def test_expected_tags_logged_when_using_conda(): 89 with mock.patch.object(MlflowClient, "set_tag") as tag_mock: 90 try: 91 mlflow.projects.run(TEST_PROJECT_DIR, env_manager="conda") 92 finally: 93 tag_mock.assert_has_calls( 94 [ 95 mock.call(mock.ANY, MLFLOW_PROJECT_BACKEND, "local"), 96 mock.call(mock.ANY, MLFLOW_PROJECT_ENV, "conda"), 97 ], 98 any_order=True, 99 ) 100 101 102 @pytest.mark.usefixtures("patch_user") 103 @pytest.mark.parametrize("use_start_run", map(str, [0, 1])) 104 @pytest.mark.parametrize("version", [None, "master", "git-commit"]) 105 def test_run_local_git_repo( 106 local_git_repo, local_git_repo_uri, use_start_run, version, monkeypatch 107 ): 108 monkeypatch.setenv("DATABRICKS_HOST", "my-host") 109 monkeypatch.setenv("DATABRICKS_TOKEN", "my-token") 110 if version is not None: 111 uri = local_git_repo_uri + "#" + TEST_PROJECT_NAME 112 else: 113 uri = os.path.join(f"{local_git_repo}/", TEST_PROJECT_NAME) 114 if version == "git-commit": 115 version = _get_version_local_git_repo(local_git_repo) 116 submitted_run = mlflow.projects.run( 117 uri, 118 entry_point="test_tracking", 119 version=version, 120 parameters={"use_start_run": use_start_run}, 121 env_manager="local", 122 experiment_id=FileStore.DEFAULT_EXPERIMENT_ID, 123 ) 124 125 # Blocking runs should be finished when they return 126 validate_exit_status(submitted_run.get_status(), RunStatus.FINISHED) 127 # Test that we can call wait() on a synchronous run & that the run has the correct 128 # status after calling wait(). 129 submitted_run.wait() 130 validate_exit_status(submitted_run.get_status(), RunStatus.FINISHED) 131 # Validate run contents in the FileStore 132 run_id = submitted_run.run_id 133 mlflow_service = MlflowClient() 134 runs = mlflow_service.search_runs( 135 [FileStore.DEFAULT_EXPERIMENT_ID], run_view_type=ViewType.ACTIVE_ONLY 136 ) 137 assert len(runs) == 1 138 store_run_id = runs[0].info.run_id 139 assert run_id == store_run_id 140 run = mlflow_service.get_run(run_id) 141 142 assert run.info.status == RunStatus.to_string(RunStatus.FINISHED) 143 144 assert run.data.params == { 145 "use_start_run": use_start_run, 146 } 147 assert run.data.metrics == {"some_key": 3} 148 149 tags = run.data.tags 150 assert tags[MLFLOW_USER] == MOCK_USER 151 assert "file:" in tags[MLFLOW_SOURCE_NAME] 152 assert tags[MLFLOW_SOURCE_TYPE] == SourceType.to_string(SourceType.PROJECT) 153 assert tags[MLFLOW_PROJECT_ENTRY_POINT] == "test_tracking" 154 assert tags[MLFLOW_PROJECT_BACKEND] == "local" 155 156 if version == "master": 157 assert tags[MLFLOW_GIT_BRANCH] == "master" 158 assert tags[MLFLOW_GIT_REPO_URL] == local_git_repo_uri 159 160 161 def test_invalid_version_local_git_repo(local_git_repo_uri): 162 # Run project with invalid commit hash 163 with pytest.raises(ExecutionException, match=r"Unable to checkout version \'badc0de\'"): 164 mlflow.projects.run( 165 local_git_repo_uri + "#" + TEST_PROJECT_NAME, 166 entry_point="test_tracking", 167 version="badc0de", 168 env_manager="local", 169 experiment_id=FileStore.DEFAULT_EXPERIMENT_ID, 170 ) 171 172 173 @pytest.mark.parametrize("use_start_run", map(str, [0, 1])) 174 @pytest.mark.usefixtures("patch_user") 175 def test_run(use_start_run): 176 submitted_run = mlflow.projects.run( 177 TEST_PROJECT_DIR, 178 entry_point="test_tracking", 179 parameters={"use_start_run": use_start_run}, 180 env_manager="local", 181 experiment_id=FileStore.DEFAULT_EXPERIMENT_ID, 182 ) 183 assert submitted_run.run_id is not None 184 # Blocking runs should be finished when they return 185 validate_exit_status(submitted_run.get_status(), RunStatus.FINISHED) 186 # Test that we can call wait() on a synchronous run & that the run has the correct 187 # status after calling wait(). 188 submitted_run.wait() 189 validate_exit_status(submitted_run.get_status(), RunStatus.FINISHED) 190 # Validate run contents in the FileStore 191 run_id = submitted_run.run_id 192 mlflow_service = MlflowClient() 193 194 runs = mlflow_service.search_runs( 195 [FileStore.DEFAULT_EXPERIMENT_ID], run_view_type=ViewType.ACTIVE_ONLY 196 ) 197 assert len(runs) == 1 198 store_run_id = runs[0].info.run_id 199 assert run_id == store_run_id 200 run = mlflow_service.get_run(run_id) 201 202 assert run.info.status == RunStatus.to_string(RunStatus.FINISHED) 203 204 assert run.data.params == { 205 "use_start_run": use_start_run, 206 } 207 assert run.data.metrics == {"some_key": 3} 208 209 tags = run.data.tags 210 assert tags[MLFLOW_USER] == MOCK_USER 211 assert "file:" in tags[MLFLOW_SOURCE_NAME] 212 assert tags[MLFLOW_SOURCE_TYPE] == SourceType.to_string(SourceType.PROJECT) 213 assert tags[MLFLOW_PROJECT_ENTRY_POINT] == "test_tracking" 214 215 216 def test_run_with_parent(): 217 with mlflow.start_run(): 218 parent_run_id = mlflow.active_run().info.run_id 219 submitted_run = mlflow.projects.run( 220 TEST_PROJECT_DIR, 221 entry_point="test_tracking", 222 parameters={"use_start_run": "1"}, 223 env_manager="local", 224 experiment_id=FileStore.DEFAULT_EXPERIMENT_ID, 225 ) 226 assert submitted_run.run_id is not None 227 validate_exit_status(submitted_run.get_status(), RunStatus.FINISHED) 228 run_id = submitted_run.run_id 229 run = MlflowClient().get_run(run_id) 230 assert run.data.tags[MLFLOW_PARENT_RUN_ID] == parent_run_id 231 232 233 def test_run_with_artifact_path(tmp_path): 234 artifact_file = tmp_path.joinpath("model.pkl") 235 artifact_file.write_text("Hello world") 236 with mlflow.start_run() as run: 237 mlflow.log_artifact(artifact_file) 238 submitted_run = mlflow.projects.run( 239 TEST_PROJECT_DIR, 240 entry_point="test_artifact_path", 241 parameters={"model": f"runs:/{run.info.run_id}/model.pkl"}, 242 env_manager="local", 243 experiment_id=FileStore.DEFAULT_EXPERIMENT_ID, 244 ) 245 validate_exit_status(submitted_run.get_status(), RunStatus.FINISHED) 246 247 248 def test_run_async(): 249 submitted_run0 = mlflow.projects.run( 250 TEST_PROJECT_DIR, 251 entry_point="sleep", 252 parameters={"duration": 2}, 253 env_manager="local", 254 experiment_id=FileStore.DEFAULT_EXPERIMENT_ID, 255 synchronous=False, 256 ) 257 validate_exit_status(submitted_run0.get_status(), RunStatus.RUNNING) 258 submitted_run0.wait() 259 validate_exit_status(submitted_run0.get_status(), RunStatus.FINISHED) 260 submitted_run1 = mlflow.projects.run( 261 TEST_PROJECT_DIR, 262 entry_point="sleep", 263 parameters={"duration": -1, "invalid-param": 30}, 264 env_manager="local", 265 experiment_id=FileStore.DEFAULT_EXPERIMENT_ID, 266 synchronous=False, 267 ) 268 submitted_run1.wait() 269 validate_exit_status(submitted_run1.get_status(), RunStatus.FAILED) 270 271 272 @pytest.mark.parametrize( 273 ("mock_env", "expected_conda", "expected_activate"), 274 [ 275 ( 276 {CONDA_EXE: "/abc/conda"}, 277 "/abc/conda", 278 "/abc/activate", 279 ), 280 ( 281 {MLFLOW_CONDA_HOME.name: "/some/dir/"}, 282 "/some/dir/bin/conda", 283 "/some/dir/bin/activate", 284 ), 285 ], 286 ) 287 def test_conda_path(mock_env, expected_conda, expected_activate, monkeypatch): 288 for name in [CONDA_EXE, MLFLOW_CONDA_HOME.name]: 289 monkeypatch.delenv(name, raising=False) 290 for name, value in mock_env.items(): 291 monkeypatch.setenv(name, value) 292 assert mlflow.utils.conda.get_conda_bin_executable("conda") == expected_conda 293 assert mlflow.utils.conda.get_conda_bin_executable("activate") == expected_activate 294 295 296 @pytest.mark.parametrize( 297 ("mock_env", "expected_conda_env_create_path"), 298 [ 299 ( 300 {CONDA_EXE: "/abc/conda"}, 301 "/abc/conda", 302 ), 303 ( 304 {CONDA_EXE: "/abc/conda", MLFLOW_CONDA_CREATE_ENV_CMD.name: "mamba"}, 305 "/abc/mamba", 306 ), 307 ( 308 {MLFLOW_CONDA_HOME.name: "/some/dir/"}, 309 "/some/dir/bin/conda", 310 ), 311 ( 312 {MLFLOW_CONDA_HOME.name: "/some/dir/", MLFLOW_CONDA_CREATE_ENV_CMD.name: "mamba"}, 313 "/some/dir/bin/mamba", 314 ), 315 ], 316 ) 317 def test_find_conda_executables(mock_env, expected_conda_env_create_path, monkeypatch): 318 """ 319 Verify that we correctly determine the path to executables to be used to 320 create environments (for example, it could be mamba instead of conda) 321 """ 322 monkeypatch.delenv(CONDA_EXE, raising=False) 323 monkeypatch.delenv(MLFLOW_CONDA_HOME.name, raising=False) 324 monkeypatch.delenv(MLFLOW_CONDA_CREATE_ENV_CMD.name, raising=False) 325 for name, value in mock_env.items(): 326 monkeypatch.setenv(name, value) 327 conda_env_create_path = mlflow.utils.conda._get_conda_executable_for_create_env() 328 assert conda_env_create_path == expected_conda_env_create_path 329 330 331 def test_create_env_with_mamba(monkeypatch): 332 """ 333 Test that mamba is called when set, and that we fail when mamba is not available or is 334 not working. We mock the calls so we do not actually execute mamba (which is not 335 installed in the test environment anyway) 336 """ 337 338 def exec_cmd_mock(cmd, *args, **kwargs): 339 if cmd[-1] == "--json": 340 # We are supposed to list environments in JSON format 341 return subprocess.CompletedProcess( 342 cmd, 0, json.dumps({"envs": ["mlflow-mock-environment"]}), None 343 ) 344 else: 345 # Here we are creating the environment, no need to return 346 # anything 347 return subprocess.CompletedProcess(cmd, 0) 348 349 def exec_cmd_mock_raise(cmd, *args, **kwargs): 350 if os.path.basename(cmd[0]) == "mamba": 351 raise OSError() 352 353 conda_env_path = os.path.join(TEST_PROJECT_DIR, "conda.yaml") 354 355 monkeypatch.setenv(MLFLOW_CONDA_CREATE_ENV_CMD.name, "mamba") 356 # Simulate success 357 with mock.patch("mlflow.utils.process._exec_cmd", side_effect=exec_cmd_mock): 358 mlflow.utils.conda.get_or_create_conda_env(conda_env_path) 359 360 # Simulate a non-working or non-existent mamba 361 with mock.patch("mlflow.utils.process._exec_cmd", side_effect=exec_cmd_mock_raise): 362 with pytest.raises( 363 ExecutionException, 364 match="You have set the env variable MLFLOW_CONDA_CREATE_ENV_CMD", 365 ): 366 mlflow.utils.conda.get_or_create_conda_env(conda_env_path) 367 368 369 def test_conda_environment_cleaned_up_when_pip_fails(tmp_path): 370 conda_yaml = tmp_path / "conda.yaml" 371 content = f""" 372 name: {uuid.uuid4().hex} 373 channels: 374 - conda-forge 375 dependencies: 376 - python={PYTHON_VERSION} 377 - pip 378 - pip: 379 - mlflow==999.999.999 380 """ 381 conda_yaml.write_text(content) 382 envs_before = mlflow.utils.conda._list_conda_environments() 383 384 # `conda create` should fail because mlflow 999.999.999 doesn't exist 385 with pytest.raises(ShellCommandException, match=r"No matching distribution found"): 386 mlflow.utils.conda.get_or_create_conda_env(conda_yaml, capture_output=True) 387 388 # Ensure the environment is cleaned up 389 envs_after = mlflow.utils.conda._list_conda_environments() 390 assert envs_before == envs_after 391 392 393 def test_cancel_run(): 394 submitted_run0, submitted_run1 = ( 395 mlflow.projects.run( 396 TEST_PROJECT_DIR, 397 entry_point="sleep", 398 parameters={"duration": 2}, 399 env_manager="local", 400 experiment_id=FileStore.DEFAULT_EXPERIMENT_ID, 401 synchronous=False, 402 ) 403 for _ in range(2) 404 ) 405 submitted_run0.cancel() 406 validate_exit_status(submitted_run0.get_status(), RunStatus.FAILED) 407 # Sanity check: cancelling one run has no effect on the other 408 assert submitted_run1.wait() 409 validate_exit_status(submitted_run1.get_status(), RunStatus.FINISHED) 410 # Try cancelling after calling wait() 411 submitted_run1.cancel() 412 validate_exit_status(submitted_run1.get_status(), RunStatus.FINISHED) 413 414 415 def test_parse_kubernetes_config(): 416 work_dir = "./examples/docker" 417 kubernetes_config = { 418 "kube-context": "docker-for-desktop", 419 "kube-job-template-path": os.path.join(work_dir, "kubernetes_job_template.yaml"), 420 "repository-uri": "dockerhub_account/mlflow-kubernetes-example", 421 } 422 yaml_obj = None 423 with open(kubernetes_config["kube-job-template-path"]) as job_template: 424 yaml_obj = yaml.safe_load(job_template.read()) 425 kube_config = _parse_kubernetes_config(kubernetes_config) 426 assert kube_config["kube-context"] == kubernetes_config["kube-context"] 427 assert kube_config["kube-job-template-path"] == kubernetes_config["kube-job-template-path"] 428 assert kube_config["repository-uri"] == kubernetes_config["repository-uri"] 429 assert kube_config["kube-job-template"] == yaml_obj 430 431 432 @pytest.fixture 433 def mock_kubernetes_job_template(tmp_path): 434 k8s_yaml = tmp_path.joinpath("kubernetes_job_template.yaml") 435 k8s_yaml.write_text( 436 """ 437 apiVersion: batch/v1 438 kind: Job 439 metadata: 440 name: "{replaced with MLflow Project name}" 441 namespace: mlflow 442 spec: 443 ttlSecondsAfterFinished: 100 444 backoffLimit: 0 445 template: 446 spec: 447 containers: 448 - name: "{replaced with MLflow Project name}" 449 image: "{replaced with URI of Docker image created during Project execution}" 450 command: ["{replaced with MLflow Project entry point command}"] 451 resources: 452 limits: 453 memory: 512Mi 454 requests: 455 memory: 256Mi 456 restartPolicy: Never 457 """.lstrip() 458 ) 459 return str(k8s_yaml) 460 461 462 class StartsWithMatcher: 463 def __init__(self, prefix): 464 self.prefix = prefix 465 466 def __eq__(self, other): 467 return isinstance(other, str) and other.startswith(self.prefix) 468 469 470 def test_parse_kubernetes_config_without_context(mock_kubernetes_job_template): 471 with mock.patch("mlflow.projects._logger.debug") as mock_debug: 472 kubernetes_config = { 473 "repository-uri": "dockerhub_account/mlflow-kubernetes-example", 474 "kube-job-template-path": mock_kubernetes_job_template, 475 } 476 _parse_kubernetes_config(kubernetes_config) 477 mock_debug.assert_called_once_with( 478 StartsWithMatcher("Could not find kube-context in backend_config") 479 ) 480 481 482 def test_parse_kubernetes_config_without_image_uri(mock_kubernetes_job_template): 483 kubernetes_config = { 484 "kube-context": "docker-for-desktop", 485 "kube-job-template-path": mock_kubernetes_job_template, 486 } 487 with pytest.raises(ExecutionException, match="Could not find 'repository-uri'"): 488 _parse_kubernetes_config(kubernetes_config) 489 490 491 def test_parse_kubernetes_config_invalid_template_job_file(): 492 kubernetes_config = { 493 "kube-context": "docker-for-desktop", 494 "repository-uri": "username/mlflow-kubernetes-example", 495 "kube-job-template-path": "file_not_found.yaml", 496 } 497 with pytest.raises(ExecutionException, match="Could not find 'kube-job-template-path'"): 498 _parse_kubernetes_config(kubernetes_config) 499 500 501 @pytest.mark.parametrize("synchronous", [True, False]) 502 def test_credential_propagation(synchronous, monkeypatch): 503 class DummyProcess: 504 def wait(self): 505 return 0 506 507 def poll(self): 508 return 0 509 510 def communicate(self, _): 511 return "", "" 512 513 monkeypatch.setenv("DATABRICKS_HOST", "host") 514 monkeypatch.setenv("DATABRICKS_TOKEN", "mytoken") 515 with ( 516 mock.patch("subprocess.Popen", return_value=DummyProcess()) as popen_mock, 517 mock.patch("mlflow.utils.uri.is_databricks_uri", return_value=True), 518 ): 519 mlflow.projects.run( 520 TEST_PROJECT_DIR, 521 entry_point="sleep", 522 experiment_id=FileStore.DEFAULT_EXPERIMENT_ID, 523 parameters={"duration": 2}, 524 env_manager="local", 525 synchronous=synchronous, 526 ) 527 _, kwargs = popen_mock.call_args 528 env = kwargs["env"] 529 assert env["DATABRICKS_HOST"] == "host" 530 assert env["DATABRICKS_TOKEN"] == "mytoken" 531 532 533 def test_get_or_create_conda_env_capture_output_mode(tmp_path): 534 conda_yaml_file = tmp_path / "conda.yaml" 535 conda_yaml_file.write_text( 536 """ 537 channels: 538 - conda-forge 539 dependencies: 540 - pip: 541 - scikit-learn==99.99.99 542 """ 543 ) 544 with pytest.raises( 545 ShellCommandException, 546 match="Could not find a version that satisfies the requirement scikit-learn==99.99.99", 547 ): 548 get_or_create_conda_env(str(conda_yaml_file), capture_output=True)