/ tests / projects / test_databricks.py
test_databricks.py
  1  import filecmp
  2  import json
  3  import os
  4  import shutil
  5  from unittest import mock
  6  
  7  import pytest
  8  
  9  import mlflow
 10  from mlflow import MlflowClient, cli
 11  from mlflow.entities import RunStatus
 12  from mlflow.environment_variables import MLFLOW_TRACKING_URI
 13  from mlflow.exceptions import MlflowException
 14  from mlflow.legacy_databricks_cli.configure.provider import DatabricksConfig
 15  from mlflow.projects import ExecutionException, databricks
 16  from mlflow.projects.databricks import DatabricksJobRunner, _get_cluster_mlflow_run_cmd
 17  from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, ErrorCode
 18  from mlflow.store.tracking.file_store import FileStore
 19  from mlflow.tracking.request_header.default_request_header_provider import (
 20      DefaultRequestHeaderProvider,
 21  )
 22  from mlflow.utils import file_utils
 23  from mlflow.utils.mlflow_tags import (
 24      MLFLOW_DATABRICKS_RUN_URL,
 25      MLFLOW_DATABRICKS_SHELL_JOB_RUN_ID,
 26      MLFLOW_DATABRICKS_WEBAPP_URL,
 27  )
 28  from mlflow.utils.rest_utils import MlflowHostCreds
 29  from mlflow.utils.uri import construct_db_uri_from_profile
 30  
 31  from tests import helper_functions
 32  from tests.integration.utils import invoke_cli_runner
 33  from tests.projects.utils import TEST_PROJECT_DIR, validate_exit_status
 34  
 35  
 36  @pytest.fixture
 37  def runs_cancel_mock():
 38      """Mocks the Jobs Runs Cancel API request"""
 39      with mock.patch(
 40          "mlflow.projects.databricks.DatabricksJobRunner.jobs_runs_cancel"
 41      ) as runs_cancel_mock:
 42          runs_cancel_mock.return_value = None
 43          yield runs_cancel_mock
 44  
 45  
 46  @pytest.fixture
 47  def runs_submit_mock():
 48      """Mocks the Jobs Runs Submit API request"""
 49      with mock.patch(
 50          "mlflow.projects.databricks.DatabricksJobRunner._jobs_runs_submit",
 51          return_value={"run_id": "-1"},
 52      ) as runs_submit_mock:
 53          yield runs_submit_mock
 54  
 55  
 56  @pytest.fixture
 57  def runs_get_mock():
 58      """Mocks the Jobs Runs Get API request"""
 59      with mock.patch(
 60          "mlflow.projects.databricks.DatabricksJobRunner.jobs_runs_get"
 61      ) as runs_get_mock:
 62          yield runs_get_mock
 63  
 64  
 65  @pytest.fixture
 66  def databricks_cluster_mlflow_run_cmd_mock():
 67      """Mocks the Jobs Runs Get API request"""
 68      with mock.patch(
 69          "mlflow.projects.databricks._get_cluster_mlflow_run_cmd"
 70      ) as mlflow_run_cmd_mock:
 71          yield mlflow_run_cmd_mock
 72  
 73  
 74  @pytest.fixture
 75  def cluster_spec_mock(tmp_path):
 76      cluster_spec_handle = tmp_path.joinpath("cluster_spec.json")
 77      cluster_spec_handle.write_text("{}")
 78      return str(cluster_spec_handle)
 79  
 80  
 81  @pytest.fixture
 82  def dbfs_root_mock(tmp_path):
 83      return str(tmp_path.joinpath("dbfs-root"))
 84  
 85  
 86  @pytest.fixture
 87  def upload_to_dbfs_mock(dbfs_root_mock):
 88      def upload_mock_fn(_, src_path, dbfs_uri):
 89          mock_dbfs_dst = os.path.join(dbfs_root_mock, dbfs_uri.split("/dbfs/")[1])
 90          os.makedirs(os.path.dirname(mock_dbfs_dst))
 91          shutil.copy(src_path, mock_dbfs_dst)
 92  
 93      with mock.patch.object(
 94          mlflow.projects.databricks.DatabricksJobRunner, "_upload_to_dbfs", new=upload_mock_fn
 95      ) as upload_mock:
 96          yield upload_mock
 97  
 98  
 99  @pytest.fixture
100  def dbfs_path_exists_mock(dbfs_root_mock):
101      with mock.patch(
102          "mlflow.projects.databricks.DatabricksJobRunner._dbfs_path_exists"
103      ) as path_exists_mock:
104          yield path_exists_mock
105  
106  
107  @pytest.fixture
108  def dbfs_mocks(dbfs_path_exists_mock, upload_to_dbfs_mock):
109      return
110  
111  
112  @pytest.fixture
113  def before_run_validations_mock():
114      with mock.patch("mlflow.projects.databricks.before_run_validations"):
115          yield
116  
117  
118  @pytest.fixture
119  def set_tag_mock():
120      with mock.patch("mlflow.projects.databricks.tracking.MlflowClient") as m:
121          mlflow_service_mock = mock.Mock(wraps=MlflowClient())
122          m.return_value = mlflow_service_mock
123          yield mlflow_service_mock.set_tag
124  
125  
126  def _get_mock_run_state(succeeded):
127      if succeeded is None:
128          return {"life_cycle_state": "RUNNING", "state_message": ""}
129      run_result_state = "SUCCESS" if succeeded else "FAILED"
130      return {"life_cycle_state": "TERMINATED", "state_message": "", "result_state": run_result_state}
131  
132  
133  def mock_runs_get_result(succeeded):
134      run_state = _get_mock_run_state(succeeded)
135      return {"state": run_state, "run_page_url": "test_url"}
136  
137  
138  def run_databricks_project(cluster_spec, **kwargs):
139      return mlflow.projects.run(
140          uri=TEST_PROJECT_DIR,
141          backend="databricks",
142          backend_config=cluster_spec,
143          parameters={"alpha": "0.4"},
144          **kwargs,
145      )
146  
147  
148  def test_upload_project_to_dbfs(
149      dbfs_root_mock, tmp_path, dbfs_path_exists_mock, upload_to_dbfs_mock
150  ):
151      # Upload project to a mock directory
152      dbfs_path_exists_mock.return_value = False
153      runner = DatabricksJobRunner(databricks_profile_uri=construct_db_uri_from_profile("DEFAULT"))
154      dbfs_uri = runner._upload_project_to_dbfs(
155          project_dir=TEST_PROJECT_DIR, experiment_id=FileStore.DEFAULT_EXPERIMENT_ID
156      )
157      # Get expected tar
158      local_tar_path = os.path.join(dbfs_root_mock, dbfs_uri.split("/dbfs/")[1])
159      expected_tar_path = str(tmp_path.joinpath("expected.tar.gz"))
160      file_utils.make_tarfile(
161          output_filename=expected_tar_path,
162          source_dir=TEST_PROJECT_DIR,
163          archive_name=databricks.DB_TARFILE_ARCHIVE_NAME,
164      )
165      # Extract the tarred project, verify its contents
166      assert filecmp.cmp(local_tar_path, expected_tar_path, shallow=False)
167  
168  
169  def test_upload_existing_project_to_dbfs(dbfs_path_exists_mock):
170      # Check that we don't upload the project if it already exists on DBFS
171      with mock.patch(
172          "mlflow.projects.databricks.DatabricksJobRunner._upload_to_dbfs"
173      ) as upload_to_dbfs_mock:
174          dbfs_path_exists_mock.return_value = True
175          runner = DatabricksJobRunner(
176              databricks_profile_uri=construct_db_uri_from_profile("DEFAULT")
177          )
178          runner._upload_project_to_dbfs(
179              project_dir=TEST_PROJECT_DIR, experiment_id=FileStore.DEFAULT_EXPERIMENT_ID
180          )
181          assert upload_to_dbfs_mock.call_count == 0
182  
183  
184  @pytest.mark.parametrize(
185      "response_mock",
186      [
187          helper_functions.create_mock_response(400, "Error message but not a JSON string"),
188          helper_functions.create_mock_response(400, ""),
189          helper_functions.create_mock_response(400, None),
190      ],
191  )
192  def test_dbfs_path_exists_error_response_handling(response_mock):
193      with (
194          mock.patch(
195              "mlflow.utils.databricks_utils.get_databricks_host_creds"
196          ) as get_databricks_host_creds_mock,
197          mock.patch("mlflow.utils.rest_utils.http_request") as http_request_mock,
198      ):
199          # given a well formed DatabricksJobRunner
200          # note: databricks_profile is None needed because clients using profile are mocked
201          job_runner = DatabricksJobRunner(databricks_profile_uri=None)
202  
203          # when the http request to validate the dbfs path returns a 400 response with an
204          # error message that is either well-formed JSON or not
205          get_databricks_host_creds_mock.return_value = None
206          http_request_mock.return_value = response_mock
207  
208          # then _dbfs_path_exists should return a MlflowException
209          with pytest.raises(MlflowException, match="API request to check existence of file at DBFS"):
210              job_runner._dbfs_path_exists("some/path")
211  
212  
213  def test_run_databricks_validations(
214      tmp_path,
215      monkeypatch,
216      cluster_spec_mock,
217      dbfs_mocks,
218      set_tag_mock,
219  ):
220      """
221      Tests that running on Databricks fails before making any API requests if validations fail.
222      """
223      monkeypatch.setenv("DATABRICKS_HOST", "test-host")
224      monkeypatch.setenv("DATABRICKS_TOKEN", "foo")
225      with mock.patch(
226          "mlflow.projects.databricks.DatabricksJobRunner._databricks_api_request"
227      ) as db_api_req_mock:
228          # Test bad tracking URI
229          mlflow.set_tracking_uri(f"sqlite:///{tmp_path / 'mlflow.db'}")
230          with pytest.raises(ExecutionException, match="MLflow tracking URI must be of"):
231              run_databricks_project(cluster_spec_mock, synchronous=True)
232          assert db_api_req_mock.call_count == 0
233          db_api_req_mock.reset_mock()
234          mlflow_service = MlflowClient()
235          assert len(mlflow_service.search_runs([FileStore.DEFAULT_EXPERIMENT_ID])) == 0
236          mlflow.set_tracking_uri("databricks")
237          # Test misspecified parameters
238          with pytest.raises(
239              ExecutionException, match="No value given for missing parameters: 'name'"
240          ):
241              mlflow.projects.run(
242                  TEST_PROJECT_DIR,
243                  backend="databricks",
244                  entry_point="greeter",
245                  backend_config=cluster_spec_mock,
246              )
247          assert db_api_req_mock.call_count == 0
248          db_api_req_mock.reset_mock()
249          # Test bad cluster spec
250          with pytest.raises(ExecutionException, match="Backend spec must be provided"):
251              mlflow.projects.run(
252                  TEST_PROJECT_DIR, backend="databricks", synchronous=True, backend_config=None
253              )
254          assert db_api_req_mock.call_count == 0
255          db_api_req_mock.reset_mock()
256          # Test that validations pass with good tracking URIs
257          databricks.before_run_validations("http://", cluster_spec_mock)
258          databricks.before_run_validations("databricks", cluster_spec_mock)
259  
260  
261  @pytest.mark.usefixtures(
262      "before_run_validations_mock",
263      "runs_cancel_mock",
264      "dbfs_mocks",
265      "databricks_cluster_mlflow_run_cmd_mock",
266  )
267  def test_run_databricks(
268      runs_submit_mock,
269      runs_get_mock,
270      cluster_spec_mock,
271      set_tag_mock,
272      databricks_cluster_mlflow_run_cmd_mock,
273      monkeypatch,
274  ):
275      monkeypatch.setenv("DATABRICKS_HOST", "https://test-host")
276      monkeypatch.setenv("DATABRICKS_TOKEN", "foo")
277      mlflow.set_tracking_uri("databricks")
278      # Test that MLflow gets the correct run status when performing a Databricks run
279      for run_succeeded, expect_status in [(True, RunStatus.FINISHED), (False, RunStatus.FAILED)]:
280          runs_get_mock.return_value = mock_runs_get_result(succeeded=run_succeeded)
281          submitted_run = run_databricks_project(cluster_spec_mock, synchronous=False)
282          assert submitted_run.wait() == run_succeeded
283          assert submitted_run.run_id is not None
284          assert runs_submit_mock.call_count == 1
285          assert databricks_cluster_mlflow_run_cmd_mock.call_count == 1
286          tags = {}
287          for call_args, _ in set_tag_mock.call_args_list:
288              tags[call_args[1]] = call_args[2]
289          assert tags[MLFLOW_DATABRICKS_RUN_URL] == "test_url"
290          assert tags[MLFLOW_DATABRICKS_SHELL_JOB_RUN_ID] == "-1"
291          assert tags[MLFLOW_DATABRICKS_WEBAPP_URL] == "https://test-host"
292          set_tag_mock.reset_mock()
293          runs_submit_mock.reset_mock()
294          databricks_cluster_mlflow_run_cmd_mock.reset_mock()
295          validate_exit_status(submitted_run.get_status(), expect_status)
296  
297  
298  @pytest.mark.usefixtures(
299      "before_run_validations_mock",
300      "runs_cancel_mock",
301      "dbfs_mocks",
302      "cluster_spec_mock",
303      "set_tag_mock",
304  )
305  def test_run_databricks_cluster_spec_json(runs_submit_mock, runs_get_mock, monkeypatch):
306      monkeypatch.setenv("DATABRICKS_HOST", "test-host")
307      monkeypatch.setenv("DATABRICKS_TOKEN", "foo")
308      runs_get_mock.return_value = mock_runs_get_result(succeeded=True)
309      cluster_spec = {
310          "spark_version": "5.0.x-scala2.11",
311          "num_workers": 2,
312          "node_type_id": "i3.xlarge",
313      }
314      # Run project synchronously, verify that it succeeds (doesn't throw)
315      run_databricks_project(cluster_spec=cluster_spec, synchronous=True)
316      assert runs_submit_mock.call_count == 1
317      runs_submit_args, _ = runs_submit_mock.call_args_list[0]
318      req_body = runs_submit_args[0]
319      assert req_body["new_cluster"] == cluster_spec
320  
321  
322  @pytest.mark.usefixtures(
323      "before_run_validations_mock",
324      "runs_cancel_mock",
325      "dbfs_mocks",
326      "cluster_spec_mock",
327      "set_tag_mock",
328  )
329  def test_run_databricks_extended_cluster_spec_json(runs_submit_mock, runs_get_mock, monkeypatch):
330      monkeypatch.setenv("DATABRICKS_HOST", "test-host")
331      monkeypatch.setenv("DATABRICKS_TOKEN", "foo")
332      runs_get_mock.return_value = mock_runs_get_result(succeeded=True)
333      new_cluster_spec = {
334          "spark_version": "6.5.x-scala2.11",
335          "num_workers": 2,
336          "node_type_id": "i3.xlarge",
337      }
338      extra_library = {"pypi": {"package": "tensorflow"}}
339  
340      cluster_spec = {"new_cluster": new_cluster_spec, "libraries": [extra_library]}
341  
342      # Run project synchronously, verify that it succeeds (doesn't throw)
343      run_databricks_project(cluster_spec=cluster_spec, synchronous=True)
344      assert runs_submit_mock.call_count == 1
345      runs_submit_args, _ = runs_submit_mock.call_args_list[0]
346      req_body = runs_submit_args[0]
347      assert req_body["new_cluster"] == new_cluster_spec
348      # This does test deep object equivalence
349      assert extra_library in req_body["libraries"]
350  
351  
352  @pytest.mark.usefixtures(
353      "before_run_validations_mock",
354      "runs_cancel_mock",
355      "dbfs_mocks",
356      "cluster_spec_mock",
357      "set_tag_mock",
358  )
359  def test_run_databricks_extended_cluster_spec_json_without_libraries(
360      runs_submit_mock, runs_get_mock, monkeypatch
361  ):
362      monkeypatch.setenv("DATABRICKS_HOST", "test-host")
363      monkeypatch.setenv("DATABRICKS_TOKEN", "foo")
364      runs_get_mock.return_value = mock_runs_get_result(succeeded=True)
365      new_cluster_spec = {
366          "spark_version": "6.5.x-scala2.11",
367          "num_workers": 2,
368          "node_type_id": "i3.xlarge",
369      }
370  
371      cluster_spec = {
372          "new_cluster": new_cluster_spec,
373      }
374  
375      # Run project synchronously, verify that it succeeds (doesn't throw)
376      run_databricks_project(cluster_spec=cluster_spec, synchronous=True)
377      assert runs_submit_mock.call_count == 1
378      runs_submit_args, _ = runs_submit_mock.call_args_list[0]
379      req_body = runs_submit_args[0]
380      assert req_body["new_cluster"] == new_cluster_spec
381  
382  
383  def test_run_databricks_throws_exception_when_spec_uses_existing_cluster(monkeypatch):
384      monkeypatch.setenv("DATABRICKS_HOST", "test-host")
385      monkeypatch.setenv("DATABRICKS_TOKEN", "foo")
386      existing_cluster_spec = {
387          "existing_cluster_id": "1000-123456-clust1",
388      }
389      with pytest.raises(
390          MlflowException, match="execution against existing clusters is not currently supported"
391      ) as exc:
392          run_databricks_project(cluster_spec=existing_cluster_spec)
393      assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
394  
395  
396  def test_run_databricks_cancel(
397      before_run_validations_mock,
398      runs_submit_mock,
399      dbfs_mocks,
400      set_tag_mock,
401      runs_cancel_mock,
402      runs_get_mock,
403      cluster_spec_mock,
404      monkeypatch,
405  ):
406      # Test that MLflow properly handles Databricks run cancellation. We mock the result of
407      # the runs-get API to indicate run failure so that cancel() exits instead of blocking while
408      # waiting for run status.
409      monkeypatch.setenv("DATABRICKS_HOST", "test-host")
410      monkeypatch.setenv("DATABRICKS_TOKEN", "foo")
411      runs_get_mock.return_value = mock_runs_get_result(succeeded=False)
412      submitted_run = run_databricks_project(cluster_spec_mock, synchronous=False)
413      submitted_run.cancel()
414      validate_exit_status(submitted_run.get_status(), RunStatus.FAILED)
415      assert runs_cancel_mock.call_count == 1
416      # Test that we raise an exception when a blocking Databricks run fails
417      runs_get_mock.return_value = mock_runs_get_result(succeeded=False)
418      with pytest.raises(mlflow.projects.ExecutionException, match=r"Run \(ID '.+'\) failed"):
419          run_databricks_project(cluster_spec_mock, synchronous=True)
420  
421  
422  def test_get_tracking_uri_for_run(monkeypatch):
423      mlflow.set_tracking_uri("http://some-uri")
424      assert databricks._get_tracking_uri_for_run() == "http://some-uri"
425      mlflow.set_tracking_uri("databricks://profile")
426      assert databricks._get_tracking_uri_for_run() == "databricks"
427      mlflow.set_tracking_uri(None)
428      monkeypatch.setenv(MLFLOW_TRACKING_URI.name, "http://some-uri")
429      assert mlflow.tracking._tracking_service.utils.get_tracking_uri() == "http://some-uri"
430  
431  
432  class MockProfileConfigProvider:
433      def __init__(self, profile):
434          assert profile == "my-profile"
435  
436      def get_config(self):
437          return DatabricksConfig.from_password("host", "user", "pass", insecure=False)
438  
439  
440  def test_databricks_http_request_integration():
441      def confirm_request_params(*args, **kwargs):
442          headers = DefaultRequestHeaderProvider().request_headers()
443          headers["Authorization"] = "Basic dXNlcjpwYXNz"
444          assert args == ("PUT", "host/clusters/list")
445          assert kwargs == {
446              "allow_redirects": True,
447              "headers": headers,
448              "verify": True,
449              "json": {"a": "b"},
450              "timeout": 120,
451          }
452          http_response = mock.MagicMock()
453          http_response.status_code = 200
454          http_response.text = '{"OK": "woo"}'
455          return http_response
456  
457      with (
458          mock.patch("requests.Session.request", side_effect=confirm_request_params),
459          mock.patch(
460              "mlflow.utils.databricks_utils.get_databricks_host_creds",
461              return_value=MlflowHostCreds(
462                  host="host", username="user", password="pass", ignore_tls_verification=False
463              ),
464          ),
465      ):
466          response = DatabricksJobRunner(databricks_profile_uri=None)._databricks_api_request(
467              "/clusters/list", "PUT", json={"a": "b"}
468          )
469          assert json.loads(response.text) == {"OK": "woo"}
470  
471  
472  def test_run_databricks_failed():
473      text = '{"error_code": "RESOURCE_DOES_NOT_EXIST", "message": "Node type not supported"}'
474      with (
475          mock.patch("mlflow.utils.databricks_utils.get_databricks_host_creds"),
476          mock.patch(
477              "mlflow.utils.rest_utils.http_request",
478              return_value=mock.Mock(text=text, status_code=400),
479          ),
480      ):
481          runner = DatabricksJobRunner(construct_db_uri_from_profile("profile"))
482          with pytest.raises(
483              MlflowException, match="RESOURCE_DOES_NOT_EXIST: Node type not supported"
484          ):
485              runner._run_shell_command_job("/project", "command", {}, {})
486  
487  
488  def test_run_databricks_generates_valid_mlflow_run_cmd():
489      cmd = _get_cluster_mlflow_run_cmd(
490          project_dir="my_project_dir",
491          run_id="hi",
492          entry_point="main",
493          parameters={"a": "b"},
494          env_manager="conda",
495      )
496      assert cmd[0] == "mlflow"
497      with mock.patch("mlflow.projects.run"):
498          invoke_cli_runner(cli.cli, cmd[1:])