test_databricks_utils.py
1 import builtins 2 import json 3 import os 4 import platform 5 import sys 6 import time 7 from unittest import mock 8 9 import pytest 10 11 import mlflow 12 from mlflow.exceptions import MlflowException 13 from mlflow.legacy_databricks_cli.configure.provider import ( 14 DatabricksConfig, 15 DatabricksModelServingConfigProvider, 16 ) 17 from mlflow.utils import databricks_utils 18 from mlflow.utils.databricks_utils import ( 19 DatabricksConfigProvider, 20 DatabricksRuntimeVersion, 21 _NoDbutilsError, 22 check_databricks_secret_scope_access, 23 get_databricks_host_creds, 24 get_databricks_runtime_major_minor_version, 25 get_databricks_runtime_version, 26 get_databricks_workspace_client_config, 27 get_dbconnect_udf_sandbox_info, 28 get_mlflow_credential_context_by_run_id, 29 get_sgc_job_run_id, 30 get_workspace_info_from_databricks_secrets, 31 get_workspace_info_from_dbutils, 32 get_workspace_url, 33 is_databricks_default_tracking_uri, 34 is_running_in_ipython_environment, 35 ) 36 from mlflow.utils.os import is_windows 37 38 from tests.helper_functions import mock_method_chain 39 from tests.pyfunc.test_spark import spark # noqa: F401 40 41 42 def test_no_throw(): 43 """ 44 Outside of Databricks the databricks_utils methods should never throw and should only return 45 None. 46 """ 47 assert not databricks_utils.is_in_databricks_notebook() 48 assert not databricks_utils.is_in_databricks_repo_notebook() 49 assert not databricks_utils.is_in_databricks_job() 50 assert not databricks_utils.is_dbfs_fuse_available() 51 assert not databricks_utils.is_in_databricks_runtime() 52 53 54 def test_databricks_registry_profile(): 55 mock_provider = mock.MagicMock() 56 mock_provider.get_config.return_value = None 57 mock_dbutils = mock.MagicMock() 58 mock_dbutils.secrets.get.return_value = "random" 59 with ( 60 mock.patch( 61 "mlflow.utils.databricks_utils.ProfileConfigProvider", return_value=mock_provider 62 ), 63 mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils), 64 ): 65 params = databricks_utils.get_databricks_host_creds("databricks://profile:prefix") 66 mock_dbutils.secrets.get.assert_any_call(key="prefix-host", scope="profile") 67 mock_dbutils.secrets.get.assert_any_call(key="prefix-token", scope="profile") 68 assert params.host == "random" 69 assert params.token == "random" 70 71 72 def test_databricks_no_creds_found(): 73 with pytest.raises(MlflowException, match="Reading Databricks credential configuration failed"): 74 databricks_utils.get_databricks_host_creds() 75 76 77 def test_databricks_no_creds_found_in_model_serving(monkeypatch): 78 monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", "true") 79 with pytest.raises( 80 MlflowException, match="Reading Databricks credential configuration in model serving failed" 81 ): 82 databricks_utils.get_databricks_host_creds() 83 84 85 def test_databricks_single_slash_in_uri_scheme_throws(): 86 with pytest.raises(MlflowException, match="URI is formatted incorrectly"): 87 databricks_utils.get_databricks_host_creds("databricks:/profile:path") 88 89 90 @pytest.fixture 91 def oauth_file(tmp_path): 92 token_contents = {"OAUTH_TOKEN": [{"oauthTokenValue": "token2"}]} 93 oauth_file = tmp_path.joinpath("model-dependencies-oauth-token") 94 with open(oauth_file, "w") as f: 95 json.dump(token_contents, f) 96 return oauth_file 97 98 99 def test_get_model_dependency_token(oauth_file): 100 with mock.patch( 101 "mlflow.utils.databricks_utils._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH", str(oauth_file) 102 ): 103 token = databricks_utils.get_model_dependency_oauth_token() 104 assert token == "token2" 105 106 107 def test_get_model_dependency_oauth_token_model_serving_throws(): 108 with pytest.raises(MlflowException, match="Unable to read Oauth credentials"): 109 databricks_utils.get_model_dependency_oauth_token() 110 111 112 @pytest.mark.parametrize( 113 ("model_serving_env_var"), 114 [ 115 ("DATABRICKS_MODEL_SERVING_HOST_URL"), 116 ("DB_MODEL_SERVING_HOST_URL"), 117 ], 118 ) 119 def test_databricks_params_model_serving_oauth_cache_databricks( 120 monkeypatch, oauth_file, model_serving_env_var 121 ): 122 monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", "true") 123 monkeypatch.setenv(model_serving_env_var, "host") 124 monkeypatch.setenv("DB_DEPENDENCY_OAUTH_CACHE", "token") 125 monkeypatch.setenv("DB_DEPENDENCY_OAUTH_CACHE_EXPIRY_TS", str(time.time() + 5)) 126 # oauth file still needs to be present for should_fetch_model_serving_environment_oauth() 127 # to evaluate true 128 with mock.patch( 129 "mlflow.utils.databricks_utils._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH", str(oauth_file) 130 ): 131 params = databricks_utils.get_databricks_host_creds() 132 assert params.host == "host" 133 # should use token from cache, rather than token from oauthfile 134 assert params.token == "token" 135 136 137 def test_databricks_params_model_serving_oauth_cache_expired(monkeypatch, oauth_file): 138 monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", "true") 139 monkeypatch.setenv("DATABRICKS_MODEL_SERVING_HOST_URL", "host") 140 monkeypatch.setenv("DB_DEPENDENCY_OAUTH_CACHE", "token") 141 monkeypatch.setenv("DB_DEPENDENCY_OAUTH_CACHE_EXPIRY_TS", str(time.time() - 5)) 142 with mock.patch( 143 "mlflow.utils.databricks_utils._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH", str(oauth_file) 144 ): 145 params = databricks_utils.get_databricks_host_creds() 146 # cache should get updated with new token 147 assert os.environ["DB_DEPENDENCY_OAUTH_CACHE"] == "token2" 148 assert float(os.environ["DB_DEPENDENCY_OAUTH_CACHE_EXPIRY_TS"]) > time.time() 149 assert params.host == "host" 150 # should use token2 from oauthfile, rather than token from cache 151 assert params.token == "token2" 152 153 154 def test_databricks_params_model_serving_read_oauth(monkeypatch, oauth_file): 155 monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", "true") 156 monkeypatch.setenv("DATABRICKS_MODEL_SERVING_HOST_URL", "host") 157 with mock.patch( 158 "mlflow.utils.databricks_utils._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH", str(oauth_file) 159 ): 160 params = databricks_utils.get_databricks_host_creds() 161 assert os.environ["DB_DEPENDENCY_OAUTH_CACHE"] == "token2" 162 assert float(os.environ["DB_DEPENDENCY_OAUTH_CACHE_EXPIRY_TS"]) > time.time() 163 assert params.host == "host" 164 assert params.token == "token2" 165 166 167 def test_databricks_params_env_var_overrides_model_serving_oauth(monkeypatch, oauth_file): 168 monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", "true") 169 monkeypatch.setenv("DATABRICKS_MODEL_SERVING_HOST_URL", "host") 170 monkeypatch.setenv("DATABRICKS_HOST", "host_envvar") 171 monkeypatch.setenv("DATABRICKS_TOKEN", "pat_token") 172 # oauth file still needs to be present for should_fetch_model_serving_environment_oauth() 173 # to evaluate true 174 with mock.patch( 175 "mlflow.utils.databricks_utils._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH", str(oauth_file) 176 ): 177 params = databricks_utils.get_databricks_host_creds() 178 # should use token and host from envvar, rather than token from oauthfile 179 assert params.use_databricks_sdk 180 181 182 def test_model_serving_config_provider_errors_caught(): 183 provider = DatabricksModelServingConfigProvider() 184 with mock.patch.object( 185 provider, 186 "_get_databricks_model_serving_config", 187 side_effect=Exception("Failed to Read OAuth Creds"), 188 ): 189 assert provider.get_config() is None 190 191 192 def test_get_workspace_info_from_databricks_secrets(): 193 mock_dbutils = mock.MagicMock() 194 mock_dbutils.secrets.get.return_value = "workspace-placeholder-info" 195 with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils): 196 workspace_host, workspace_id = get_workspace_info_from_databricks_secrets( 197 "databricks://profile:prefix" 198 ) 199 mock_dbutils.secrets.get.assert_any_call(key="prefix-host", scope="profile") 200 mock_dbutils.secrets.get.assert_any_call(key="prefix-workspace-id", scope="profile") 201 assert workspace_host == "workspace-placeholder-info" 202 assert workspace_id == "workspace-placeholder-info" 203 204 205 def test_get_workspace_info_from_dbutils(): 206 mock_dbutils = mock.MagicMock() 207 methods = ["notebook.entry_point.getDbutils", "notebook", "getContext"] 208 mock_method_chain( 209 mock_dbutils, methods + ["browserHostName", "get"], return_value="mlflow.databricks.com" 210 ) 211 mock_method_chain(mock_dbutils, methods + ["workspaceId", "get"], return_value="1111") 212 213 with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils): 214 workspace_host, workspace_id = get_workspace_info_from_dbutils() 215 assert workspace_host == "https://mlflow.databricks.com" 216 assert workspace_id == "1111" 217 218 219 def test_get_workspace_info_from_dbutils_no_browser_host_name(): 220 mock_dbutils = mock.MagicMock() 221 methods = ["notebook.entry_point.getDbutils", "notebook", "getContext"] 222 mock_method_chain(mock_dbutils, methods + ["browserHostName", "get"], return_value=None) 223 mock_method_chain( 224 mock_dbutils, methods + ["apiUrl", "get"], return_value="https://mlflow.databricks.com" 225 ) 226 mock_method_chain(mock_dbutils, methods + ["workspaceId", "get"], return_value="1111") 227 with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils): 228 workspace_host, workspace_id = get_workspace_info_from_dbutils() 229 assert workspace_host == "https://mlflow.databricks.com" 230 assert workspace_id == "1111" 231 232 233 def test_get_workspace_info_from_dbutils_old_runtimes(): 234 mock_dbutils = mock.MagicMock() 235 methods = ["notebook.entry_point.getDbutils", "notebook", "getContext"] 236 mock_method_chain( 237 mock_dbutils, 238 methods + ["toJson", "get"], 239 return_value='{"tags": {"orgId" : "1111", "browserHostName": "mlflow.databricks.com"}}', 240 ) 241 mock_method_chain( 242 mock_dbutils, methods + ["browserHostName", "get"], return_value="mlflow.databricks.com" 243 ) 244 245 # Mock out workspace ID tag 246 mock_workspace_id_tag_opt = mock.MagicMock() 247 mock_workspace_id_tag_opt.isDefined.return_value = True 248 mock_workspace_id_tag_opt.get.return_value = "1111" 249 mock_method_chain( 250 mock_dbutils, methods + ["tags", "get"], return_value=mock_workspace_id_tag_opt 251 ) 252 253 # Mimic old runtimes by raising an exception when the nonexistent "workspaceId" method is called 254 mock_method_chain( 255 mock_dbutils, 256 methods + ["workspaceId"], 257 side_effect=Exception("workspaceId method not defined!"), 258 ) 259 with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils): 260 workspace_host, workspace_id = get_workspace_info_from_dbutils() 261 assert workspace_host == "https://mlflow.databricks.com" 262 assert workspace_id == "1111" 263 264 265 def test_get_workspace_info_from_dbutils_when_no_dbutils_available(): 266 with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=None): 267 workspace_host, workspace_id = get_workspace_info_from_dbutils() 268 assert workspace_host is None 269 assert workspace_id is None 270 271 272 @pytest.mark.parametrize( 273 ("tracking_uri", "result"), 274 [ 275 ("databricks", True), 276 ("databricks://profile:prefix", False), 277 ("databricks://profile/prefix", False), 278 ("nondatabricks", False), 279 ("databricks\t\r", True), 280 ("databricks\n", True), 281 ("databricks://", False), 282 ("databricks://aAbB", False), 283 ], 284 ) 285 def test_is_databricks_default_tracking_uri(tracking_uri, result): 286 assert is_databricks_default_tracking_uri(tracking_uri) == result 287 288 289 def test_databricks_params_throws_errors(): 290 # No hostname 291 mock_provider = mock.MagicMock() 292 mock_provider.get_config.return_value = DatabricksConfig.from_password( 293 None, "user", "pass", insecure=True 294 ) 295 with mock.patch( 296 "mlflow.utils.databricks_utils.ProfileConfigProvider", return_value=mock_provider 297 ): 298 with pytest.raises( 299 Exception, match="Reading Databricks credential configuration failed with" 300 ): 301 databricks_utils.get_databricks_host_creds() 302 303 # No authentication 304 mock_provider = mock.MagicMock() 305 mock_provider.get_config.return_value = DatabricksConfig.from_password( 306 "host", None, None, insecure=True 307 ) 308 with mock.patch( 309 "mlflow.utils.databricks_utils.ProfileConfigProvider", return_value=mock_provider 310 ): 311 with pytest.raises( 312 Exception, match="Reading Databricks credential configuration failed with" 313 ): 314 databricks_utils.get_databricks_host_creds() 315 316 317 def test_is_in_databricks_runtime(monkeypatch): 318 monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "11.x") 319 assert databricks_utils.is_in_databricks_runtime() 320 321 monkeypatch.delenv("DATABRICKS_RUNTIME_VERSION") 322 assert not databricks_utils.is_in_databricks_runtime() 323 324 325 @pytest.mark.parametrize("val", ["true", "1"]) 326 def test_is_in_databricks_model_serving_environment(monkeypatch, val): 327 monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", val) 328 assert databricks_utils.is_in_databricks_model_serving_environment() 329 330 monkeypatch.delenv("IS_IN_DB_MODEL_SERVING_ENV") 331 assert not databricks_utils.is_in_databricks_model_serving_environment() 332 333 334 # test both is_in_databricks_model_serving_environment and 335 # should_fetch_model_serving_environment_oauth return apprropriate values 336 def test_should_fetch_model_serving_environment_oauth(monkeypatch, oauth_file): 337 monkeypatch.setenv("IS_IN_DB_MODEL_SERVING_ENV", "true") 338 # will return false if file mount is not configured even if env var set 339 assert not databricks_utils.should_fetch_model_serving_environment_oauth() 340 341 with mock.patch( 342 "mlflow.utils.databricks_utils._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH", str(oauth_file) 343 ): 344 # both file mount and env var exist, both values should return true 345 assert databricks_utils.should_fetch_model_serving_environment_oauth() 346 347 # file mount without env var should return false 348 monkeypatch.delenv("IS_IN_DB_MODEL_SERVING_ENV") 349 assert not databricks_utils.should_fetch_model_serving_environment_oauth() 350 351 352 def test_get_repl_id(): 353 # Outside of Databricks environments, the Databricks REPL ID should be absent 354 assert databricks_utils.get_repl_id() is None 355 356 mock_client = mock.MagicMock() 357 mock_client.getReplId.return_value = "testReplId1" 358 with mock.patch( 359 "mlflow.utils.databricks_utils._get_runtime_integration_client", 360 return_value=mock_client, 361 ): 362 assert databricks_utils.get_repl_id() == "testReplId1" 363 mock_client.getReplId.assert_called_once() 364 365 # When runtime_integration_client is unavailable, fall back to entry_point. 366 mock_dbutils = mock.MagicMock() 367 mock_dbutils.entry_point.getReplId.return_value = "testReplId1" 368 with ( 369 mock.patch( 370 "mlflow.utils.databricks_utils._get_runtime_integration_client", 371 side_effect=Exception("unavailable"), 372 ), 373 mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils), 374 ): 375 assert databricks_utils.get_repl_id() == "testReplId1" 376 mock_dbutils.entry_point.getReplId.assert_called_once() 377 378 mock_sparkcontext_inst = mock.MagicMock() 379 mock_sparkcontext_inst.getLocalProperty.return_value = "testReplId2" 380 mock_sparkcontext_class = mock.MagicMock() 381 mock_sparkcontext_class.getOrCreate.return_value = mock_sparkcontext_inst 382 mock_spark = mock.MagicMock() 383 mock_spark.SparkContext = mock_sparkcontext_class 384 385 original_import = builtins.__import__ 386 387 def mock_import(name, *args, **kwargs): 388 if name == "pyspark": 389 return mock_spark 390 else: 391 return original_import(name, *args, **kwargs) 392 393 with mock.patch("builtins.__import__", side_effect=mock_import): 394 assert databricks_utils.get_repl_id() == "testReplId2" 395 396 397 def test_use_repl_context_if_available(tmp_path, monkeypatch): 398 # Simulate a case where `dbruntime.databricks_repl_context.get_context` is unavailable. 399 with pytest.raises(ModuleNotFoundError, match="No module named 'dbruntime'"): 400 from dbruntime.databricks_repl_context import get_context # noqa: F401 401 402 command_context_mock = mock.MagicMock() 403 command_context_mock.jobId().get.return_value = "job_id" 404 command_context_mock.tags().get("jobType").get.return_value = "NORMAL" 405 with mock.patch( 406 "mlflow.utils.databricks_utils._get_command_context", return_value=command_context_mock 407 ) as mock_get_command_context: 408 assert databricks_utils.get_job_id() == "job_id" 409 mock_get_command_context.assert_called_once() 410 411 # Create a fake databricks_repl_context module 412 dbruntime = tmp_path.joinpath("dbruntime") 413 dbruntime.mkdir() 414 dbruntime.joinpath("databricks_repl_context.py").write_text( 415 """ 416 def get_context(): 417 pass 418 """ 419 ) 420 monkeypatch.syspath_prepend(str(tmp_path)) 421 422 # Simulate a case where the REPL context object is not initialized. 423 with ( 424 mock.patch( 425 "dbruntime.databricks_repl_context.get_context", 426 return_value=None, 427 ) as mock_get_context, 428 mock.patch( 429 "mlflow.utils.databricks_utils._get_command_context", return_value=command_context_mock 430 ) as mock_get_command_context, 431 ): 432 assert databricks_utils.get_job_id() == "job_id" 433 assert mock_get_command_context.call_count == 1 434 435 with ( 436 mock.patch( 437 "dbruntime.databricks_repl_context.get_context", 438 return_value=mock.MagicMock(jobId="job_id"), 439 ) as mock_get_context, 440 mock.patch("mlflow.utils.databricks_utils._get_dbutils") as mock_dbutils, 441 ): 442 assert databricks_utils.get_job_id() == "job_id" 443 mock_get_context.assert_called_once() 444 mock_dbutils.assert_not_called() 445 446 with ( 447 mock.patch( 448 "dbruntime.databricks_repl_context.get_context", 449 return_value=mock.MagicMock( 450 notebookId="notebook_id", notebookPath="/Repos/notebook_path" 451 ), 452 ) as mock_get_context, 453 mock.patch( 454 "mlflow.utils.databricks_utils._get_property_from_spark_context" 455 ) as mock_spark_context, 456 ): 457 assert databricks_utils.get_notebook_id() == "notebook_id" 458 assert databricks_utils.is_in_databricks_repo_notebook() 459 assert mock_get_context.call_count == 2 460 mock_spark_context.assert_not_called() 461 462 with ( 463 mock.patch( 464 "dbruntime.databricks_repl_context.get_context", 465 return_value=mock.MagicMock( 466 notebookId="notebook_id", notebookPath="/Users/notebook_path" 467 ), 468 ) as mock_get_context, 469 mock.patch( 470 "mlflow.utils.databricks_utils._get_property_from_spark_context" 471 ) as mock_spark_context, 472 ): 473 assert not databricks_utils.is_in_databricks_repo_notebook() 474 475 with ( 476 mock.patch( 477 "dbruntime.databricks_repl_context.get_context", 478 return_value=mock.MagicMock(isInCluster=True), 479 ) as mock_get_context, 480 mock.patch("mlflow.utils._spark_utils._get_active_spark_session") as mock_spark_session, 481 ): 482 assert databricks_utils.is_in_cluster() 483 mock_get_context.assert_called_once() 484 mock_spark_session.assert_not_called() 485 486 487 @pytest.mark.parametrize("get_ipython", [True, None]) 488 def test_is_running_in_ipython_environment_works(get_ipython): 489 mod_name = "IPython" 490 if mod_name in sys.modules: 491 ipython_mod = sys.modules.pop(mod_name) 492 assert not is_running_in_ipython_environment() 493 sys.modules["IPython"] = ipython_mod 494 495 with mock.patch("IPython.get_ipython", return_value=get_ipython): 496 assert is_running_in_ipython_environment() == (get_ipython is not None) 497 498 499 def test_get_mlflow_credential_context_by_run_id(): 500 with ( 501 mock.patch( 502 "mlflow.tracking.artifact_utils.get_artifact_uri", return_value="dbfs:/path/to/artifact" 503 ) as mock_get_artifact_uri, 504 mock.patch( 505 "mlflow.utils.uri.get_databricks_profile_uri_from_artifact_uri", 506 return_value="databricks://path/to/profile", 507 ) as mock_get_databricks_profile, 508 mock.patch( 509 "mlflow.utils.databricks_utils.MlflowCredentialContext" 510 ) as mock_credential_context, 511 ): 512 get_mlflow_credential_context_by_run_id(run_id="abc") 513 mock_get_artifact_uri.assert_called_once_with(run_id="abc") 514 mock_get_databricks_profile.assert_called_once_with("dbfs:/path/to/artifact") 515 mock_credential_context.assert_called_once_with("databricks://path/to/profile") 516 517 518 def test_check_databricks_secret_scope_access(): 519 mock_dbutils = mock.MagicMock() 520 mock_dbutils.secrets.list.return_value = "random" 521 with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils): 522 check_databricks_secret_scope_access("scope") 523 mock_dbutils.secrets.list.assert_called_once_with("scope") 524 525 526 def test_check_databricks_secret_scope_access_error(): 527 mock_dbutils = mock.MagicMock() 528 mock_dbutils.secrets.list.side_effect = Exception("no scope access") 529 with ( 530 mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils), 531 mock.patch("mlflow.utils.databricks_utils._logger.warning") as mock_warning, 532 ): 533 check_databricks_secret_scope_access("scope") 534 mock_warning.assert_called_once_with( 535 "Unable to access Databricks secret scope 'scope' for OpenAI credentials that will be " 536 "used to deploy the model to Databricks Model Serving. Please verify that the current " 537 "Databricks user has 'READ' permission for this scope. For more information, see " 538 "https://mlflow.org/docs/latest/python_api/openai/index.html#credential-management-for-openai-on-databricks. " # noqa: E501 539 "Error: no scope access" 540 ) 541 mock_dbutils.secrets.list.assert_called_once_with("scope") 542 543 544 @pytest.mark.parametrize( 545 ("version_str", "is_client_image", "major", "minor"), 546 [ 547 ("client.0", True, 0, 0), 548 ("client.1", True, 1, 0), 549 ("client.1.6", True, 1, 6), 550 ("15.1", False, 15, 1), 551 ("12.1.1", False, 12, 1), 552 ], 553 ) 554 def test_get_databricks_runtime_major_minor_version( 555 monkeypatch, version_str, is_client_image, major, minor 556 ): 557 monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", version_str) 558 dbr_version = get_databricks_runtime_major_minor_version() 559 560 assert dbr_version.is_client_image == is_client_image 561 assert dbr_version.major == major 562 assert dbr_version.minor == minor 563 564 565 def test_get_dbr_major_minor_version_throws_on_invalid_version_key(monkeypatch): 566 # minor version is not allowed to be a string 567 monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "12.x") 568 with pytest.raises(MlflowException, match="Failed to parse databricks runtime version"): 569 get_databricks_runtime_major_minor_version() 570 571 572 def test_prioritize_env_var_config_provider(monkeypatch): 573 monkeypatch.setenv("DATABRICKS_HOST", "my_host1") 574 monkeypatch.setenv("DATABRICKS_TOKEN", "token1") 575 576 class MyProvider(DatabricksConfigProvider): 577 def get_config(self): 578 return DatabricksConfig(host="my_host2", token="token2") 579 580 monkeypatch.setattr(databricks_utils, "_dynamic_token_config_provider", MyProvider) 581 582 hc = get_databricks_host_creds("databricks") 583 assert hc.host == "my_host1" 584 assert hc.token == "token1" 585 586 587 @pytest.mark.parametrize( 588 ("input_url", "expected_result"), 589 [ 590 # Test with a valid URL without https:// prefix 591 ("example.com", "https://example.com"), 592 # Test with a valid URL with https:// prefix 593 ("https://example.com", "https://example.com"), 594 # Test with None URL 595 (None, None), 596 ], 597 ) 598 def test_get_workspace_url(input_url, expected_result): 599 with mock.patch("mlflow.utils.databricks_utils._get_workspace_url", return_value=input_url): 600 result = get_workspace_url() 601 assert result == expected_result 602 603 604 @pytest.mark.skipif(is_windows(), reason="This test doesn't work on Windows") 605 def test_get_dbconnect_udf_sandbox_info(spark, monkeypatch): 606 monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "client.1.2") 607 databricks_utils._dbconnect_udf_sandbox_info_cache = None 608 609 spark.udf.register( 610 "current_version", 611 lambda: {"dbr_version": "15.4.x-scala2.12"}, 612 returnType="dbr_version string", 613 ) 614 615 info = get_dbconnect_udf_sandbox_info(spark) 616 assert info.mlflow_version == mlflow.__version__ 617 assert info.image_version == "client.1.2" 618 assert info.runtime_version == "15.4" 619 assert info.platform_machine == platform.machine() 620 621 monkeypatch.delenv("DATABRICKS_RUNTIME_VERSION") 622 databricks_utils._dbconnect_udf_sandbox_info_cache = None 623 624 info = get_dbconnect_udf_sandbox_info(spark) 625 assert info.mlflow_version == mlflow.__version__ 626 assert info.image_version == "15.4" 627 assert info.runtime_version == "15.4" 628 assert info.platform_machine == platform.machine() 629 630 631 def test_construct_databricks_uc_registered_model_url(): 632 # Test case with workspace ID 633 workspace_url = "https://databricks.com" 634 registered_model_name = "name.mlflow.echo_model" 635 version = "6" 636 workspace_id = "123" 637 638 expected_url = ( 639 "https://databricks.com/explore/data/models/name/mlflow/echo_model/version/6?o=123" 640 ) 641 642 result = databricks_utils._construct_databricks_uc_registered_model_url( 643 workspace_url=workspace_url, 644 registered_model_name=registered_model_name, 645 version=version, 646 workspace_id=workspace_id, 647 ) 648 649 assert result == expected_url 650 651 # Test case without workspace ID 652 expected_url_no_workspace = ( 653 "https://databricks.com/explore/data/models/name/mlflow/echo_model/version/6" 654 ) 655 656 result_no_workspace = databricks_utils._construct_databricks_uc_registered_model_url( 657 workspace_url=workspace_url, 658 registered_model_name=registered_model_name, 659 version=version, 660 ) 661 662 assert result_no_workspace == expected_url_no_workspace 663 664 665 def test_construct_databricks_logged_model_url(): 666 # Test case with workspace ID 667 workspace_url = "https://databricks.com" 668 experiment_id = "123456" 669 model_id = "model_789" 670 workspace_id = "123" 671 672 expected_url = "https://databricks.com/ml/experiments/123456/models/model_789?o=123" 673 674 result = databricks_utils._construct_databricks_logged_model_url( 675 workspace_url=workspace_url, 676 experiment_id=experiment_id, 677 model_id=model_id, 678 workspace_id=workspace_id, 679 ) 680 681 assert result == expected_url 682 683 # Test case without workspace ID 684 expected_url_no_workspace = "https://databricks.com/ml/experiments/123456/models/model_789" 685 686 result_no_workspace = databricks_utils._construct_databricks_logged_model_url( 687 workspace_url=workspace_url, 688 experiment_id=experiment_id, 689 model_id=model_id, 690 ) 691 692 assert result_no_workspace == expected_url_no_workspace 693 694 695 def test_print_databricks_deployment_job_url(): 696 workspace_url = "https://databricks.com" 697 job_id = "123" 698 workspace_id = "456" 699 700 expected_url_no_workspace = "https://databricks.com/jobs/123" 701 expected_url = f"{expected_url_no_workspace}?o=456" 702 model_name = "main.models.name" 703 704 with ( 705 mock.patch("mlflow.utils.databricks_utils.eprint") as mock_eprint, 706 mock.patch("mlflow.utils.databricks_utils.get_workspace_url", return_value=workspace_url), 707 ): 708 # Test case with a workspace ID 709 with mock.patch( 710 "mlflow.utils.databricks_utils.get_workspace_id", return_value=workspace_id 711 ): 712 result = databricks_utils._print_databricks_deployment_job_url( 713 model_name=model_name, 714 job_id=job_id, 715 ) 716 717 assert result == expected_url 718 mock_eprint.assert_called_once_with( 719 f"🔗 Linked deployment job to '{model_name}': {expected_url}" 720 ) 721 mock_eprint.reset_mock() 722 723 # Test case without a workspace ID 724 with mock.patch("mlflow.utils.databricks_utils.get_workspace_id", return_value=None): 725 result_no_workspace = databricks_utils._print_databricks_deployment_job_url( 726 model_name=model_name, 727 job_id=job_id, 728 ) 729 730 assert result_no_workspace == expected_url_no_workspace 731 mock_eprint.assert_called_once_with( 732 f"🔗 Linked deployment job to '{model_name}': {expected_url_no_workspace}" 733 ) 734 735 736 @pytest.mark.parametrize( 737 ("version_str", "expected_is_client", "expected_major", "expected_minor", "expected_is_gpu"), 738 [ 739 ("client.2.0", True, 2, 0, False), 740 ("client.3.1", True, 3, 1, False), 741 ("13.2", False, 13, 2, False), 742 ("15.4", False, 15, 4, False), 743 ("client.8.1-gpu", True, 8, 1, True), 744 ("client.10.0-gpu", True, 10, 0, True), 745 ("14.3-gpu", False, 14, 3, True), 746 ("15.1-gpu", False, 15, 1, True), 747 ], 748 ) 749 def test_databricks_runtime_version_parse( 750 version_str, 751 expected_is_client, 752 expected_major, 753 expected_minor, 754 expected_is_gpu, 755 ): 756 version = DatabricksRuntimeVersion.parse(version_str) 757 assert version.is_client_image == expected_is_client 758 assert version.major == expected_major 759 assert version.minor == expected_minor 760 assert version.is_gpu_image == expected_is_gpu 761 762 763 @pytest.mark.parametrize( 764 ("env_version", "expected_is_client", "expected_major", "expected_minor", "expected_is_gpu"), 765 [ 766 ("client.2.0", True, 2, 0, False), 767 ("13.2", False, 13, 2, False), 768 ("client.8.1-gpu", True, 8, 1, True), 769 ("14.3-gpu", False, 14, 3, True), 770 ], 771 ) 772 def test_databricks_runtime_version_parse_default( 773 monkeypatch, 774 env_version, 775 expected_is_client, 776 expected_major, 777 expected_minor, 778 expected_is_gpu, 779 ): 780 monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", env_version) 781 version = DatabricksRuntimeVersion.parse() 782 assert version.is_client_image == expected_is_client 783 assert version.major == expected_major 784 assert version.minor == expected_minor 785 assert version.is_gpu_image == expected_is_gpu 786 787 788 def test_databricks_runtime_version_parse_default_no_env(monkeypatch): 789 """Test that DatabricksRuntimeVersion.parse() raises error when no environment variable is 790 set. 791 """ 792 monkeypatch.delenv("DATABRICKS_RUNTIME_VERSION", raising=False) 793 monkeypatch.delenv("DATABRICKS_ENV_VERSION", raising=False) 794 with pytest.raises(Exception, match="Failed to parse databricks runtime version"): 795 DatabricksRuntimeVersion.parse() 796 797 798 @pytest.mark.parametrize( 799 ("env_version", "accelerator", "expected"), 800 [ 801 ("4", "A10G", "client.4-gpu"), 802 ("4", "NVIDIA H100", "client.4-gpu"), 803 ("4", None, "client.4"), 804 ], 805 ) 806 def test_get_databricks_runtime_version_from_env_version( 807 monkeypatch, env_version, accelerator, expected 808 ): 809 monkeypatch.delenv("DATABRICKS_RUNTIME_VERSION", raising=False) 810 monkeypatch.setenv("DATABRICKS_ENV_VERSION", env_version) 811 if accelerator: 812 monkeypatch.setenv("DATABRICKS_ACCELERATOR", accelerator) 813 else: 814 monkeypatch.delenv("DATABRICKS_ACCELERATOR", raising=False) 815 assert get_databricks_runtime_version() == expected 816 817 818 @pytest.mark.parametrize( 819 ("accelerator", "expected"), 820 [ 821 ("A10G", "client.4-gpu"), 822 (None, "client.4"), 823 ], 824 ) 825 def test_databricks_env_version_takes_priority_over_runtime_version( 826 monkeypatch, accelerator, expected 827 ): 828 monkeypatch.setenv("DATABRICKS_ENV_VERSION", "4") 829 monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "client.4.1") 830 if accelerator: 831 monkeypatch.setenv("DATABRICKS_ACCELERATOR", accelerator) 832 else: 833 monkeypatch.delenv("DATABRICKS_ACCELERATOR", raising=False) 834 assert get_databricks_runtime_version() == expected 835 836 837 def test_databricks_runtime_version_parse_from_env_version(monkeypatch): 838 monkeypatch.delenv("DATABRICKS_RUNTIME_VERSION", raising=False) 839 monkeypatch.setenv("DATABRICKS_ENV_VERSION", "4") 840 monkeypatch.setenv("DATABRICKS_ACCELERATOR", "A10") 841 version = DatabricksRuntimeVersion.parse() 842 assert version.is_client_image is True 843 assert version.major == 4 844 assert version.minor == 0 845 assert version.is_gpu_image is True 846 847 848 @pytest.mark.parametrize( 849 "invalid_version", 850 [ 851 "invalid", 852 "client", 853 "client.invalid", 854 "13", 855 ], 856 ) 857 def test_databricks_runtime_version_parse_invalid(invalid_version): 858 with pytest.raises(Exception, match="Failed to parse databricks runtime version"): 859 DatabricksRuntimeVersion.parse(invalid_version) 860 861 862 def test_get_databricks_workspace_client_config_with_tracking_uri_provider(): 863 # Mock the workspace client and its config 864 mock_config = mock.MagicMock() 865 mock_client_instance = mock.MagicMock() 866 mock_client_instance.config = mock_config 867 868 # Mock TrackingURIConfigProvider 869 mock_uri_config = mock.MagicMock() 870 mock_uri_config.host = "https://test.databricks.com" 871 mock_uri_config.token = "test_token" 872 873 with ( 874 mock.patch( 875 "mlflow.utils.databricks_utils.get_db_info_from_uri", 876 return_value=("profile_name", "key_prefix"), 877 ), 878 mock.patch( 879 "databricks.sdk.WorkspaceClient", return_value=mock_client_instance 880 ) as mock_workspace_client, 881 mock.patch("mlflow.utils.databricks_utils.TrackingURIConfigProvider") as mock_provider, 882 ): 883 mock_provider.return_value.get_config.return_value = mock_uri_config 884 885 result = get_databricks_workspace_client_config("databricks://profile:prefix") 886 887 # Verify the WorkspaceClient was created with correct parameters 888 mock_workspace_client.assert_called_once_with( 889 host="https://test.databricks.com", token="test_token" 890 ) 891 assert result == mock_config 892 893 894 def test_get_databricks_workspace_client_config_with_profile(): 895 # Mock the workspace client and its config 896 mock_config = mock.MagicMock() 897 mock_client_instance = mock.MagicMock() 898 mock_client_instance.config = mock_config 899 900 with ( 901 mock.patch( 902 "mlflow.utils.databricks_utils.get_db_info_from_uri", 903 return_value=("profile_name", None), 904 ), 905 mock.patch( 906 "databricks.sdk.WorkspaceClient", return_value=mock_client_instance 907 ) as mock_workspace_client, 908 ): 909 result = get_databricks_workspace_client_config("databricks://profile_name") 910 911 # Verify the WorkspaceClient was created with profile 912 mock_workspace_client.assert_called_once_with(profile="profile_name") 913 assert result == mock_config 914 915 916 def test_get_databricks_workspace_client_config_env_profile(monkeypatch): 917 monkeypatch.setenv("DATABRICKS_CONFIG_PROFILE", "env_profile") 918 # Mock the workspace client and its config 919 mock_config = mock.MagicMock() 920 mock_client_instance = mock.MagicMock() 921 mock_client_instance.config = mock_config 922 923 with ( 924 mock.patch("mlflow.utils.databricks_utils.get_db_info_from_uri", return_value=(None, None)), 925 mock.patch( 926 "databricks.sdk.WorkspaceClient", return_value=mock_client_instance 927 ) as mock_workspace_client, 928 ): 929 result = get_databricks_workspace_client_config("databricks") 930 931 # Verify the WorkspaceClient was created with environment profile 932 mock_workspace_client.assert_called_once_with(profile="env_profile") 933 assert result == mock_config 934 935 936 def test_get_databricks_workspace_client_config_client_creation_error(): 937 with ( 938 mock.patch( 939 "mlflow.utils.databricks_utils.get_db_info_from_uri", return_value=("profile", None) 940 ), 941 mock.patch( 942 "databricks.sdk.WorkspaceClient", side_effect=Exception("Client creation failed") 943 ), 944 ): 945 with pytest.raises(Exception, match="Client creation failed"): 946 get_databricks_workspace_client_config("databricks://profile") 947 948 949 def test_get_sgc_job_run_id_success(monkeypatch): 950 monkeypatch.delenv("SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID", raising=False) 951 mock_dbutils = mock.MagicMock() 952 mock_dbutils.widgets.get.return_value = "test_job_run_id_12345" 953 954 with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils): 955 result = get_sgc_job_run_id() 956 assert result == "test_job_run_id_12345" 957 mock_dbutils.widgets.get.assert_called_once_with( 958 "SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID" 959 ) 960 961 962 def test_get_sgc_job_run_id_no_dbutils(monkeypatch): 963 monkeypatch.delenv("SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID", raising=False) 964 with mock.patch("mlflow.utils.databricks_utils._get_dbutils", side_effect=_NoDbutilsError()): 965 result = get_sgc_job_run_id() 966 assert result is None 967 968 969 def test_get_sgc_job_run_id_no_dbutils_with_env_var(monkeypatch): 970 monkeypatch.setenv("SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID", "env_job_run_id_456") 971 with mock.patch("mlflow.utils.databricks_utils._get_dbutils", side_effect=_NoDbutilsError()): 972 result = get_sgc_job_run_id() 973 assert result == "env_job_run_id_456" 974 975 976 def test_get_sgc_job_run_id_value_error(monkeypatch): 977 monkeypatch.delenv("SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID", raising=False) 978 mock_dbutils = mock.MagicMock() 979 mock_dbutils.widgets.get.side_effect = ValueError("Widget not found") 980 981 with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils): 982 result = get_sgc_job_run_id() 983 assert result is None 984 mock_dbutils.widgets.get.assert_called_once_with( 985 "SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID" 986 ) 987 988 989 def test_get_sgc_job_run_id_value_error_with_env_var(monkeypatch): 990 monkeypatch.setenv("SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID", "env_job_run_id_789") 991 mock_dbutils = mock.MagicMock() 992 mock_dbutils.widgets.get.side_effect = ValueError("Widget not found") 993 994 with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils): 995 result = get_sgc_job_run_id() 996 assert result == "env_job_run_id_789" 997 mock_dbutils.widgets.get.assert_called_once_with( 998 "SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID" 999 ) 1000 1001 1002 def test_get_sgc_job_run_id_empty_widget_with_env_var(monkeypatch): 1003 monkeypatch.setenv("SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID", "env_job_run_id_999") 1004 mock_dbutils = mock.MagicMock() 1005 mock_dbutils.widgets.get.return_value = "" 1006 1007 with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils): 1008 result = get_sgc_job_run_id() 1009 assert result == "env_job_run_id_999" 1010 mock_dbutils.widgets.get.assert_called_once_with( 1011 "SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID" 1012 ) 1013 1014 1015 def test_get_sgc_job_run_id_none_widget_with_env_var(monkeypatch): 1016 monkeypatch.setenv("SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID", "env_job_run_id_111") 1017 mock_dbutils = mock.MagicMock() 1018 mock_dbutils.widgets.get.return_value = None 1019 1020 with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils): 1021 result = get_sgc_job_run_id() 1022 assert result == "env_job_run_id_111" 1023 mock_dbutils.widgets.get.assert_called_once_with( 1024 "SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID" 1025 ) 1026 1027 1028 def test_get_sgc_job_run_id_widget_takes_precedence_over_env_var(monkeypatch): 1029 monkeypatch.setenv("SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID", "env_job_run_id_222") 1030 mock_dbutils = mock.MagicMock() 1031 mock_dbutils.widgets.get.return_value = "widget_job_run_id_333" 1032 1033 with mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils): 1034 result = get_sgc_job_run_id() 1035 assert result == "widget_job_run_id_333" 1036 mock_dbutils.widgets.get.assert_called_once_with( 1037 "SERVERLESS_GPU_COMPUTE_ASSOCIATED_JOB_RUN_ID" 1038 ) 1039 1040 1041 def test_databricks_config_profile_env_var_is_respected(tmp_path, monkeypatch): 1042 file_path = tmp_path / ".databrickscfg" 1043 monkeypatch.setenv("MLFLOW_TRACKING_URI", "databricks") 1044 monkeypatch.setenv("DATABRICKS_CONFIG_FILE", str(file_path)) 1045 monkeypatch.setenv("DATABRICKS_CONFIG_PROFILE", "test") 1046 1047 file_path.write_text("""[DEFAULT] 1048 host = http://default-workspace.databricks.com 1049 token = default-token 1050 1051 [test] 1052 host = https://test-workspace.databricks.com 1053 token = test-token 1054 """) 1055 1056 # the resulting config should be the one from the [test] section 1057 result = get_databricks_host_creds("databricks") 1058 assert result.host == "https://test-workspace.databricks.com" 1059 assert result.token == "test-token" 1060 1061 1062 def test_get_databricks_nfs_temp_dir(): 1063 mock_dbutils = mock.MagicMock() 1064 mock_client = mock.MagicMock() 1065 mock_client.getUserNFSTempDir.return_value = "/nfs/user/grpc" 1066 1067 # When runtime_integration_client is available, use getUserNFSTempDir from client 1068 with ( 1069 mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils), 1070 mock.patch( 1071 "mlflow.utils.databricks_utils._get_runtime_integration_client", 1072 return_value=mock_client, 1073 ), 1074 ): 1075 assert databricks_utils.get_databricks_nfs_temp_dir() == "/nfs/user/grpc" 1076 mock_client.getUserNFSTempDir.assert_called_once() 1077 1078 # When runtime_integration_client raises, fall back to entry_point.getUserNFSTempDir 1079 mock_dbutils2 = mock.MagicMock() 1080 mock_dbutils2.entry_point.getUserNFSTempDir.return_value = "/nfs/user" 1081 with ( 1082 mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils2), 1083 mock.patch( 1084 "mlflow.utils.databricks_utils._get_runtime_integration_client", 1085 side_effect=Exception("unavailable"), 1086 ), 1087 ): 1088 assert databricks_utils.get_databricks_nfs_temp_dir() == "/nfs/user" 1089 mock_dbutils2.entry_point.getUserNFSTempDir.assert_called_once() 1090 1091 1092 def test_get_databricks_local_temp_dir(): 1093 mock_dbutils = mock.MagicMock() 1094 mock_client = mock.MagicMock() 1095 mock_client.getUserLocalTempDir.return_value = "/local/user/grpc" 1096 1097 # When runtime_integration_client is available, use getUserLocalTempDir from client 1098 with ( 1099 mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils), 1100 mock.patch( 1101 "mlflow.utils.databricks_utils._get_runtime_integration_client", 1102 return_value=mock_client, 1103 ), 1104 ): 1105 assert databricks_utils.get_databricks_local_temp_dir() == "/local/user/grpc" 1106 mock_client.getUserLocalTempDir.assert_called_once() 1107 1108 # When runtime_integration_client raises, fall back to entry_point.getUserLocalTempDir 1109 mock_dbutils2 = mock.MagicMock() 1110 mock_dbutils2.entry_point.getUserLocalTempDir.return_value = "/local/user" 1111 with ( 1112 mock.patch("mlflow.utils.databricks_utils._get_dbutils", return_value=mock_dbutils2), 1113 mock.patch( 1114 "mlflow.utils.databricks_utils._get_runtime_integration_client", 1115 side_effect=Exception("unavailable"), 1116 ), 1117 ): 1118 assert databricks_utils.get_databricks_local_temp_dir() == "/local/user" 1119 mock_dbutils2.entry_point.getUserLocalTempDir.assert_called_once() 1120 1121 1122 def test_get_databricks_host_creds_propagates_workspace_id(monkeypatch): 1123 monkeypatch.setenv("MLFLOW_ENABLE_DB_SDK", "true") 1124 monkeypatch.setenv("DATABRICKS_HOST", "https://spog.databricks.com") 1125 monkeypatch.setenv("DATABRICKS_TOKEN", "test-token") 1126 1127 mock_config = mock.MagicMock() 1128 mock_config.workspace_id = "6051921418418893" 1129 1130 mock_ws = mock.MagicMock() 1131 mock_ws.config = mock_config 1132 1133 with mock.patch("databricks.sdk.WorkspaceClient", return_value=mock_ws) as mock_ws_cls: 1134 result = get_databricks_host_creds("databricks") 1135 mock_ws_cls.assert_called_once_with(profile=None) 1136 assert result.workspace_id == "6051921418418893" 1137 assert result.use_databricks_sdk 1138 1139 1140 def test_get_databricks_host_creds_workspace_id_none_when_not_set(monkeypatch): 1141 monkeypatch.setenv("MLFLOW_ENABLE_DB_SDK", "true") 1142 monkeypatch.setenv("DATABRICKS_HOST", "https://workspace.databricks.com") 1143 monkeypatch.setenv("DATABRICKS_TOKEN", "test-token") 1144 1145 mock_config = mock.MagicMock() 1146 mock_config.workspace_id = None 1147 1148 mock_ws = mock.MagicMock() 1149 mock_ws.config = mock_config 1150 1151 with mock.patch("databricks.sdk.WorkspaceClient", return_value=mock_ws): 1152 result = get_databricks_host_creds("databricks") 1153 assert result.workspace_id is None 1154 1155 1156 def test_get_databricks_host_creds_workspace_id_from_config_on_sdk_failure(monkeypatch): 1157 monkeypatch.setenv("MLFLOW_ENABLE_DB_SDK", "true") 1158 monkeypatch.setenv("DATABRICKS_HOST", "https://spog.databricks.com") 1159 monkeypatch.setenv("DATABRICKS_TOKEN", "test-token") 1160 1161 mock_config = mock.MagicMock() 1162 mock_config.workspace_id = "6051921418418893" 1163 1164 with ( 1165 mock.patch( 1166 "databricks.sdk.WorkspaceClient", 1167 side_effect=Exception("SDK auth failed"), 1168 ), 1169 mock.patch( 1170 "databricks.sdk.config.Config", 1171 return_value=mock_config, 1172 ), 1173 ): 1174 result = get_databricks_host_creds("databricks") 1175 assert result.workspace_id == "6051921418418893" 1176 assert not result.use_databricks_sdk 1177 1178 1179 def test_get_databricks_host_creds_workspace_id_none_on_full_failure(monkeypatch): 1180 monkeypatch.setenv("MLFLOW_ENABLE_DB_SDK", "true") 1181 monkeypatch.setenv("DATABRICKS_HOST", "https://workspace.databricks.com") 1182 monkeypatch.setenv("DATABRICKS_TOKEN", "test-token") 1183 1184 with ( 1185 mock.patch( 1186 "databricks.sdk.WorkspaceClient", 1187 side_effect=Exception("SDK auth failed"), 1188 ), 1189 mock.patch( 1190 "databricks.sdk.config.Config", 1191 side_effect=Exception("Config failed"), 1192 ), 1193 ): 1194 result = get_databricks_host_creds("databricks") 1195 assert result.workspace_id is None 1196 assert not result.use_databricks_sdk