/ tests / utils / test_credentials.py
test_credentials.py
  1  from unittest import mock
  2  from unittest.mock import patch
  3  
  4  import pytest
  5  
  6  from mlflow import get_tracking_uri
  7  from mlflow.environment_variables import MLFLOW_TRACKING_PASSWORD, MLFLOW_TRACKING_USERNAME
  8  from mlflow.exceptions import MlflowException
  9  from mlflow.utils.credentials import login, read_mlflow_creds
 10  
 11  
 12  def test_read_mlflow_creds_file(tmp_path, monkeypatch):
 13      monkeypatch.delenv(MLFLOW_TRACKING_USERNAME.name, raising=False)
 14      monkeypatch.delenv(MLFLOW_TRACKING_PASSWORD.name, raising=False)
 15  
 16      creds_file = tmp_path.joinpath("credentials")
 17      with mock.patch("mlflow.utils.credentials._get_credentials_path", return_value=str(creds_file)):
 18          # credentials file does not exist
 19          creds = read_mlflow_creds()
 20          assert creds.username is None
 21          assert creds.password is None
 22  
 23          # credentials file is empty
 24          creds = read_mlflow_creds()
 25          assert creds.username is None
 26          assert creds.password is None
 27  
 28          # password is missing
 29          creds_file.write_text(
 30              """
 31  [mlflow]
 32  mlflow_tracking_username = username
 33  """
 34          )
 35          creds = read_mlflow_creds()
 36          assert creds.username == "username"
 37          assert creds.password is None
 38  
 39          # username is missing
 40          creds_file.write_text(
 41              """
 42  [mlflow]
 43  mlflow_tracking_password = password
 44  """
 45          )
 46          creds = read_mlflow_creds()
 47          assert creds.username is None
 48          assert creds.password == "password"
 49  
 50          # valid credentials
 51          creds_file.write_text(
 52              """
 53  [mlflow]
 54  mlflow_tracking_username = username
 55  mlflow_tracking_password = password
 56  """
 57          )
 58          creds = read_mlflow_creds()
 59          assert creds is not None
 60          assert creds.username == "username"
 61          assert creds.password == "password"
 62  
 63  
 64  @pytest.mark.parametrize(
 65      ("username", "password"),
 66      [
 67          ("username", "password"),
 68          ("username", None),
 69          (None, "password"),
 70          (None, None),
 71      ],
 72  )
 73  def test_read_mlflow_creds_env(username, password, monkeypatch):
 74      if username is None:
 75          monkeypatch.delenv(MLFLOW_TRACKING_USERNAME.name, raising=False)
 76      else:
 77          monkeypatch.setenv(MLFLOW_TRACKING_USERNAME.name, username)
 78  
 79      if password is None:
 80          monkeypatch.delenv(MLFLOW_TRACKING_PASSWORD.name, raising=False)
 81      else:
 82          monkeypatch.setenv(MLFLOW_TRACKING_PASSWORD.name, password)
 83  
 84      creds = read_mlflow_creds()
 85      assert creds.username == username
 86      assert creds.password == password
 87  
 88  
 89  def test_read_mlflow_creds_env_takes_precedence_over_file(tmp_path, monkeypatch):
 90      monkeypatch.setenv(MLFLOW_TRACKING_USERNAME.name, "username_env")
 91      monkeypatch.setenv(MLFLOW_TRACKING_PASSWORD.name, "password_env")
 92      creds_file = tmp_path.joinpath("credentials")
 93      with mock.patch("mlflow.utils.credentials._get_credentials_path", return_value=str(creds_file)):
 94          creds_file.write_text(
 95              """
 96  [mlflow]
 97  mlflow_tracking_username = username_file
 98  mlflow_tracking_password = password_file
 99  """
100          )
101          creds = read_mlflow_creds()
102          assert creds.username == "username_env"
103          assert creds.password == "password_env"
104  
105  
106  def test_mlflow_login(tmp_path, monkeypatch):
107      # Mock `input()` and `getpass()` to return host, username and password in order.
108      with (
109          patch(
110              "builtins.input",
111              side_effect=["https://community.cloud.databricks.com/", "dummyusername"],
112          ),
113          patch("getpass.getpass", side_effect=["dummypassword"]),
114      ):
115          file_name = f"{tmp_path}/.databrickscfg"
116          profile = "TEST"
117          monkeypatch.setenv("DATABRICKS_CONFIG_FILE", file_name)
118          monkeypatch.setenv("DATABRICKS_CONFIG_PROFILE", profile)
119  
120          def success():
121              return
122  
123          with patch(
124              "mlflow.utils.credentials._validate_databricks_auth",
125              side_effect=[MlflowException("Invalid databricks credentials."), success()],
126          ):
127              login("databricks")
128  
129      with open(file_name) as f:
130          lines = f.readlines()
131          assert lines[0] == "[TEST]\n"
132          assert lines[1] == "host = https://community.cloud.databricks.com/\n"
133          assert lines[2] == "username = dummyusername\n"
134          assert lines[3] == "password = dummypassword\n"
135  
136      # Assert that the tracking URI is set to the databricks.
137      assert get_tracking_uri() == "databricks"
138  
139  
140  def test_mlflow_login_noninteractive():
141      # Forces mlflow.utils.credentials._validate_databricks_auth to raise `MlflowException()`
142      with patch(
143          "mlflow.utils.credentials._validate_databricks_auth",
144          side_effect=MlflowException("Failed to validate databricks credentials."),
145      ):
146          with pytest.raises(
147              MlflowException,
148              match="No valid Databricks credentials found while running in non-interactive mode",
149          ):
150              login(backend="databricks", interactive=False)