test_oss_registry_utils.py
1 from unittest import mock 2 3 import pytest 4 5 from mlflow.exceptions import MlflowException 6 from mlflow.utils.oss_registry_utils import get_oss_host_creds 7 from mlflow.utils.rest_utils import MlflowHostCreds 8 9 10 @pytest.mark.parametrize( 11 ("server_uri", "expected_creds"), 12 [ 13 ("uc:databricks-uc", MlflowHostCreds(host="databricks-uc")), 14 ("uc:http://localhost:8081", MlflowHostCreds(host="http://localhost:8081")), 15 ("invalid_scheme:http://localhost:8081", MlflowException), 16 ("databricks-uc", MlflowException), 17 ], 18 ) 19 def test_get_oss_host_creds(server_uri, expected_creds): 20 with mock.patch( 21 "mlflow.utils.oss_registry_utils.get_databricks_host_creds", 22 return_value=MlflowHostCreds(host="databricks-uc"), 23 ): 24 if expected_creds == MlflowException: 25 with pytest.raises( 26 MlflowException, match="The scheme of the server_uri should be 'uc'" 27 ): 28 get_oss_host_creds(server_uri) 29 else: 30 actual_creds = get_oss_host_creds(server_uri) 31 assert actual_creds == expected_creds 32 33 34 def test_get_databricks_host_creds(): 35 # Test case: When the scheme is "uc" and the new scheme is "_DATABRICKS_UNITY_CATALOG_SCHEME" 36 server_uri = "uc:databricks-uc" 37 with mock.patch( 38 "mlflow.utils.oss_registry_utils.get_databricks_host_creds" 39 ) as mock_get_databricks_host_creds: 40 get_oss_host_creds(server_uri) 41 assert mock_get_databricks_host_creds.call_args_list == [mock.call("databricks-uc")]