auth.py
1 import os 2 import uuid 3 4 import mlflow.server 5 6 7 class User: 8 MLFLOW_TRACKING_USERNAME = "MLFLOW_TRACKING_USERNAME" 9 MLFLOW_TRACKING_PASSWORD = "MLFLOW_TRACKING_PASSWORD" 10 11 def __init__(self, username, password) -> None: 12 self.username = username 13 self.password = password 14 self.env = {} 15 16 def _record_env_var(self, key): 17 if key := os.environ.get(key): 18 self.env[key] = key 19 20 def _restore_env_var(self, key): 21 if value := self.env.get(key): 22 os.environ[key] = value 23 else: 24 del os.environ[key] 25 26 def __enter__(self): 27 self._record_env_var(User.MLFLOW_TRACKING_USERNAME) 28 self._record_env_var(User.MLFLOW_TRACKING_PASSWORD) 29 os.environ[User.MLFLOW_TRACKING_USERNAME] = self.username 30 os.environ[User.MLFLOW_TRACKING_PASSWORD] = self.password 31 return self 32 33 def __exit__(self, *_exc): 34 self._restore_env_var(User.MLFLOW_TRACKING_USERNAME) 35 self._restore_env_var(User.MLFLOW_TRACKING_PASSWORD) 36 self.env.clear() 37 38 39 tracking_uri = "http://localhost:5000" 40 mlflow.set_tracking_uri(tracking_uri) 41 client = mlflow.server.get_app_client("basic-auth", tracking_uri) 42 A = User("user_a", "password_a") 43 B = User("user_b", "password_b") 44 45 with A: 46 exp_a = mlflow.set_experiment(uuid.uuid4().hex) 47 with mlflow.start_run(): 48 mlflow.log_metric("a", 1) 49 50 with B: 51 mlflow.set_experiment(exp_a.name) 52 try: 53 with mlflow.start_run(): # not allowed 54 mlflow.log_metric("b", 2) 55 except Exception as e: 56 print(str(e)) 57 58 # Grant B permission to edit A's experiment 59 with A: 60 client.create_experiment_permission(str(exp_a.experiment_id), B.username, "EDIT") 61 62 # B can edit now, should be able to log a metric 63 with B: 64 mlflow.set_experiment(exp_a.name) 65 with mlflow.start_run(): 66 mlflow.log_metric("b", 2)