test_databricks.py
1 import filecmp 2 import json 3 import os 4 import shutil 5 from unittest import mock 6 7 import pytest 8 9 import mlflow 10 from mlflow import MlflowClient, cli 11 from mlflow.entities import RunStatus 12 from mlflow.environment_variables import MLFLOW_TRACKING_URI 13 from mlflow.exceptions import MlflowException 14 from mlflow.legacy_databricks_cli.configure.provider import DatabricksConfig 15 from mlflow.projects import ExecutionException, databricks 16 from mlflow.projects.databricks import DatabricksJobRunner, _get_cluster_mlflow_run_cmd 17 from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, ErrorCode 18 from mlflow.store.tracking.file_store import FileStore 19 from mlflow.tracking.request_header.default_request_header_provider import ( 20 DefaultRequestHeaderProvider, 21 ) 22 from mlflow.utils import file_utils 23 from mlflow.utils.mlflow_tags import ( 24 MLFLOW_DATABRICKS_RUN_URL, 25 MLFLOW_DATABRICKS_SHELL_JOB_RUN_ID, 26 MLFLOW_DATABRICKS_WEBAPP_URL, 27 ) 28 from mlflow.utils.rest_utils import MlflowHostCreds 29 from mlflow.utils.uri import construct_db_uri_from_profile 30 31 from tests import helper_functions 32 from tests.integration.utils import invoke_cli_runner 33 from tests.projects.utils import TEST_PROJECT_DIR, validate_exit_status 34 35 36 @pytest.fixture 37 def runs_cancel_mock(): 38 """Mocks the Jobs Runs Cancel API request""" 39 with mock.patch( 40 "mlflow.projects.databricks.DatabricksJobRunner.jobs_runs_cancel" 41 ) as runs_cancel_mock: 42 runs_cancel_mock.return_value = None 43 yield runs_cancel_mock 44 45 46 @pytest.fixture 47 def runs_submit_mock(): 48 """Mocks the Jobs Runs Submit API request""" 49 with mock.patch( 50 "mlflow.projects.databricks.DatabricksJobRunner._jobs_runs_submit", 51 return_value={"run_id": "-1"}, 52 ) as runs_submit_mock: 53 yield runs_submit_mock 54 55 56 @pytest.fixture 57 def runs_get_mock(): 58 """Mocks the Jobs Runs Get API request""" 59 with mock.patch( 60 "mlflow.projects.databricks.DatabricksJobRunner.jobs_runs_get" 61 ) as runs_get_mock: 62 yield runs_get_mock 63 64 65 @pytest.fixture 66 def databricks_cluster_mlflow_run_cmd_mock(): 67 """Mocks the Jobs Runs Get API request""" 68 with mock.patch( 69 "mlflow.projects.databricks._get_cluster_mlflow_run_cmd" 70 ) as mlflow_run_cmd_mock: 71 yield mlflow_run_cmd_mock 72 73 74 @pytest.fixture 75 def cluster_spec_mock(tmp_path): 76 cluster_spec_handle = tmp_path.joinpath("cluster_spec.json") 77 cluster_spec_handle.write_text("{}") 78 return str(cluster_spec_handle) 79 80 81 @pytest.fixture 82 def dbfs_root_mock(tmp_path): 83 return str(tmp_path.joinpath("dbfs-root")) 84 85 86 @pytest.fixture 87 def upload_to_dbfs_mock(dbfs_root_mock): 88 def upload_mock_fn(_, src_path, dbfs_uri): 89 mock_dbfs_dst = os.path.join(dbfs_root_mock, dbfs_uri.split("/dbfs/")[1]) 90 os.makedirs(os.path.dirname(mock_dbfs_dst)) 91 shutil.copy(src_path, mock_dbfs_dst) 92 93 with mock.patch.object( 94 mlflow.projects.databricks.DatabricksJobRunner, "_upload_to_dbfs", new=upload_mock_fn 95 ) as upload_mock: 96 yield upload_mock 97 98 99 @pytest.fixture 100 def dbfs_path_exists_mock(dbfs_root_mock): 101 with mock.patch( 102 "mlflow.projects.databricks.DatabricksJobRunner._dbfs_path_exists" 103 ) as path_exists_mock: 104 yield path_exists_mock 105 106 107 @pytest.fixture 108 def dbfs_mocks(dbfs_path_exists_mock, upload_to_dbfs_mock): 109 return 110 111 112 @pytest.fixture 113 def before_run_validations_mock(): 114 with mock.patch("mlflow.projects.databricks.before_run_validations"): 115 yield 116 117 118 @pytest.fixture 119 def set_tag_mock(): 120 with mock.patch("mlflow.projects.databricks.tracking.MlflowClient") as m: 121 mlflow_service_mock = mock.Mock(wraps=MlflowClient()) 122 m.return_value = mlflow_service_mock 123 yield mlflow_service_mock.set_tag 124 125 126 def _get_mock_run_state(succeeded): 127 if succeeded is None: 128 return {"life_cycle_state": "RUNNING", "state_message": ""} 129 run_result_state = "SUCCESS" if succeeded else "FAILED" 130 return {"life_cycle_state": "TERMINATED", "state_message": "", "result_state": run_result_state} 131 132 133 def mock_runs_get_result(succeeded): 134 run_state = _get_mock_run_state(succeeded) 135 return {"state": run_state, "run_page_url": "test_url"} 136 137 138 def run_databricks_project(cluster_spec, **kwargs): 139 return mlflow.projects.run( 140 uri=TEST_PROJECT_DIR, 141 backend="databricks", 142 backend_config=cluster_spec, 143 parameters={"alpha": "0.4"}, 144 **kwargs, 145 ) 146 147 148 def test_upload_project_to_dbfs( 149 dbfs_root_mock, tmp_path, dbfs_path_exists_mock, upload_to_dbfs_mock 150 ): 151 # Upload project to a mock directory 152 dbfs_path_exists_mock.return_value = False 153 runner = DatabricksJobRunner(databricks_profile_uri=construct_db_uri_from_profile("DEFAULT")) 154 dbfs_uri = runner._upload_project_to_dbfs( 155 project_dir=TEST_PROJECT_DIR, experiment_id=FileStore.DEFAULT_EXPERIMENT_ID 156 ) 157 # Get expected tar 158 local_tar_path = os.path.join(dbfs_root_mock, dbfs_uri.split("/dbfs/")[1]) 159 expected_tar_path = str(tmp_path.joinpath("expected.tar.gz")) 160 file_utils.make_tarfile( 161 output_filename=expected_tar_path, 162 source_dir=TEST_PROJECT_DIR, 163 archive_name=databricks.DB_TARFILE_ARCHIVE_NAME, 164 ) 165 # Extract the tarred project, verify its contents 166 assert filecmp.cmp(local_tar_path, expected_tar_path, shallow=False) 167 168 169 def test_upload_existing_project_to_dbfs(dbfs_path_exists_mock): 170 # Check that we don't upload the project if it already exists on DBFS 171 with mock.patch( 172 "mlflow.projects.databricks.DatabricksJobRunner._upload_to_dbfs" 173 ) as upload_to_dbfs_mock: 174 dbfs_path_exists_mock.return_value = True 175 runner = DatabricksJobRunner( 176 databricks_profile_uri=construct_db_uri_from_profile("DEFAULT") 177 ) 178 runner._upload_project_to_dbfs( 179 project_dir=TEST_PROJECT_DIR, experiment_id=FileStore.DEFAULT_EXPERIMENT_ID 180 ) 181 assert upload_to_dbfs_mock.call_count == 0 182 183 184 @pytest.mark.parametrize( 185 "response_mock", 186 [ 187 helper_functions.create_mock_response(400, "Error message but not a JSON string"), 188 helper_functions.create_mock_response(400, ""), 189 helper_functions.create_mock_response(400, None), 190 ], 191 ) 192 def test_dbfs_path_exists_error_response_handling(response_mock): 193 with ( 194 mock.patch( 195 "mlflow.utils.databricks_utils.get_databricks_host_creds" 196 ) as get_databricks_host_creds_mock, 197 mock.patch("mlflow.utils.rest_utils.http_request") as http_request_mock, 198 ): 199 # given a well formed DatabricksJobRunner 200 # note: databricks_profile is None needed because clients using profile are mocked 201 job_runner = DatabricksJobRunner(databricks_profile_uri=None) 202 203 # when the http request to validate the dbfs path returns a 400 response with an 204 # error message that is either well-formed JSON or not 205 get_databricks_host_creds_mock.return_value = None 206 http_request_mock.return_value = response_mock 207 208 # then _dbfs_path_exists should return a MlflowException 209 with pytest.raises(MlflowException, match="API request to check existence of file at DBFS"): 210 job_runner._dbfs_path_exists("some/path") 211 212 213 def test_run_databricks_validations( 214 tmp_path, 215 monkeypatch, 216 cluster_spec_mock, 217 dbfs_mocks, 218 set_tag_mock, 219 ): 220 """ 221 Tests that running on Databricks fails before making any API requests if validations fail. 222 """ 223 monkeypatch.setenv("DATABRICKS_HOST", "test-host") 224 monkeypatch.setenv("DATABRICKS_TOKEN", "foo") 225 with mock.patch( 226 "mlflow.projects.databricks.DatabricksJobRunner._databricks_api_request" 227 ) as db_api_req_mock: 228 # Test bad tracking URI 229 mlflow.set_tracking_uri(f"sqlite:///{tmp_path / 'mlflow.db'}") 230 with pytest.raises(ExecutionException, match="MLflow tracking URI must be of"): 231 run_databricks_project(cluster_spec_mock, synchronous=True) 232 assert db_api_req_mock.call_count == 0 233 db_api_req_mock.reset_mock() 234 mlflow_service = MlflowClient() 235 assert len(mlflow_service.search_runs([FileStore.DEFAULT_EXPERIMENT_ID])) == 0 236 mlflow.set_tracking_uri("databricks") 237 # Test misspecified parameters 238 with pytest.raises( 239 ExecutionException, match="No value given for missing parameters: 'name'" 240 ): 241 mlflow.projects.run( 242 TEST_PROJECT_DIR, 243 backend="databricks", 244 entry_point="greeter", 245 backend_config=cluster_spec_mock, 246 ) 247 assert db_api_req_mock.call_count == 0 248 db_api_req_mock.reset_mock() 249 # Test bad cluster spec 250 with pytest.raises(ExecutionException, match="Backend spec must be provided"): 251 mlflow.projects.run( 252 TEST_PROJECT_DIR, backend="databricks", synchronous=True, backend_config=None 253 ) 254 assert db_api_req_mock.call_count == 0 255 db_api_req_mock.reset_mock() 256 # Test that validations pass with good tracking URIs 257 databricks.before_run_validations("http://", cluster_spec_mock) 258 databricks.before_run_validations("databricks", cluster_spec_mock) 259 260 261 @pytest.mark.usefixtures( 262 "before_run_validations_mock", 263 "runs_cancel_mock", 264 "dbfs_mocks", 265 "databricks_cluster_mlflow_run_cmd_mock", 266 ) 267 def test_run_databricks( 268 runs_submit_mock, 269 runs_get_mock, 270 cluster_spec_mock, 271 set_tag_mock, 272 databricks_cluster_mlflow_run_cmd_mock, 273 monkeypatch, 274 ): 275 monkeypatch.setenv("DATABRICKS_HOST", "https://test-host") 276 monkeypatch.setenv("DATABRICKS_TOKEN", "foo") 277 mlflow.set_tracking_uri("databricks") 278 # Test that MLflow gets the correct run status when performing a Databricks run 279 for run_succeeded, expect_status in [(True, RunStatus.FINISHED), (False, RunStatus.FAILED)]: 280 runs_get_mock.return_value = mock_runs_get_result(succeeded=run_succeeded) 281 submitted_run = run_databricks_project(cluster_spec_mock, synchronous=False) 282 assert submitted_run.wait() == run_succeeded 283 assert submitted_run.run_id is not None 284 assert runs_submit_mock.call_count == 1 285 assert databricks_cluster_mlflow_run_cmd_mock.call_count == 1 286 tags = {} 287 for call_args, _ in set_tag_mock.call_args_list: 288 tags[call_args[1]] = call_args[2] 289 assert tags[MLFLOW_DATABRICKS_RUN_URL] == "test_url" 290 assert tags[MLFLOW_DATABRICKS_SHELL_JOB_RUN_ID] == "-1" 291 assert tags[MLFLOW_DATABRICKS_WEBAPP_URL] == "https://test-host" 292 set_tag_mock.reset_mock() 293 runs_submit_mock.reset_mock() 294 databricks_cluster_mlflow_run_cmd_mock.reset_mock() 295 validate_exit_status(submitted_run.get_status(), expect_status) 296 297 298 @pytest.mark.usefixtures( 299 "before_run_validations_mock", 300 "runs_cancel_mock", 301 "dbfs_mocks", 302 "cluster_spec_mock", 303 "set_tag_mock", 304 ) 305 def test_run_databricks_cluster_spec_json(runs_submit_mock, runs_get_mock, monkeypatch): 306 monkeypatch.setenv("DATABRICKS_HOST", "test-host") 307 monkeypatch.setenv("DATABRICKS_TOKEN", "foo") 308 runs_get_mock.return_value = mock_runs_get_result(succeeded=True) 309 cluster_spec = { 310 "spark_version": "5.0.x-scala2.11", 311 "num_workers": 2, 312 "node_type_id": "i3.xlarge", 313 } 314 # Run project synchronously, verify that it succeeds (doesn't throw) 315 run_databricks_project(cluster_spec=cluster_spec, synchronous=True) 316 assert runs_submit_mock.call_count == 1 317 runs_submit_args, _ = runs_submit_mock.call_args_list[0] 318 req_body = runs_submit_args[0] 319 assert req_body["new_cluster"] == cluster_spec 320 321 322 @pytest.mark.usefixtures( 323 "before_run_validations_mock", 324 "runs_cancel_mock", 325 "dbfs_mocks", 326 "cluster_spec_mock", 327 "set_tag_mock", 328 ) 329 def test_run_databricks_extended_cluster_spec_json(runs_submit_mock, runs_get_mock, monkeypatch): 330 monkeypatch.setenv("DATABRICKS_HOST", "test-host") 331 monkeypatch.setenv("DATABRICKS_TOKEN", "foo") 332 runs_get_mock.return_value = mock_runs_get_result(succeeded=True) 333 new_cluster_spec = { 334 "spark_version": "6.5.x-scala2.11", 335 "num_workers": 2, 336 "node_type_id": "i3.xlarge", 337 } 338 extra_library = {"pypi": {"package": "tensorflow"}} 339 340 cluster_spec = {"new_cluster": new_cluster_spec, "libraries": [extra_library]} 341 342 # Run project synchronously, verify that it succeeds (doesn't throw) 343 run_databricks_project(cluster_spec=cluster_spec, synchronous=True) 344 assert runs_submit_mock.call_count == 1 345 runs_submit_args, _ = runs_submit_mock.call_args_list[0] 346 req_body = runs_submit_args[0] 347 assert req_body["new_cluster"] == new_cluster_spec 348 # This does test deep object equivalence 349 assert extra_library in req_body["libraries"] 350 351 352 @pytest.mark.usefixtures( 353 "before_run_validations_mock", 354 "runs_cancel_mock", 355 "dbfs_mocks", 356 "cluster_spec_mock", 357 "set_tag_mock", 358 ) 359 def test_run_databricks_extended_cluster_spec_json_without_libraries( 360 runs_submit_mock, runs_get_mock, monkeypatch 361 ): 362 monkeypatch.setenv("DATABRICKS_HOST", "test-host") 363 monkeypatch.setenv("DATABRICKS_TOKEN", "foo") 364 runs_get_mock.return_value = mock_runs_get_result(succeeded=True) 365 new_cluster_spec = { 366 "spark_version": "6.5.x-scala2.11", 367 "num_workers": 2, 368 "node_type_id": "i3.xlarge", 369 } 370 371 cluster_spec = { 372 "new_cluster": new_cluster_spec, 373 } 374 375 # Run project synchronously, verify that it succeeds (doesn't throw) 376 run_databricks_project(cluster_spec=cluster_spec, synchronous=True) 377 assert runs_submit_mock.call_count == 1 378 runs_submit_args, _ = runs_submit_mock.call_args_list[0] 379 req_body = runs_submit_args[0] 380 assert req_body["new_cluster"] == new_cluster_spec 381 382 383 def test_run_databricks_throws_exception_when_spec_uses_existing_cluster(monkeypatch): 384 monkeypatch.setenv("DATABRICKS_HOST", "test-host") 385 monkeypatch.setenv("DATABRICKS_TOKEN", "foo") 386 existing_cluster_spec = { 387 "existing_cluster_id": "1000-123456-clust1", 388 } 389 with pytest.raises( 390 MlflowException, match="execution against existing clusters is not currently supported" 391 ) as exc: 392 run_databricks_project(cluster_spec=existing_cluster_spec) 393 assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 394 395 396 def test_run_databricks_cancel( 397 before_run_validations_mock, 398 runs_submit_mock, 399 dbfs_mocks, 400 set_tag_mock, 401 runs_cancel_mock, 402 runs_get_mock, 403 cluster_spec_mock, 404 monkeypatch, 405 ): 406 # Test that MLflow properly handles Databricks run cancellation. We mock the result of 407 # the runs-get API to indicate run failure so that cancel() exits instead of blocking while 408 # waiting for run status. 409 monkeypatch.setenv("DATABRICKS_HOST", "test-host") 410 monkeypatch.setenv("DATABRICKS_TOKEN", "foo") 411 runs_get_mock.return_value = mock_runs_get_result(succeeded=False) 412 submitted_run = run_databricks_project(cluster_spec_mock, synchronous=False) 413 submitted_run.cancel() 414 validate_exit_status(submitted_run.get_status(), RunStatus.FAILED) 415 assert runs_cancel_mock.call_count == 1 416 # Test that we raise an exception when a blocking Databricks run fails 417 runs_get_mock.return_value = mock_runs_get_result(succeeded=False) 418 with pytest.raises(mlflow.projects.ExecutionException, match=r"Run \(ID '.+'\) failed"): 419 run_databricks_project(cluster_spec_mock, synchronous=True) 420 421 422 def test_get_tracking_uri_for_run(monkeypatch): 423 mlflow.set_tracking_uri("http://some-uri") 424 assert databricks._get_tracking_uri_for_run() == "http://some-uri" 425 mlflow.set_tracking_uri("databricks://profile") 426 assert databricks._get_tracking_uri_for_run() == "databricks" 427 mlflow.set_tracking_uri(None) 428 monkeypatch.setenv(MLFLOW_TRACKING_URI.name, "http://some-uri") 429 assert mlflow.tracking._tracking_service.utils.get_tracking_uri() == "http://some-uri" 430 431 432 class MockProfileConfigProvider: 433 def __init__(self, profile): 434 assert profile == "my-profile" 435 436 def get_config(self): 437 return DatabricksConfig.from_password("host", "user", "pass", insecure=False) 438 439 440 def test_databricks_http_request_integration(): 441 def confirm_request_params(*args, **kwargs): 442 headers = DefaultRequestHeaderProvider().request_headers() 443 headers["Authorization"] = "Basic dXNlcjpwYXNz" 444 assert args == ("PUT", "host/clusters/list") 445 assert kwargs == { 446 "allow_redirects": True, 447 "headers": headers, 448 "verify": True, 449 "json": {"a": "b"}, 450 "timeout": 120, 451 } 452 http_response = mock.MagicMock() 453 http_response.status_code = 200 454 http_response.text = '{"OK": "woo"}' 455 return http_response 456 457 with ( 458 mock.patch("requests.Session.request", side_effect=confirm_request_params), 459 mock.patch( 460 "mlflow.utils.databricks_utils.get_databricks_host_creds", 461 return_value=MlflowHostCreds( 462 host="host", username="user", password="pass", ignore_tls_verification=False 463 ), 464 ), 465 ): 466 response = DatabricksJobRunner(databricks_profile_uri=None)._databricks_api_request( 467 "/clusters/list", "PUT", json={"a": "b"} 468 ) 469 assert json.loads(response.text) == {"OK": "woo"} 470 471 472 def test_run_databricks_failed(): 473 text = '{"error_code": "RESOURCE_DOES_NOT_EXIST", "message": "Node type not supported"}' 474 with ( 475 mock.patch("mlflow.utils.databricks_utils.get_databricks_host_creds"), 476 mock.patch( 477 "mlflow.utils.rest_utils.http_request", 478 return_value=mock.Mock(text=text, status_code=400), 479 ), 480 ): 481 runner = DatabricksJobRunner(construct_db_uri_from_profile("profile")) 482 with pytest.raises( 483 MlflowException, match="RESOURCE_DOES_NOT_EXIST: Node type not supported" 484 ): 485 runner._run_shell_command_job("/project", "command", {}, {}) 486 487 488 def test_run_databricks_generates_valid_mlflow_run_cmd(): 489 cmd = _get_cluster_mlflow_run_cmd( 490 project_dir="my_project_dir", 491 run_id="hi", 492 entry_point="main", 493 parameters={"a": "b"}, 494 env_manager="conda", 495 ) 496 assert cmd[0] == "mlflow" 497 with mock.patch("mlflow.projects.run"): 498 invoke_cli_runner(cli.cli, cmd[1:])