utils.py
1 import importlib 2 from functools import partial 3 4 from mlflow.environment_variables import MLFLOW_ENABLE_WORKSPACES, MLFLOW_REGISTRY_URI 5 from mlflow.store.db.db_types import DATABASE_ENGINES 6 from mlflow.store.model_registry.databricks_workspace_model_registry_rest_store import ( 7 DatabricksWorkspaceModelRegistryRestStore, 8 ) 9 from mlflow.store.model_registry.file_store import FileStore 10 from mlflow.store.model_registry.rest_store import RestStore 11 from mlflow.tracking._model_registry.registry import ModelRegistryStoreRegistry 12 from mlflow.tracking._tracking_service.utils import ( 13 _resolve_tracking_uri, 14 ) 15 from mlflow.utils._spark_utils import _get_active_spark_session 16 from mlflow.utils.credentials import get_default_host_creds 17 from mlflow.utils.databricks_utils import ( 18 is_in_databricks_serverless_runtime, 19 warn_on_deprecated_cross_workspace_registry_uri, 20 ) 21 from mlflow.utils.uri import ( 22 _DATABRICKS_UNITY_CATALOG_SCHEME, 23 _OSS_UNITY_CATALOG_SCHEME, 24 construct_db_uc_uri_from_profile, 25 get_db_info_from_uri, 26 is_databricks_uri, 27 ) 28 29 # NOTE: in contrast to tracking, we do not support the following ways to specify 30 # the model registry URI: 31 # - via environment variables like MLFLOW_TRACKING_URI, MLFLOW_TRACKING_USERNAME, ... 32 # We do support specifying it 33 # - via the ``model_registry_uri`` parameter when creating an ``MlflowClient`` or 34 # ``ModelRegistryClient``. 35 # - via a utility method ``mlflow.set_registry_uri`` 36 # - by not specifying anything: in this case we assume the model registry store URI is 37 # the same as the tracking store URI. This means Tracking and Model Registry are 38 # backed by the same backend DB/Rest server. However, note that we access them via 39 # different ``Store`` classes (e.g. ``mlflow.store.tracking.SQLAlchemyStore`` & 40 # ``mlflow.store.model_registry.SQLAlchemyStore``). 41 # This means the following combinations are not supported: 42 # - Tracking RestStore & Model Registry RestStore that use different credentials. 43 44 _registry_uri = None 45 46 47 def set_registry_uri(uri: str) -> None: 48 """Set the registry server URI. This method is especially useful if you have a registry server 49 that's different from the tracking server. 50 51 Args: 52 uri: An empty string, or a local file path, prefixed with ``file:/``. Data is stored 53 locally at the provided file (or ``./mlruns`` if empty). An HTTP URI like 54 ``https://my-tracking-server:5000`` or ``http://my-oss-uc-server:8080``. A Databricks 55 workspace, provided as the string "databricks" or, to use a Databricks CLI 56 `profile <https://github.com/databricks/databricks-cli#installation>`_, 57 "databricks://<profileName>". 58 59 .. code-block:: python 60 :caption: Example 61 62 import mflow 63 64 # Set model registry uri, fetch the set uri, and compare 65 # it with the tracking uri. They should be different 66 mlflow.set_registry_uri("sqlite:////tmp/registry.db") 67 mr_uri = mlflow.get_registry_uri() 68 print(f"Current registry uri: {mr_uri}") 69 tracking_uri = mlflow.get_tracking_uri() 70 print(f"Current tracking uri: {tracking_uri}") 71 72 # They should be different 73 assert tracking_uri != mr_uri 74 75 .. code-block:: text 76 :caption: Output 77 78 Current registry uri: sqlite:////tmp/registry.db 79 Current tracking uri: file:///.../mlruns 80 81 """ 82 global _registry_uri 83 _registry_uri = uri 84 if uri: 85 # Set 'MLFLOW_REGISTRY_URI' environment variable 86 # so that subprocess can inherit it. 87 MLFLOW_REGISTRY_URI.set(_registry_uri) 88 89 90 def _get_registry_uri_from_spark_session(): 91 session = _get_active_spark_session() 92 if session is None: 93 return None 94 95 if is_in_databricks_serverless_runtime(): 96 # Connected to Serverless 97 return "databricks-uc" 98 99 from pyspark.sql.utils import AnalysisException 100 101 try: 102 return session.conf.get("spark.mlflow.modelRegistryUri", None) 103 except AnalysisException: 104 # In serverless clusters, session.conf.get() is unsupported 105 # and raises an AnalysisException. We may encounter this case 106 # when DBConnect is used to connect to a serverless cluster, 107 # in which case the prior `is_in_databricks_serverless_runtime()` 108 # check will have returned false (as of 2025-06-07, it checks 109 # an environment variable that isn't set by DBConnect) 110 return None 111 112 113 def _get_registry_uri_from_context(): 114 if _registry_uri is not None: 115 return _registry_uri 116 elif (uri := MLFLOW_REGISTRY_URI.get()) or (uri := _get_registry_uri_from_spark_session()): 117 return uri 118 return _registry_uri 119 120 121 def _get_default_registry_uri_for_tracking_uri(tracking_uri: str | None) -> str | None: 122 """ 123 Get the default registry URI for a given tracking URI. 124 125 If the tracking URI starts with "databricks", returns "databricks-uc" with profile if present. 126 Otherwise, returns the tracking URI itself. 127 128 Args: 129 tracking_uri: The tracking URI to get the default registry URI for 130 131 Returns: 132 The default registry URI 133 """ 134 if tracking_uri is not None and is_databricks_uri(tracking_uri): 135 # If the tracking URI is "databricks", we impute the registry URI as "databricks-uc" 136 # corresponding to Databricks Unity Catalog Model Registry, which is the recommended 137 # model registry offering on Databricks 138 if tracking_uri == "databricks": 139 return _DATABRICKS_UNITY_CATALOG_SCHEME 140 else: 141 # Extract profile from tracking URI and construct databricks-uc URI 142 profile, key_prefix = get_db_info_from_uri(tracking_uri) 143 if profile: 144 # Reconstruct the profile string including key_prefix if present 145 profile_string = f"{profile}:{key_prefix}" if key_prefix else profile 146 return construct_db_uc_uri_from_profile(profile_string) 147 else: 148 return _DATABRICKS_UNITY_CATALOG_SCHEME 149 150 # For non-databricks tracking URIs, use the tracking URI as the registry URI 151 return tracking_uri 152 153 154 def get_registry_uri() -> str: 155 """Get the current registry URI. If none has been specified, defaults to the tracking URI. 156 157 Returns: 158 The registry URI. 159 160 .. code-block:: python 161 162 # Get the current model registry uri 163 mr_uri = mlflow.get_registry_uri() 164 print(f"Current model registry uri: {mr_uri}") 165 166 # Get the current tracking uri 167 tracking_uri = mlflow.get_tracking_uri() 168 print(f"Current tracking uri: {tracking_uri}") 169 170 # They should be the same 171 assert mr_uri == tracking_uri 172 173 .. code-block:: text 174 175 Current model registry uri: file:///.../mlruns 176 Current tracking uri: file:///.../mlruns 177 178 """ 179 return _resolve_registry_uri() 180 181 182 def _resolve_registry_uri( 183 registry_uri: str | None = None, tracking_uri: str | None = None 184 ) -> str | None: 185 """ 186 Resolve the registry URI following the same logic as get_registry_uri(). 187 """ 188 return ( 189 registry_uri 190 or _get_registry_uri_from_context() 191 or _get_default_registry_uri_for_tracking_uri(_resolve_tracking_uri(tracking_uri)) 192 ) 193 194 195 def _get_sqlalchemy_store(store_uri): 196 from mlflow.store.model_registry.sqlalchemy_store import SqlAlchemyStore 197 from mlflow.store.model_registry.sqlalchemy_workspace_store import ( 198 WorkspaceAwareSqlAlchemyStore, 199 ) 200 201 store_cls = WorkspaceAwareSqlAlchemyStore if MLFLOW_ENABLE_WORKSPACES.get() else SqlAlchemyStore 202 return store_cls(store_uri) 203 204 205 def _get_rest_store(store_uri, **_): 206 return RestStore(partial(get_default_host_creds, store_uri)) 207 208 209 def _get_databricks_rest_store(store_uri, tracking_uri, **_): 210 warn_on_deprecated_cross_workspace_registry_uri(registry_uri=store_uri) 211 return DatabricksWorkspaceModelRegistryRestStore(store_uri, tracking_uri) 212 213 214 # We define the global variable as `None` so that instantiating the store does not lead to circular 215 # dependency issues. 216 _model_registry_store_registry = None 217 218 219 def _get_file_store(store_uri, **_): 220 return FileStore(store_uri) 221 222 223 def _get_store_registry(): 224 global _model_registry_store_registry 225 from mlflow.store._unity_catalog.registry.rest_store import UcModelRegistryStore 226 from mlflow.store._unity_catalog.registry.uc_oss_rest_store import UnityCatalogOssStore 227 228 if _model_registry_store_registry is not None: 229 return _model_registry_store_registry 230 231 _model_registry_store_registry = ModelRegistryStoreRegistry() 232 _model_registry_store_registry.register("databricks", _get_databricks_rest_store) 233 # Register a placeholder function that raises if users pass a registry URI with scheme 234 # "databricks-uc" 235 _model_registry_store_registry.register(_DATABRICKS_UNITY_CATALOG_SCHEME, UcModelRegistryStore) 236 _model_registry_store_registry.register(_OSS_UNITY_CATALOG_SCHEME, UnityCatalogOssStore) 237 238 for scheme in ["http", "https"]: 239 _model_registry_store_registry.register(scheme, _get_rest_store) 240 241 if importlib.util.find_spec("sqlalchemy") is not None: 242 for scheme in DATABASE_ENGINES: 243 _model_registry_store_registry.register(scheme, _get_sqlalchemy_store) 244 245 for scheme in ["", "file"]: 246 _model_registry_store_registry.register(scheme, _get_file_store) 247 248 _model_registry_store_registry.register_entrypoints() 249 return _model_registry_store_registry 250 251 252 def _get_store(store_uri=None, tracking_uri=None): 253 return _get_store_registry().get_store(store_uri, tracking_uri)