registry.py
1 import logging 2 import warnings 3 4 from mlflow.tracking import get_tracking_uri 5 from mlflow.tracking.default_experiment import DEFAULT_EXPERIMENT_ID 6 from mlflow.tracking.default_experiment.databricks_notebook_experiment_provider import ( 7 DatabricksNotebookExperimentProvider, 8 ) 9 from mlflow.utils.plugins import get_entry_points 10 from mlflow.utils.uri import is_databricks_uri 11 12 _logger = logging.getLogger(__name__) 13 # Listed below are the list of providers, which are used to provide MLflow Experiment IDs based on 14 # the current context where the MLflow client is running when the user has not explicitly set 15 # an experiment. The order below is the order in which the these providers are registered. 16 _EXPERIMENT_PROVIDERS = (DatabricksNotebookExperimentProvider,) 17 18 19 class DefaultExperimentProviderRegistry: 20 """Registry for default experiment provider implementations 21 22 This class allows the registration of default experiment providers, which are used to provide 23 MLflow Experiment IDs based on the current context where the MLflow client is running when 24 the user has not explicitly set an experiment. Implementations declared though the entrypoints 25 `mlflow.default_experiment_provider` group can be automatically registered through the 26 `register_entrypoints` method. 27 """ 28 29 def __init__(self): 30 self._registry = [] 31 32 def register(self, default_experiment_provider_cls): 33 self._registry.append(default_experiment_provider_cls()) 34 35 def register_entrypoints(self): 36 """Register tracking stores provided by other packages""" 37 for entrypoint in get_entry_points("mlflow.default_experiment_provider"): 38 try: 39 self.register(entrypoint.load()) 40 except (AttributeError, ImportError) as exc: 41 warnings.warn( 42 "Failure attempting to register default experiment" 43 + f'context provider "{entrypoint.name}": {exc}', 44 stacklevel=2, 45 ) 46 47 def __iter__(self): 48 return iter(self._registry) 49 50 51 _default_experiment_provider_registry = DefaultExperimentProviderRegistry() 52 for exp_provider in _EXPERIMENT_PROVIDERS: 53 _default_experiment_provider_registry.register(exp_provider) 54 55 _default_experiment_provider_registry.register_entrypoints() 56 57 58 def get_experiment_id() -> str | None: 59 """Get an experiment ID for the current context. 60 61 The experiment ID is fetched by querying providers, in the order that they were registered. 62 This function iterates through all default experiment context providers in the registry. 63 64 Returns: 65 An experiment_id. 66 """ 67 for provider in _default_experiment_provider_registry: 68 try: 69 if provider.in_context(): 70 return provider.get_experiment_id() 71 except Exception as e: 72 _logger.warning("Encountered unexpected error while getting experiment_id: %s", e) 73 74 return DEFAULT_EXPERIMENT_ID if not is_databricks_uri(get_tracking_uri()) else None