/ tests / utils / test_oss_registry_utils.py
test_oss_registry_utils.py
 1  from unittest import mock
 2  
 3  import pytest
 4  
 5  from mlflow.exceptions import MlflowException
 6  from mlflow.utils.oss_registry_utils import get_oss_host_creds
 7  from mlflow.utils.rest_utils import MlflowHostCreds
 8  
 9  
10  @pytest.mark.parametrize(
11      ("server_uri", "expected_creds"),
12      [
13          ("uc:databricks-uc", MlflowHostCreds(host="databricks-uc")),
14          ("uc:http://localhost:8081", MlflowHostCreds(host="http://localhost:8081")),
15          ("invalid_scheme:http://localhost:8081", MlflowException),
16          ("databricks-uc", MlflowException),
17      ],
18  )
19  def test_get_oss_host_creds(server_uri, expected_creds):
20      with mock.patch(
21          "mlflow.utils.oss_registry_utils.get_databricks_host_creds",
22          return_value=MlflowHostCreds(host="databricks-uc"),
23      ):
24          if expected_creds == MlflowException:
25              with pytest.raises(
26                  MlflowException, match="The scheme of the server_uri should be 'uc'"
27              ):
28                  get_oss_host_creds(server_uri)
29          else:
30              actual_creds = get_oss_host_creds(server_uri)
31              assert actual_creds == expected_creds
32  
33  
34  def test_get_databricks_host_creds():
35      # Test case: When the scheme is "uc" and the new scheme is "_DATABRICKS_UNITY_CATALOG_SCHEME"
36      server_uri = "uc:databricks-uc"
37      with mock.patch(
38          "mlflow.utils.oss_registry_utils.get_databricks_host_creds"
39      ) as mock_get_databricks_host_creds:
40          get_oss_host_creds(server_uri)
41          assert mock_get_databricks_host_creds.call_args_list == [mock.call("databricks-uc")]