test_credentials.py
1 from unittest import mock 2 from unittest.mock import patch 3 4 import pytest 5 6 from mlflow import get_tracking_uri 7 from mlflow.environment_variables import MLFLOW_TRACKING_PASSWORD, MLFLOW_TRACKING_USERNAME 8 from mlflow.exceptions import MlflowException 9 from mlflow.utils.credentials import login, read_mlflow_creds 10 11 12 def test_read_mlflow_creds_file(tmp_path, monkeypatch): 13 monkeypatch.delenv(MLFLOW_TRACKING_USERNAME.name, raising=False) 14 monkeypatch.delenv(MLFLOW_TRACKING_PASSWORD.name, raising=False) 15 16 creds_file = tmp_path.joinpath("credentials") 17 with mock.patch("mlflow.utils.credentials._get_credentials_path", return_value=str(creds_file)): 18 # credentials file does not exist 19 creds = read_mlflow_creds() 20 assert creds.username is None 21 assert creds.password is None 22 23 # credentials file is empty 24 creds = read_mlflow_creds() 25 assert creds.username is None 26 assert creds.password is None 27 28 # password is missing 29 creds_file.write_text( 30 """ 31 [mlflow] 32 mlflow_tracking_username = username 33 """ 34 ) 35 creds = read_mlflow_creds() 36 assert creds.username == "username" 37 assert creds.password is None 38 39 # username is missing 40 creds_file.write_text( 41 """ 42 [mlflow] 43 mlflow_tracking_password = password 44 """ 45 ) 46 creds = read_mlflow_creds() 47 assert creds.username is None 48 assert creds.password == "password" 49 50 # valid credentials 51 creds_file.write_text( 52 """ 53 [mlflow] 54 mlflow_tracking_username = username 55 mlflow_tracking_password = password 56 """ 57 ) 58 creds = read_mlflow_creds() 59 assert creds is not None 60 assert creds.username == "username" 61 assert creds.password == "password" 62 63 64 @pytest.mark.parametrize( 65 ("username", "password"), 66 [ 67 ("username", "password"), 68 ("username", None), 69 (None, "password"), 70 (None, None), 71 ], 72 ) 73 def test_read_mlflow_creds_env(username, password, monkeypatch): 74 if username is None: 75 monkeypatch.delenv(MLFLOW_TRACKING_USERNAME.name, raising=False) 76 else: 77 monkeypatch.setenv(MLFLOW_TRACKING_USERNAME.name, username) 78 79 if password is None: 80 monkeypatch.delenv(MLFLOW_TRACKING_PASSWORD.name, raising=False) 81 else: 82 monkeypatch.setenv(MLFLOW_TRACKING_PASSWORD.name, password) 83 84 creds = read_mlflow_creds() 85 assert creds.username == username 86 assert creds.password == password 87 88 89 def test_read_mlflow_creds_env_takes_precedence_over_file(tmp_path, monkeypatch): 90 monkeypatch.setenv(MLFLOW_TRACKING_USERNAME.name, "username_env") 91 monkeypatch.setenv(MLFLOW_TRACKING_PASSWORD.name, "password_env") 92 creds_file = tmp_path.joinpath("credentials") 93 with mock.patch("mlflow.utils.credentials._get_credentials_path", return_value=str(creds_file)): 94 creds_file.write_text( 95 """ 96 [mlflow] 97 mlflow_tracking_username = username_file 98 mlflow_tracking_password = password_file 99 """ 100 ) 101 creds = read_mlflow_creds() 102 assert creds.username == "username_env" 103 assert creds.password == "password_env" 104 105 106 def test_mlflow_login(tmp_path, monkeypatch): 107 # Mock `input()` and `getpass()` to return host, username and password in order. 108 with ( 109 patch( 110 "builtins.input", 111 side_effect=["https://community.cloud.databricks.com/", "dummyusername"], 112 ), 113 patch("getpass.getpass", side_effect=["dummypassword"]), 114 ): 115 file_name = f"{tmp_path}/.databrickscfg" 116 profile = "TEST" 117 monkeypatch.setenv("DATABRICKS_CONFIG_FILE", file_name) 118 monkeypatch.setenv("DATABRICKS_CONFIG_PROFILE", profile) 119 120 def success(): 121 return 122 123 with patch( 124 "mlflow.utils.credentials._validate_databricks_auth", 125 side_effect=[MlflowException("Invalid databricks credentials."), success()], 126 ): 127 login("databricks") 128 129 with open(file_name) as f: 130 lines = f.readlines() 131 assert lines[0] == "[TEST]\n" 132 assert lines[1] == "host = https://community.cloud.databricks.com/\n" 133 assert lines[2] == "username = dummyusername\n" 134 assert lines[3] == "password = dummypassword\n" 135 136 # Assert that the tracking URI is set to the databricks. 137 assert get_tracking_uri() == "databricks" 138 139 140 def test_mlflow_login_noninteractive(): 141 # Forces mlflow.utils.credentials._validate_databricks_auth to raise `MlflowException()` 142 with patch( 143 "mlflow.utils.credentials._validate_databricks_auth", 144 side_effect=MlflowException("Failed to validate databricks credentials."), 145 ): 146 with pytest.raises( 147 MlflowException, 148 match="No valid Databricks credentials found while running in non-interactive mode", 149 ): 150 login(backend="databricks", interactive=False)