/ tests / utils / test_databricks_utils.py
test_databricks_utils.py
   1  import builtins
   2  import json
   3  import os
   4  import platform
   5  import sys
   6  import time
   7  from unittest import mock
   8  
   9  import pytest
  10  
  11  import mlflow
  12  from mlflow.exceptions import MlflowException
  13  from mlflow.legacy_databricks_cli.configure.provider import (
  14      DatabricksConfig,
  15      DatabricksModelServingConfigProvider,
  16  )
  17  from mlflow.utils import databricks_utils
  18  from mlflow.utils.databricks_utils import (
  19      DatabricksConfigProvider,
  20      DatabricksRuntimeVersion,
  21      _NoDbutilsError,
  22      check_databricks_secret_scope_access,
  23      get_databricks_host_creds,
  24      get_databricks_runtime_major_minor_version,
  25      get_databricks_runtime_version,
  26      get_databricks_workspace_client_config,
  27      get_dbconnect_udf_sandbox_info,
  28      get_mlflow_credential_context_by_run_id,
  29      get_sgc_job_run_id,
  30      get_workspace_info_from_databricks_secrets,
  31      get_workspace_info_from_dbutils,
  32      get_workspace_url,
  33      is_databricks_default_tracking_uri,
  34      is_running_in_ipython_environment,
  35  )
  36  from mlflow.utils.os import is_windows
  37  
  38  from tests.helper_functions import mock_method_chain
  39  from tests.pyfunc.test_spark import spark  # noqa: F401
  40  
  41  
  42  def test_no_throw():
  43      """
  44      Outside of Databricks the databricks_utils methods should never throw and should only return
  45      None.
  46      """
  47      assert not databricks_utils.is_in_databricks_notebook()
  48      assert not databricks_utils.is_in_databricks_repo_notebook()
  49      assert not databricks_utils.is_in_databricks_job()
  50      assert not databricks_utils.is_dbfs_fuse_available()
  51      assert not databricks_utils.is_in_databricks_runtime()
  52  
  53  
  54  def test_databricks_registry_profile():
  55      mock_provider = mock.MagicMock()
  56      mock_provider.get_config.return_value = None
  57      mock_dbutils = mock.MagicMock()
  58      mock_dbutils.secrets.get.return_value = "random"
  59      with (
  60          mock.patch(
  61              "mlflow.utils.databricks_utils.ProfileConfigProvider", return_value=mock_provider
  62          ),
  63          mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils),
  64      ):
  65          params = databricks_utils.get_databricks_host_creds("databricks://profile:prefix")
  66          mock_dbutils.secrets.get.assert_any_call(key="prefix-host", scope="profile")
  67          mock_dbutils.secrets.get.assert_any_call(key="prefix-token", scope="profile")
  68          assert params.host == "random"
  69          assert params.token == "random"
  70  
  71  
  72  def test_databricks_no_creds_found():
  73      with pytest.raises(MlflowException, match="Reading Databricks credential configuration failed"):
  74          databricks_utils.get_databricks_host_creds()
  75  
  76  
  77  def test_databricks_no_creds_found_in_model_serving(monkeypatch):
  78      monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", "true")
  79      with pytest.raises(
  80          MlflowException, match="Reading Databricks credential configuration in model serving failed"
  81      ):
  82          databricks_utils.get_databricks_host_creds()
  83  
  84  
  85  def test_databricks_single_slash_in_uri_scheme_throws():
  86      with pytest.raises(MlflowException, match="URI is formatted incorrectly"):
  87          databricks_utils.get_databricks_host_creds("databricks:/profile:path")
  88  
  89  
  90  @pytest.fixture
  91  def oauth_file(tmp_path):
  92      token_contents = {"OAUTH_TOKEN": [{"oauthTokenValue": "token2"}]}
  93      oauth_file = tmp_path.joinpath("model-dependencies-oauth-token")
  94      with open(oauth_file, "w") as f:
  95          json.dump(token_contents, f)
  96      return oauth_file
  97  
  98  
  99  def test_get_model_dependency_token(oauth_file):
 100      with mock.patch(
 101          "mlflow.utils.databricks_utils._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH", str(oauth_file)
 102      ):
 103          token = databricks_utils.get_model_dependency_oauth_token()
 104          assert token == "token2"
 105  
 106  
 107  def test_get_model_dependency_oauth_token_model_serving_throws():
 108      with pytest.raises(MlflowException, match="Unable to read Oauth credentials"):
 109          databricks_utils.get_model_dependency_oauth_token()
 110  
 111  
 112  @pytest.mark.parametrize(
 113      ("model_serving_env_var"),
 114      [
 115          ("DATABRICKS_MODEL_SERVING_HOST_URL"),
 116          ("DB_MODEL_SERVING_HOST_URL"),
 117      ],
 118  )
 119  def test_databricks_params_model_serving_oauth_cache_databricks(
 120      monkeypatch, oauth_file, model_serving_env_var
 121  ):
 122      monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", "true")
 123      monkeypatch.setenv(model_serving_env_var, "host")
 124      monkeypatch.setenv("DB_DEPENDENCY_OAUTH_CACHE", "token")
 125      monkeypatch.setenv("DB_DEPENDENCY_OAUTH_CACHE_EXPIRY_TS", str(time.time() + 5))
 126      # oauth file still needs to be present for should_fetch_model_serving_environment_oauth()
 127      #  to evaluate true
 128      with mock.patch(
 129          "mlflow.utils.databricks_utils._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH", str(oauth_file)
 130      ):
 131          params = databricks_utils.get_databricks_host_creds()
 132          assert params.host == "host"
 133          # should use token from cache, rather than token from oauthfile
 134          assert params.token == "token"
 135  
 136  
 137  def test_databricks_params_model_serving_oauth_cache_expired(monkeypatch, oauth_file):
 138      monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", "true")
 139      monkeypatch.setenv("DATABRICKS_MODEL_SERVING_HOST_URL", "host")
 140      monkeypatch.setenv("DB_DEPENDENCY_OAUTH_CACHE", "token")
 141      monkeypatch.setenv("DB_DEPENDENCY_OAUTH_CACHE_EXPIRY_TS", str(time.time() - 5))
 142      with mock.patch(
 143          "mlflow.utils.databricks_utils._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH", str(oauth_file)
 144      ):
 145          params = databricks_utils.get_databricks_host_creds()
 146          # cache should get updated with new token
 147          assert os.environ["DB_DEPENDENCY_OAUTH_CACHE"] == "token2"
 148          assert float(os.environ["DB_DEPENDENCY_OAUTH_CACHE_EXPIRY_TS"]) > time.time()
 149          assert params.host == "host"
 150          # should use token2 from oauthfile, rather than token from cache
 151          assert params.token == "token2"
 152  
 153  
 154  def test_databricks_params_model_serving_read_oauth(monkeypatch, oauth_file):
 155      monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", "true")
 156      monkeypatch.setenv("DATABRICKS_MODEL_SERVING_HOST_URL", "host")
 157      with mock.patch(
 158          "mlflow.utils.databricks_utils._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH", str(oauth_file)
 159      ):
 160          params = databricks_utils.get_databricks_host_creds()
 161          assert os.environ["DB_DEPENDENCY_OAUTH_CACHE"] == "token2"
 162          assert float(os.environ["DB_DEPENDENCY_OAUTH_CACHE_EXPIRY_TS"]) > time.time()
 163          assert params.host == "host"
 164          assert params.token == "token2"
 165  
 166  
 167  def test_databricks_params_env_var_overrides_model_serving_oauth(monkeypatch, oauth_file):
 168      monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", "true")
 169      monkeypatch.setenv("DATABRICKS_MODEL_SERVING_HOST_URL", "host")
 170      monkeypatch.setenv("DATABRICKS_HOST", "host_envvar")
 171      monkeypatch.setenv("DATABRICKS_TOKEN", "pat_token")
 172      # oauth file still needs to be present for should_fetch_model_serving_environment_oauth()
 173      #  to evaluate true
 174      with mock.patch(
 175          "mlflow.utils.databricks_utils._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH", str(oauth_file)
 176      ):
 177          params = databricks_utils.get_databricks_host_creds()
 178          # should use token and host from envvar, rather than token from oauthfile
 179          assert params.use_databricks_sdk
 180  
 181  
 182  def test_model_serving_config_provider_errors_caught():
 183      provider = DatabricksModelServingConfigProvider()
 184      with mock.patch.object(
 185          provider,
 186          "_get_databricks_model_serving_config",
 187          side_effect=Exception("Failed to Read OAuth Creds"),
 188      ):
 189          assert provider.get_config() is None
 190  
 191  
 192  def test_get_workspace_info_from_databricks_secrets():
 193      mock_dbutils = mock.MagicMock()
 194      mock_dbutils.secrets.get.return_value = "workspace-placeholder-info"
 195      with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils):
 196          workspace_host, workspace_id = get_workspace_info_from_databricks_secrets(
 197              "databricks://profile:prefix"
 198          )
 199          mock_dbutils.secrets.get.assert_any_call(key="prefix-host", scope="profile")
 200          mock_dbutils.secrets.get.assert_any_call(key="prefix-workspace-id", scope="profile")
 201          assert workspace_host == "workspace-placeholder-info"
 202          assert workspace_id == "workspace-placeholder-info"
 203  
 204  
 205  def test_get_workspace_info_from_dbutils():
 206      mock_dbutils = mock.MagicMock()
 207      methods = ["notebook.entry_point.getDbutils", "notebook", "getContext"]
 208      mock_method_chain(
 209          mock_dbutils, methods + ["browserHostName", "get"], return_value="mlflow.databricks.com"
 210      )
 211      mock_method_chain(mock_dbutils, methods + ["workspaceId", "get"], return_value="1111")
 212  
 213      with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils):
 214          workspace_host, workspace_id = get_workspace_info_from_dbutils()
 215          assert workspace_host == "https://mlflow.databricks.com"
 216          assert workspace_id == "1111"
 217  
 218  
 219  def test_get_workspace_info_from_dbutils_no_browser_host_name():
 220      mock_dbutils = mock.MagicMock()
 221      methods = ["notebook.entry_point.getDbutils", "notebook", "getContext"]
 222      mock_method_chain(mock_dbutils, methods + ["browserHostName", "get"], return_value=None)
 223      mock_method_chain(
 224          mock_dbutils, methods + ["apiUrl", "get"], return_value="https://mlflow.databricks.com"
 225      )
 226      mock_method_chain(mock_dbutils, methods + ["workspaceId", "get"], return_value="1111")
 227      with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils):
 228          workspace_host, workspace_id = get_workspace_info_from_dbutils()
 229          assert workspace_host == "https://mlflow.databricks.com"
 230          assert workspace_id == "1111"
 231  
 232  
 233  def test_get_workspace_info_from_dbutils_old_runtimes():
 234      mock_dbutils = mock.MagicMock()
 235      methods = ["notebook.entry_point.getDbutils", "notebook", "getContext"]
 236      mock_method_chain(
 237          mock_dbutils,
 238          methods + ["toJson", "get"],
 239          return_value='{"tags": {"orgId" : "1111", "browserHostName": "mlflow.databricks.com"}}',
 240      )
 241      mock_method_chain(
 242          mock_dbutils, methods + ["browserHostName", "get"], return_value="mlflow.databricks.com"
 243      )
 244  
 245      # Mock out workspace ID tag
 246      mock_workspace_id_tag_opt = mock.MagicMock()
 247      mock_workspace_id_tag_opt.isDefined.return_value = True
 248      mock_workspace_id_tag_opt.get.return_value = "1111"
 249      mock_method_chain(
 250          mock_dbutils, methods + ["tags", "get"], return_value=mock_workspace_id_tag_opt
 251      )
 252  
 253      # Mimic old runtimes by raising an exception when the nonexistent "workspaceId" method is called
 254      mock_method_chain(
 255          mock_dbutils,
 256          methods + ["workspaceId"],
 257          side_effect=Exception("workspaceId method not defined!"),
 258      )
 259      with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils):
 260          workspace_host, workspace_id = get_workspace_info_from_dbutils()
 261          assert workspace_host == "https://mlflow.databricks.com"
 262          assert workspace_id == "1111"
 263  
 264  
 265  def test_get_workspace_info_from_dbutils_when_no_dbutils_available():
 266      with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=None):
 267          workspace_host, workspace_id = get_workspace_info_from_dbutils()
 268          assert workspace_host is None
 269          assert workspace_id is None
 270  
 271  
 272  @pytest.mark.parametrize(
 273      ("tracking_uri", "result"),
 274      [
 275          ("databricks", True),
 276          ("databricks://profile:prefix", False),
 277          ("databricks://profile/prefix", False),
 278          ("nondatabricks", False),
 279          ("databricks\t\r", True),
 280          ("databricks\n", True),
 281          ("databricks://", False),
 282          ("databricks://aAbB", False),
 283      ],
 284  )
 285  def test_is_databricks_default_tracking_uri(tracking_uri, result):
 286      assert is_databricks_default_tracking_uri(tracking_uri) == result
 287  
 288  
 289  def test_databricks_params_throws_errors():
 290      # No hostname
 291      mock_provider = mock.MagicMock()
 292      mock_provider.get_config.return_value = DatabricksConfig.from_password(
 293          None, "user", "pass", insecure=True
 294      )
 295      with mock.patch(
 296          "mlflow.utils.databricks_utils.ProfileConfigProvider", return_value=mock_provider
 297      ):
 298          with pytest.raises(
 299              Exception, match="Reading Databricks credential configuration failed with"
 300          ):
 301              databricks_utils.get_databricks_host_creds()
 302  
 303      # No authentication
 304      mock_provider = mock.MagicMock()
 305      mock_provider.get_config.return_value = DatabricksConfig.from_password(
 306          "host", None, None, insecure=True
 307      )
 308      with mock.patch(
 309          "mlflow.utils.databricks_utils.ProfileConfigProvider", return_value=mock_provider
 310      ):
 311          with pytest.raises(
 312              Exception, match="Reading Databricks credential configuration failed with"
 313          ):
 314              databricks_utils.get_databricks_host_creds()
 315  
 316  
 317  def test_is_in_databricks_runtime(monkeypatch):
 318      monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "11.x")
 319      assert databricks_utils.is_in_databricks_runtime()
 320  
 321      monkeypatch.delenv("DATABRICKS_RUNTIME_VERSION")
 322      assert not databricks_utils.is_in_databricks_runtime()
 323  
 324  
 325  @pytest.mark.parametrize("val", ["true", "1"])
 326  def test_is_in_databricks_model_serving_environment(monkeypatch, val):
 327      monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", val)
 328      assert databricks_utils.is_in_databricks_model_serving_environment()
 329  
 330      monkeypatch.delenv("IS_IN_DB_MODEL_SERVING_ENV")
 331      assert not databricks_utils.is_in_databricks_model_serving_environment()
 332  
 333  
 334  # test both is_in_databricks_model_serving_environment and
 335  # should_fetch_model_serving_environment_oauth return apprropriate values
 336  def test_should_fetch_model_serving_environment_oauth(monkeypatch, oauth_file):
 337      monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", "true")
 338      # will return false if file mount is not configured even if env var set
 339      assert not databricks_utils.should_fetch_model_serving_environment_oauth()
 340  
 341      with mock.patch(
 342          "mlflow.utils.databricks_utils._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH", str(oauth_file)
 343      ):
 344          # both file mount and env var exist, both values should return true
 345          assert databricks_utils.should_fetch_model_serving_environment_oauth()
 346  
 347          # file mount without env var should return false
 348          monkeypatch.delenv("IS_IN_DB_MODEL_SERVING_ENV")
 349          assert not databricks_utils.should_fetch_model_serving_environment_oauth()
 350  
 351  
 352  def test_get_repl_id():
 353      # Outside of Databricks environments, the Databricks REPL ID should be absent
 354      assert databricks_utils.get_repl_id() is None
 355  
 356      mock_client = mock.MagicMock()
 357      mock_client.getReplId.return_value = "testReplId1"
 358      with mock.patch(
 359          "mlflow.utils.databricks_utils._get_runtime_integration_client",
 360          return_value=mock_client,
 361      ):
 362          assert databricks_utils.get_repl_id() == "testReplId1"
 363          mock_client.getReplId.assert_called_once()
 364  
 365      # When runtime_integration_client is unavailable, fall back to entry_point.
 366      mock_dbutils = mock.MagicMock()
 367      mock_dbutils.entry_point.getReplId.return_value = "testReplId1"
 368      with (
 369          mock.patch(
 370              "mlflow.utils.databricks_utils._get_runtime_integration_client",
 371              side_effect=Exception("unavailable"),
 372          ),
 373          mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils),
 374      ):
 375          assert databricks_utils.get_repl_id() == "testReplId1"
 376          mock_dbutils.entry_point.getReplId.assert_called_once()
 377  
 378      mock_sparkcontext_inst = mock.MagicMock()
 379      mock_sparkcontext_inst.getLocalProperty.return_value = "testReplId2"
 380      mock_sparkcontext_class = mock.MagicMock()
 381      mock_sparkcontext_class.getOrCreate.return_value = mock_sparkcontext_inst
 382      mock_spark = mock.MagicMock()
 383      mock_spark.SparkContext = mock_sparkcontext_class
 384  
 385      original_import = builtins.__import__
 386  
 387      def mock_import(name, *args, **kwargs):
 388          if name == "pyspark":
 389              return mock_spark
 390          else:
 391              return original_import(name, *args, **kwargs)
 392  
 393      with mock.patch("builtins.__import__", side_effect=mock_import):
 394          assert databricks_utils.get_repl_id() == "testReplId2"
 395  
 396  
 397  def test_use_repl_context_if_available(tmp_path, monkeypatch):
 398      # Simulate a case where `dbruntime.databricks_repl_context.get_context` is unavailable.
 399      with pytest.raises(ModuleNotFoundError, match="No module named 'dbruntime'"):
 400          from dbruntime.databricks_repl_context import get_context  # noqa: F401
 401  
 402      command_context_mock = mock.MagicMock()
 403      command_context_mock.jobId().get.return_value = "job_id"
 404      command_context_mock.tags().get("jobType").get.return_value = "NORMAL"
 405      with mock.patch(
 406          "mlflow.utils.databricks_utils._get_command_context", return_value=command_context_mock
 407      ) as mock_get_command_context:
 408          assert databricks_utils.get_job_id() == "job_id"
 409          mock_get_command_context.assert_called_once()
 410  
 411      # Create a fake databricks_repl_context module
 412      dbruntime = tmp_path.joinpath("dbruntime")
 413      dbruntime.mkdir()
 414      dbruntime.joinpath("databricks_repl_context.py").write_text(
 415          """
 416  def get_context():
 417      pass
 418  """
 419      )
 420      monkeypatch.syspath_prepend(str(tmp_path))
 421  
 422      # Simulate a case where the REPL context object is not initialized.
 423      with (
 424          mock.patch(
 425              "dbruntime.databricks_repl_context.get_context",
 426              return_value=None,
 427          ) as mock_get_context,
 428          mock.patch(
 429              "mlflow.utils.databricks_utils._get_command_context", return_value=command_context_mock
 430          ) as mock_get_command_context,
 431      ):
 432          assert databricks_utils.get_job_id() == "job_id"
 433          assert mock_get_command_context.call_count == 1
 434  
 435      with (
 436          mock.patch(
 437              "dbruntime.databricks_repl_context.get_context",
 438              return_value=mock.MagicMock(jobId="job_id"),
 439          ) as mock_get_context,
 440          mock.patch("mlflow.utils.databricks_utils._get_dbutils") as mock_dbutils,
 441      ):
 442          assert databricks_utils.get_job_id() == "job_id"
 443          mock_get_context.assert_called_once()
 444          mock_dbutils.assert_not_called()
 445  
 446      with (
 447          mock.patch(
 448              "dbruntime.databricks_repl_context.get_context",
 449              return_value=mock.MagicMock(
 450                  notebookId="notebook_id", notebookPath="/Repos/notebook_path"
 451              ),
 452          ) as mock_get_context,
 453          mock.patch(
 454              "mlflow.utils.databricks_utils._get_property_from_spark_context"
 455          ) as mock_spark_context,
 456      ):
 457          assert databricks_utils.get_notebook_id() == "notebook_id"
 458          assert databricks_utils.is_in_databricks_repo_notebook()
 459          assert mock_get_context.call_count == 2
 460          mock_spark_context.assert_not_called()
 461  
 462      with (
 463          mock.patch(
 464              "dbruntime.databricks_repl_context.get_context",
 465              return_value=mock.MagicMock(
 466                  notebookId="notebook_id", notebookPath="/Users/notebook_path"
 467              ),
 468          ) as mock_get_context,
 469          mock.patch(
 470              "mlflow.utils.databricks_utils._get_property_from_spark_context"
 471          ) as mock_spark_context,
 472      ):
 473          assert not databricks_utils.is_in_databricks_repo_notebook()
 474  
 475      with (
 476          mock.patch(
 477              "dbruntime.databricks_repl_context.get_context",
 478              return_value=mock.MagicMock(isInCluster=True),
 479          ) as mock_get_context,
 480          mock.patch("mlflow.utils._spark_utils._get_active_spark_session") as mock_spark_session,
 481      ):
 482          assert databricks_utils.is_in_cluster()
 483          mock_get_context.assert_called_once()
 484          mock_spark_session.assert_not_called()
 485  
 486  
 487  @pytest.mark.parametrize("get_ipython", [True, None])
 488  def test_is_running_in_ipython_environment_works(get_ipython):
 489      mod_name = "IPython"
 490      if mod_name in sys.modules:
 491          ipython_mod = sys.modules.pop(mod_name)
 492          assert not is_running_in_ipython_environment()
 493          sys.modules["IPython"] = ipython_mod
 494  
 495          with mock.patch("IPython.get_ipython", return_value=get_ipython):
 496              assert is_running_in_ipython_environment() == (get_ipython is not None)
 497  
 498  
 499  def test_get_mlflow_credential_context_by_run_id():
 500      with (
 501          mock.patch(
 502              "mlflow.tracking.artifact_utils.get_artifact_uri", return_value="dbfs:/path/to/artifact"
 503          ) as mock_get_artifact_uri,
 504          mock.patch(
 505              "mlflow.utils.uri.get_databricks_profile_uri_from_artifact_uri",
 506              return_value="databricks://path/to/profile",
 507          ) as mock_get_databricks_profile,
 508          mock.patch(
 509              "mlflow.utils.databricks_utils.MlflowCredentialContext"
 510          ) as mock_credential_context,
 511      ):
 512          get_mlflow_credential_context_by_run_id(run_id="abc")
 513          mock_get_artifact_uri.assert_called_once_with(run_id="abc")
 514          mock_get_databricks_profile.assert_called_once_with("dbfs:/path/to/artifact")
 515          mock_credential_context.assert_called_once_with("databricks://path/to/profile")
 516  
 517  
 518  def test_check_databricks_secret_scope_access():
 519      mock_dbutils = mock.MagicMock()
 520      mock_dbutils.secrets.list.return_value = "random"
 521      with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils):
 522          check_databricks_secret_scope_access("scope")
 523          mock_dbutils.secrets.list.assert_called_once_with("scope")
 524  
 525  
 526  def test_check_databricks_secret_scope_access_error():
 527      mock_dbutils = mock.MagicMock()
 528      mock_dbutils.secrets.list.side_effect = Exception("no scope access")
 529      with (
 530          mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils),
 531          mock.patch("mlflow.utils.databricks_utils._logger.warning") as mock_warning,
 532      ):
 533          check_databricks_secret_scope_access("scope")
 534          mock_warning.assert_called_once_with(
 535              "Unable to access Databricks secret scope 'scope' for OpenAI credentials that will be "
 536              "used to deploy the model to Databricks Model Serving. Please verify that the current "
 537              "Databricks user has 'READ' permission for this scope. For more information, see "
 538              "https://mlflow.org/docs/latest/python_api/openai/index.html#credential-management-for-openai-on-databricks. "  # noqa: E501
 539              "Error: no scope access"
 540          )
 541          mock_dbutils.secrets.list.assert_called_once_with("scope")
 542  
 543  
 544  @pytest.mark.parametrize(
 545      ("version_str", "is_client_image", "major", "minor"),
 546      [
 547          ("client.0", True, 0, 0),
 548          ("client.1", True, 1, 0),
 549          ("client.1.6", True, 1, 6),
 550          ("15.1", False, 15, 1),
 551          ("12.1.1", False, 12, 1),
 552      ],
 553  )
 554  def test_get_databricks_runtime_major_minor_version(
 555      monkeypatch, version_str, is_client_image, major, minor
 556  ):
 557      monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", version_str)
 558      dbr_version = get_databricks_runtime_major_minor_version()
 559  
 560      assert dbr_version.is_client_image == is_client_image
 561      assert dbr_version.major == major
 562      assert dbr_version.minor == minor
 563  
 564  
 565  def test_get_dbr_major_minor_version_throws_on_invalid_version_key(monkeypatch):
 566      # minor version is not allowed to be a string
 567      monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "12.x")
 568      with pytest.raises(MlflowException, match="Failed to parse databricks runtime version"):
 569          get_databricks_runtime_major_minor_version()
 570  
 571  
 572  def test_prioritize_env_var_config_provider(monkeypatch):
 573      monkeypatch.setenv("DATABRICKS_HOST", "my_host1")
 574      monkeypatch.setenv("DATABRICKS_TOKEN", "token1")
 575  
 576      class MyProvider(DatabricksConfigProvider):
 577          def get_config(self):
 578              return DatabricksConfig(host="my_host2", token="token2")
 579  
 580      monkeypatch.setattr(databricks_utils, "_dynamic_token_config_provider", MyProvider)
 581  
 582      hc = get_databricks_host_creds("databricks")
 583      assert hc.host == "my_host1"
 584      assert hc.token == "token1"
 585  
 586  
 587  @pytest.mark.parametrize(
 588      ("input_url", "expected_result"),
 589      [
 590          # Test with a valid URL without https:// prefix
 591          ("example.com", "https://example.com"),
 592          # Test with a valid URL with https:// prefix
 593          ("https://example.com", "https://example.com"),
 594          # Test with None URL
 595          (None, None),
 596      ],
 597  )
 598  def test_get_workspace_url(input_url, expected_result):
 599      with mock.patch("mlflow.utils.databricks_utils._get_workspace_url", return_value=input_url):
 600          result = get_workspace_url()
 601          assert result == expected_result
 602  
 603  
 604  @pytest.mark.skipif(is_windows(), reason="This test doesn't work on Windows")
 605  def test_get_dbconnect_udf_sandbox_info(spark, monkeypatch):
 606      monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "client.1.2")
 607      databricks_utils._dbconnect_udf_sandbox_info_cache = None
 608  
 609      spark.udf.register(
 610          "current_version",
 611          lambda: {"dbr_version": "15.4.x-scala2.12"},
 612          returnType="dbr_version string",
 613      )
 614  
 615      info = get_dbconnect_udf_sandbox_info(spark)
 616      assert info.mlflow_version == mlflow.__version__
 617      assert info.image_version == "client.1.2"
 618      assert info.runtime_version == "15.4"
 619      assert info.platform_machine == platform.machine()
 620  
 621      monkeypatch.delenv("DATABRICKS_RUNTIME_VERSION")
 622      databricks_utils._dbconnect_udf_sandbox_info_cache = None
 623  
 624      info = get_dbconnect_udf_sandbox_info(spark)
 625      assert info.mlflow_version == mlflow.__version__
 626      assert info.image_version == "15.4"
 627      assert info.runtime_version == "15.4"
 628      assert info.platform_machine == platform.machine()
 629  
 630  
 631  def test_construct_databricks_uc_registered_model_url():
 632      # Test case with workspace ID
 633      workspace_url = "https://databricks.com"
 634      registered_model_name = "name.mlflow.echo_model"
 635      version = "6"
 636      workspace_id = "123"
 637  
 638      expected_url = (
 639          "https://databricks.com/explore/data/models/name/mlflow/echo_model/version/6?o=123"
 640      )
 641  
 642      result = databricks_utils._construct_databricks_uc_registered_model_url(
 643          workspace_url=workspace_url,
 644          registered_model_name=registered_model_name,
 645          version=version,
 646          workspace_id=workspace_id,
 647      )
 648  
 649      assert result == expected_url
 650  
 651      # Test case without workspace ID
 652      expected_url_no_workspace = (
 653          "https://databricks.com/explore/data/models/name/mlflow/echo_model/version/6"
 654      )
 655  
 656      result_no_workspace = databricks_utils._construct_databricks_uc_registered_model_url(
 657          workspace_url=workspace_url,
 658          registered_model_name=registered_model_name,
 659          version=version,
 660      )
 661  
 662      assert result_no_workspace == expected_url_no_workspace
 663  
 664  
 665  def test_construct_databricks_logged_model_url():
 666      # Test case with workspace ID
 667      workspace_url = "https://databricks.com"
 668      experiment_id = "123456"
 669      model_id = "model_789"
 670      workspace_id = "123"
 671  
 672      expected_url = "https://databricks.com/ml/experiments/123456/models/model_789?o=123"
 673  
 674      result = databricks_utils._construct_databricks_logged_model_url(
 675          workspace_url=workspace_url,
 676          experiment_id=experiment_id,
 677          model_id=model_id,
 678          workspace_id=workspace_id,
 679      )
 680  
 681      assert result == expected_url
 682  
 683      # Test case without workspace ID
 684      expected_url_no_workspace = "https://databricks.com/ml/experiments/123456/models/model_789"
 685  
 686      result_no_workspace = databricks_utils._construct_databricks_logged_model_url(
 687          workspace_url=workspace_url,
 688          experiment_id=experiment_id,
 689          model_id=model_id,
 690      )
 691  
 692      assert result_no_workspace == expected_url_no_workspace
 693  
 694  
 695  def test_print_databricks_deployment_job_url():
 696      workspace_url = "https://databricks.com"
 697      job_id = "123"
 698      workspace_id = "456"
 699  
 700      expected_url_no_workspace = "https://databricks.com/jobs/123"
 701      expected_url = f"{expected_url_no_workspace}?o=456"
 702      model_name = "main.models.name"
 703  
 704      with (
 705          mock.patch("mlflow.utils.databricks_utils.eprint") as mock_eprint,
 706          mock.patch("mlflow.utils.databricks_utils.get_workspace_url", return_value=workspace_url),
 707      ):
 708          # Test case with a workspace ID
 709          with mock.patch(
 710              "mlflow.utils.databricks_utils.get_workspace_id", return_value=workspace_id
 711          ):
 712              result = databricks_utils._print_databricks_deployment_job_url(
 713                  model_name=model_name,
 714                  job_id=job_id,
 715              )
 716  
 717              assert result == expected_url
 718              mock_eprint.assert_called_once_with(
 719                  f"🔗 Linked deployment job to '{model_name}': {expected_url}"
 720              )
 721              mock_eprint.reset_mock()
 722  
 723          # Test case without a workspace ID
 724          with mock.patch("mlflow.utils.databricks_utils.get_workspace_id", return_value=None):
 725              result_no_workspace = databricks_utils._print_databricks_deployment_job_url(
 726                  model_name=model_name,
 727                  job_id=job_id,
 728              )
 729  
 730              assert result_no_workspace == expected_url_no_workspace
 731              mock_eprint.assert_called_once_with(
 732                  f"🔗 Linked deployment job to '{model_name}': {expected_url_no_workspace}"
 733              )
 734  
 735  
 736  @pytest.mark.parametrize(
 737      ("version_str", "expected_is_client", "expected_major", "expected_minor", "expected_is_gpu"),
 738      [
 739          ("client.2.0", True, 2, 0, False),
 740          ("client.3.1", True, 3, 1, False),
 741          ("13.2", False, 13, 2, False),
 742          ("15.4", False, 15, 4, False),
 743          ("client.8.1-gpu", True, 8, 1, True),
 744          ("client.10.0-gpu", True, 10, 0, True),
 745          ("14.3-gpu", False, 14, 3, True),
 746          ("15.1-gpu", False, 15, 1, True),
 747      ],
 748  )
 749  def test_databricks_runtime_version_parse(
 750      version_str,
 751      expected_is_client,
 752      expected_major,
 753      expected_minor,
 754      expected_is_gpu,
 755  ):
 756      version = DatabricksRuntimeVersion.parse(version_str)
 757      assert version.is_client_image == expected_is_client
 758      assert version.major == expected_major
 759      assert version.minor == expected_minor
 760      assert version.is_gpu_image == expected_is_gpu
 761  
 762  
 763  @pytest.mark.parametrize(
 764      ("env_version", "expected_is_client", "expected_major", "expected_minor", "expected_is_gpu"),
 765      [
 766          ("client.2.0", True, 2, 0, False),
 767          ("13.2", False, 13, 2, False),
 768          ("client.8.1-gpu", True, 8, 1, True),
 769          ("14.3-gpu", False, 14, 3, True),
 770      ],
 771  )
 772  def test_databricks_runtime_version_parse_default(
 773      monkeypatch,
 774      env_version,
 775      expected_is_client,
 776      expected_major,
 777      expected_minor,
 778      expected_is_gpu,
 779  ):
 780      monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", env_version)
 781      version = DatabricksRuntimeVersion.parse()
 782      assert version.is_client_image == expected_is_client
 783      assert version.major == expected_major
 784      assert version.minor == expected_minor
 785      assert version.is_gpu_image == expected_is_gpu
 786  
 787  
 788  def test_databricks_runtime_version_parse_default_no_env(monkeypatch):
 789      """Test that DatabricksRuntimeVersion.parse() raises error when no environment variable is
 790      set.
 791      """
 792      monkeypatch.delenv("DATABRICKS_RUNTIME_VERSION", raising=False)
 793      monkeypatch.delenv("DATABRICKS_ENV_VERSION", raising=False)
 794      with pytest.raises(Exception, match="Failed to parse databricks runtime version"):
 795          DatabricksRuntimeVersion.parse()
 796  
 797  
 798  @pytest.mark.parametrize(
 799      ("env_version", "accelerator", "expected"),
 800      [
 801          ("4", "A10G", "client.4-gpu"),
 802          ("4", "NVIDIA H100", "client.4-gpu"),
 803          ("4", None, "client.4"),
 804      ],
 805  )
 806  def test_get_databricks_runtime_version_from_env_version(
 807      monkeypatch, env_version, accelerator, expected
 808  ):
 809      monkeypatch.delenv("DATABRICKS_RUNTIME_VERSION", raising=False)
 810      monkeypatch.setenv("DATABRICKS_ENV_VERSION", env_version)
 811      if accelerator:
 812          monkeypatch.setenv("DATABRICKS_ACCELERATOR", accelerator)
 813      else:
 814          monkeypatch.delenv("DATABRICKS_ACCELERATOR", raising=False)
 815      assert get_databricks_runtime_version() == expected
 816  
 817  
 818  @pytest.mark.parametrize(
 819      ("accelerator", "expected"),
 820      [
 821          ("A10G", "client.4-gpu"),
 822          (None, "client.4"),
 823      ],
 824  )
 825  def test_databricks_env_version_takes_priority_over_runtime_version(
 826      monkeypatch, accelerator, expected
 827  ):
 828      monkeypatch.setenv("DATABRICKS_ENV_VERSION", "4")
 829      monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "client.4.1")
 830      if accelerator:
 831          monkeypatch.setenv("DATABRICKS_ACCELERATOR", accelerator)
 832      else:
 833          monkeypatch.delenv("DATABRICKS_ACCELERATOR", raising=False)
 834      assert get_databricks_runtime_version() == expected
 835  
 836  
 837  def test_databricks_runtime_version_parse_from_env_version(monkeypatch):
 838      monkeypatch.delenv("DATABRICKS_RUNTIME_VERSION", raising=False)
 839      monkeypatch.setenv("DATABRICKS_ENV_VERSION", "4")
 840      monkeypatch.setenv("DATABRICKS_ACCELERATOR", "A10")
 841      version = DatabricksRuntimeVersion.parse()
 842      assert version.is_client_image is True
 843      assert version.major == 4
 844      assert version.minor == 0
 845      assert version.is_gpu_image is True
 846  
 847  
 848  @pytest.mark.parametrize(
 849      "invalid_version",
 850      [
 851          "invalid",
 852          "client",
 853          "client.invalid",
 854          "13",
 855      ],
 856  )
 857  def test_databricks_runtime_version_parse_invalid(invalid_version):
 858      with pytest.raises(Exception, match="Failed to parse databricks runtime version"):
 859          DatabricksRuntimeVersion.parse(invalid_version)
 860  
 861  
 862  def test_get_databricks_workspace_client_config_with_tracking_uri_provider():
 863      # Mock the workspace client and its config
 864      mock_config = mock.MagicMock()
 865      mock_client_instance = mock.MagicMock()
 866      mock_client_instance.config = mock_config
 867  
 868      # Mock TrackingURIConfigProvider
 869      mock_uri_config = mock.MagicMock()
 870      mock_uri_config.host = "https://test.databricks.com"
 871      mock_uri_config.token = "test_token"
 872  
 873      with (
 874          mock.patch(
 875              "mlflow.utils.databricks_utils.get_db_info_from_uri",
 876              return_value=("profile_name", "key_prefix"),
 877          ),
 878          mock.patch(
 879              "databricks.sdk.WorkspaceClient", return_value=mock_client_instance
 880          ) as mock_workspace_client,
 881          mock.patch("mlflow.utils.databricks_utils.TrackingURIConfigProvider") as mock_provider,
 882      ):
 883          mock_provider.return_value.get_config.return_value = mock_uri_config
 884  
 885          result = get_databricks_workspace_client_config("databricks://profile:prefix")
 886  
 887          # Verify the WorkspaceClient was created with correct parameters
 888          mock_workspace_client.assert_called_once_with(
 889              host="https://test.databricks.com", token="test_token"
 890          )
 891          assert result == mock_config
 892  
 893  
 894  def test_get_databricks_workspace_client_config_with_profile():
 895      # Mock the workspace client and its config
 896      mock_config = mock.MagicMock()
 897      mock_client_instance = mock.MagicMock()
 898      mock_client_instance.config = mock_config
 899  
 900      with (
 901          mock.patch(
 902              "mlflow.utils.databricks_utils.get_db_info_from_uri",
 903              return_value=("profile_name", None),
 904          ),
 905          mock.patch(
 906              "databricks.sdk.WorkspaceClient", return_value=mock_client_instance
 907          ) as mock_workspace_client,
 908      ):
 909          result = get_databricks_workspace_client_config("databricks://profile_name")
 910  
 911          # Verify the WorkspaceClient was created with profile
 912          mock_workspace_client.assert_called_once_with(profile="profile_name")
 913          assert result == mock_config
 914  
 915  
 916  def test_get_databricks_workspace_client_config_env_profile(monkeypatch):
 917      monkeypatch.setenv("DATABRICKS_CONFIG_PROFILE", "env_profile")
 918      # Mock the workspace client and its config
 919      mock_config = mock.MagicMock()
 920      mock_client_instance = mock.MagicMock()
 921      mock_client_instance.config = mock_config
 922  
 923      with (
 924          mock.patch("mlflow.utils.databricks_utils.get_db_info_from_uri", return_value=(None, None)),
 925          mock.patch(
 926              "databricks.sdk.WorkspaceClient", return_value=mock_client_instance
 927          ) as mock_workspace_client,
 928      ):
 929          result = get_databricks_workspace_client_config("databricks")
 930  
 931          # Verify the WorkspaceClient was created with environment profile
 932          mock_workspace_client.assert_called_once_with(profile="env_profile")
 933          assert result == mock_config
 934  
 935  
 936  def test_get_databricks_workspace_client_config_client_creation_error():
 937      with (
 938          mock.patch(
 939              "mlflow.utils.databricks_utils.get_db_info_from_uri", return_value=("profile", None)
 940          ),
 941          mock.patch(
 942              "databricks.sdk.WorkspaceClient", side_effect=Exception("Client creation failed")
 943          ),
 944      ):
 945          with pytest.raises(Exception, match="Client creation failed"):
 946              get_databricks_workspace_client_config("databricks://profile")
 947  
 948  
 949  def test_get_sgc_job_run_id_success(monkeypatch):
 950      monkeypatch.delenv("SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID", raising=False)
 951      mock_dbutils = mock.MagicMock()
 952      mock_dbutils.widgets.get.return_value = "test_job_run_id_12345"
 953  
 954      with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils):
 955          result = get_sgc_job_run_id()
 956          assert result == "test_job_run_id_12345"
 957          mock_dbutils.widgets.get.assert_called_once_with(
 958              "SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID"
 959          )
 960  
 961  
 962  def test_get_sgc_job_run_id_no_dbutils(monkeypatch):
 963      monkeypatch.delenv("SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID", raising=False)
 964      with mock.patch("mlflow.utils.databricks_utils._get_dbutils", side_effect=_NoDbutilsError()):
 965          result = get_sgc_job_run_id()
 966          assert result is None
 967  
 968  
 969  def test_get_sgc_job_run_id_no_dbutils_with_env_var(monkeypatch):
 970      monkeypatch.setenv("SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID", "env_job_run_id_456")
 971      with mock.patch("mlflow.utils.databricks_utils._get_dbutils", side_effect=_NoDbutilsError()):
 972          result = get_sgc_job_run_id()
 973          assert result == "env_job_run_id_456"
 974  
 975  
 976  def test_get_sgc_job_run_id_value_error(monkeypatch):
 977      monkeypatch.delenv("SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID", raising=False)
 978      mock_dbutils = mock.MagicMock()
 979      mock_dbutils.widgets.get.side_effect = ValueError("Widget not found")
 980  
 981      with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils):
 982          result = get_sgc_job_run_id()
 983          assert result is None
 984          mock_dbutils.widgets.get.assert_called_once_with(
 985              "SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID"
 986          )
 987  
 988  
 989  def test_get_sgc_job_run_id_value_error_with_env_var(monkeypatch):
 990      monkeypatch.setenv("SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID", "env_job_run_id_789")
 991      mock_dbutils = mock.MagicMock()
 992      mock_dbutils.widgets.get.side_effect = ValueError("Widget not found")
 993  
 994      with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils):
 995          result = get_sgc_job_run_id()
 996          assert result == "env_job_run_id_789"
 997          mock_dbutils.widgets.get.assert_called_once_with(
 998              "SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID"
 999          )
1000  
1001  
1002  def test_get_sgc_job_run_id_empty_widget_with_env_var(monkeypatch):
1003      monkeypatch.setenv("SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID", "env_job_run_id_999")
1004      mock_dbutils = mock.MagicMock()
1005      mock_dbutils.widgets.get.return_value = ""
1006  
1007      with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils):
1008          result = get_sgc_job_run_id()
1009          assert result == "env_job_run_id_999"
1010          mock_dbutils.widgets.get.assert_called_once_with(
1011              "SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID"
1012          )
1013  
1014  
1015  def test_get_sgc_job_run_id_none_widget_with_env_var(monkeypatch):
1016      monkeypatch.setenv("SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID", "env_job_run_id_111")
1017      mock_dbutils = mock.MagicMock()
1018      mock_dbutils.widgets.get.return_value = None
1019  
1020      with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils):
1021          result = get_sgc_job_run_id()
1022          assert result == "env_job_run_id_111"
1023          mock_dbutils.widgets.get.assert_called_once_with(
1024              "SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID"
1025          )
1026  
1027  
1028  def test_get_sgc_job_run_id_widget_takes_precedence_over_env_var(monkeypatch):
1029      monkeypatch.setenv("SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID", "env_job_run_id_222")
1030      mock_dbutils = mock.MagicMock()
1031      mock_dbutils.widgets.get.return_value = "widget_job_run_id_333"
1032  
1033      with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils):
1034          result = get_sgc_job_run_id()
1035          assert result == "widget_job_run_id_333"
1036          mock_dbutils.widgets.get.assert_called_once_with(
1037              "SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID"
1038          )
1039  
1040  
1041  def test_databricks_config_profile_env_var_is_respected(tmp_path, monkeypatch):
1042      file_path = tmp_path / ".databrickscfg"
1043      monkeypatch.setenv("MLFLOW_TRACKING_URI", "databricks")
1044      monkeypatch.setenv("DATABRICKS_CONFIG_FILE", str(file_path))
1045      monkeypatch.setenv("DATABRICKS_CONFIG_PROFILE", "test")
1046  
1047      file_path.write_text("""[DEFAULT]
1048  host = http://default-workspace.databricks.com
1049  token = default-token
1050  
1051  [test]
1052  host = https://test-workspace.databricks.com
1053  token = test-token
1054  """)
1055  
1056      # the resulting config should be the one from the [test] section
1057      result = get_databricks_host_creds("databricks")
1058      assert result.host == "https://test-workspace.databricks.com"
1059      assert result.token == "test-token"
1060  
1061  
1062  def test_get_databricks_nfs_temp_dir():
1063      mock_dbutils = mock.MagicMock()
1064      mock_client = mock.MagicMock()
1065      mock_client.getUserNFSTempDir.return_value = "/nfs/user/grpc"
1066  
1067      # When runtime_integration_client is available, use getUserNFSTempDir from client
1068      with (
1069          mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils),
1070          mock.patch(
1071              "mlflow.utils.databricks_utils._get_runtime_integration_client",
1072              return_value=mock_client,
1073          ),
1074      ):
1075          assert databricks_utils.get_databricks_nfs_temp_dir() == "/nfs/user/grpc"
1076          mock_client.getUserNFSTempDir.assert_called_once()
1077  
1078      # When runtime_integration_client raises, fall back to entry_point.getUserNFSTempDir
1079      mock_dbutils2 = mock.MagicMock()
1080      mock_dbutils2.entry_point.getUserNFSTempDir.return_value = "/nfs/user"
1081      with (
1082          mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils2),
1083          mock.patch(
1084              "mlflow.utils.databricks_utils._get_runtime_integration_client",
1085              side_effect=Exception("unavailable"),
1086          ),
1087      ):
1088          assert databricks_utils.get_databricks_nfs_temp_dir() == "/nfs/user"
1089          mock_dbutils2.entry_point.getUserNFSTempDir.assert_called_once()
1090  
1091  
1092  def test_get_databricks_local_temp_dir():
1093      mock_dbutils = mock.MagicMock()
1094      mock_client = mock.MagicMock()
1095      mock_client.getUserLocalTempDir.return_value = "/local/user/grpc"
1096  
1097      # When runtime_integration_client is available, use getUserLocalTempDir from client
1098      with (
1099          mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils),
1100          mock.patch(
1101              "mlflow.utils.databricks_utils._get_runtime_integration_client",
1102              return_value=mock_client,
1103          ),
1104      ):
1105          assert databricks_utils.get_databricks_local_temp_dir() == "/local/user/grpc"
1106          mock_client.getUserLocalTempDir.assert_called_once()
1107  
1108      # When runtime_integration_client raises, fall back to entry_point.getUserLocalTempDir
1109      mock_dbutils2 = mock.MagicMock()
1110      mock_dbutils2.entry_point.getUserLocalTempDir.return_value = "/local/user"
1111      with (
1112          mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils2),
1113          mock.patch(
1114              "mlflow.utils.databricks_utils._get_runtime_integration_client",
1115              side_effect=Exception("unavailable"),
1116          ),
1117      ):
1118          assert databricks_utils.get_databricks_local_temp_dir() == "/local/user"
1119          mock_dbutils2.entry_point.getUserLocalTempDir.assert_called_once()
1120  
1121  
1122  def test_get_databricks_host_creds_propagates_workspace_id(monkeypatch):
1123      monkeypatch.setenv("MLFLOW_ENABLE_DB_SDK", "true")
1124      monkeypatch.setenv("DATABRICKS_HOST", "https://spog.databricks.com")
1125      monkeypatch.setenv("DATABRICKS_TOKEN", "test-token")
1126  
1127      mock_config = mock.MagicMock()
1128      mock_config.workspace_id = "6051921418418893"
1129  
1130      mock_ws = mock.MagicMock()
1131      mock_ws.config = mock_config
1132  
1133      with mock.patch("databricks.sdk.WorkspaceClient", return_value=mock_ws) as mock_ws_cls:
1134          result = get_databricks_host_creds("databricks")
1135          mock_ws_cls.assert_called_once_with(profile=None)
1136          assert result.workspace_id == "6051921418418893"
1137          assert result.use_databricks_sdk
1138  
1139  
1140  def test_get_databricks_host_creds_workspace_id_none_when_not_set(monkeypatch):
1141      monkeypatch.setenv("MLFLOW_ENABLE_DB_SDK", "true")
1142      monkeypatch.setenv("DATABRICKS_HOST", "https://workspace.databricks.com")
1143      monkeypatch.setenv("DATABRICKS_TOKEN", "test-token")
1144  
1145      mock_config = mock.MagicMock()
1146      mock_config.workspace_id = None
1147  
1148      mock_ws = mock.MagicMock()
1149      mock_ws.config = mock_config
1150  
1151      with mock.patch("databricks.sdk.WorkspaceClient", return_value=mock_ws):
1152          result = get_databricks_host_creds("databricks")
1153          assert result.workspace_id is None
1154  
1155  
1156  def test_get_databricks_host_creds_workspace_id_from_config_on_sdk_failure(monkeypatch):
1157      monkeypatch.setenv("MLFLOW_ENABLE_DB_SDK", "true")
1158      monkeypatch.setenv("DATABRICKS_HOST", "https://spog.databricks.com")
1159      monkeypatch.setenv("DATABRICKS_TOKEN", "test-token")
1160  
1161      mock_config = mock.MagicMock()
1162      mock_config.workspace_id = "6051921418418893"
1163  
1164      with (
1165          mock.patch(
1166              "databricks.sdk.WorkspaceClient",
1167              side_effect=Exception("SDK auth failed"),
1168          ),
1169          mock.patch(
1170              "databricks.sdk.config.Config",
1171              return_value=mock_config,
1172          ),
1173      ):
1174          result = get_databricks_host_creds("databricks")
1175          assert result.workspace_id == "6051921418418893"
1176          assert not result.use_databricks_sdk
1177  
1178  
1179  def test_get_databricks_host_creds_workspace_id_none_on_full_failure(monkeypatch):
1180      monkeypatch.setenv("MLFLOW_ENABLE_DB_SDK", "true")
1181      monkeypatch.setenv("DATABRICKS_HOST", "https://workspace.databricks.com")
1182      monkeypatch.setenv("DATABRICKS_TOKEN", "test-token")
1183  
1184      with (
1185          mock.patch(
1186              "databricks.sdk.WorkspaceClient",
1187              side_effect=Exception("SDK auth failed"),
1188          ),
1189          mock.patch(
1190              "databricks.sdk.config.Config",
1191              side_effect=Exception("Config failed"),
1192          ),
1193      ):
1194          result = get_databricks_host_creds("databricks")
1195          assert result.workspace_id is None
1196          assert not result.use_databricks_sdk