utils.py
  1  import importlib
  2  from functools import partial
  3  
  4  from mlflow.environment_variables import MLFLOW_ENABLE_WORKSPACES, MLFLOW_REGISTRY_URI
  5  from mlflow.store.db.db_types import DATABASE_ENGINES
  6  from mlflow.store.model_registry.databricks_workspace_model_registry_rest_store import (
  7      DatabricksWorkspaceModelRegistryRestStore,
  8  )
  9  from mlflow.store.model_registry.file_store import FileStore
 10  from mlflow.store.model_registry.rest_store import RestStore
 11  from mlflow.tracking._model_registry.registry import ModelRegistryStoreRegistry
 12  from mlflow.tracking._tracking_service.utils import (
 13      _resolve_tracking_uri,
 14  )
 15  from mlflow.utils._spark_utils import _get_active_spark_session
 16  from mlflow.utils.credentials import get_default_host_creds
 17  from mlflow.utils.databricks_utils import (
 18      is_in_databricks_serverless_runtime,
 19      warn_on_deprecated_cross_workspace_registry_uri,
 20  )
 21  from mlflow.utils.uri import (
 22      _DATABRICKS_UNITY_CATALOG_SCHEME,
 23      _OSS_UNITY_CATALOG_SCHEME,
 24      construct_db_uc_uri_from_profile,
 25      get_db_info_from_uri,
 26      is_databricks_uri,
 27  )
 28  
 29  # NOTE: in contrast to tracking, we do not support the following ways to specify
 30  # the model registry URI:
 31  #  - via environment variables like MLFLOW_TRACKING_URI, MLFLOW_TRACKING_USERNAME, ...
 32  # We do support specifying it
 33  #  - via the ``model_registry_uri`` parameter when creating an ``MlflowClient`` or
 34  #    ``ModelRegistryClient``.
 35  #  - via a utility method ``mlflow.set_registry_uri``
 36  #  - by not specifying anything: in this case we assume the model registry store URI is
 37  #    the same as the tracking store URI. This means Tracking and Model Registry are
 38  #    backed by the same backend DB/Rest server. However, note that we access them via
 39  #    different ``Store`` classes (e.g. ``mlflow.store.tracking.SQLAlchemyStore`` &
 40  #    ``mlflow.store.model_registry.SQLAlchemyStore``).
 41  # This means the following combinations are not supported:
 42  #  - Tracking RestStore & Model Registry RestStore that use different credentials.
 43  
 44  _registry_uri = None
 45  
 46  
 47  def set_registry_uri(uri: str) -> None:
 48      """Set the registry server URI. This method is especially useful if you have a registry server
 49      that's different from the tracking server.
 50  
 51      Args:
 52          uri: An empty string, or a local file path, prefixed with ``file:/``. Data is stored
 53              locally at the provided file (or ``./mlruns`` if empty). An HTTP URI like
 54              ``https://my-tracking-server:5000`` or ``http://my-oss-uc-server:8080``. A Databricks
 55              workspace, provided as the string "databricks" or, to use a Databricks CLI
 56              `profile <https://github.com/databricks/databricks-cli#installation>`_,
 57              "databricks://<profileName>".
 58  
 59      .. code-block:: python
 60          :caption: Example
 61  
 62          import mflow
 63  
 64          # Set model registry uri, fetch the set uri, and compare
 65          # it with the tracking uri. They should be different
 66          mlflow.set_registry_uri("sqlite:////tmp/registry.db")
 67          mr_uri = mlflow.get_registry_uri()
 68          print(f"Current registry uri: {mr_uri}")
 69          tracking_uri = mlflow.get_tracking_uri()
 70          print(f"Current tracking uri: {tracking_uri}")
 71  
 72          # They should be different
 73          assert tracking_uri != mr_uri
 74  
 75      .. code-block:: text
 76          :caption: Output
 77  
 78          Current registry uri: sqlite:////tmp/registry.db
 79          Current tracking uri: file:///.../mlruns
 80  
 81      """
 82      global _registry_uri
 83      _registry_uri = uri
 84      if uri:
 85          # Set 'MLFLOW_REGISTRY_URI' environment variable
 86          # so that subprocess can inherit it.
 87          MLFLOW_REGISTRY_URI.set(_registry_uri)
 88  
 89  
 90  def _get_registry_uri_from_spark_session():
 91      session = _get_active_spark_session()
 92      if session is None:
 93          return None
 94  
 95      if is_in_databricks_serverless_runtime():
 96          # Connected to Serverless
 97          return "databricks-uc"
 98  
 99      from pyspark.sql.utils import AnalysisException
100  
101      try:
102          return session.conf.get("spark.mlflow.modelRegistryUri", None)
103      except AnalysisException:
104          # In serverless clusters, session.conf.get() is unsupported
105          # and raises an AnalysisException. We may encounter this case
106          # when DBConnect is used to connect to a serverless cluster,
107          # in which case the prior `is_in_databricks_serverless_runtime()`
108          # check will have returned false (as of 2025-06-07, it checks
109          # an environment variable that isn't set by DBConnect)
110          return None
111  
112  
113  def _get_registry_uri_from_context():
114      if _registry_uri is not None:
115          return _registry_uri
116      elif (uri := MLFLOW_REGISTRY_URI.get()) or (uri := _get_registry_uri_from_spark_session()):
117          return uri
118      return _registry_uri
119  
120  
121  def _get_default_registry_uri_for_tracking_uri(tracking_uri: str | None) -> str | None:
122      """
123      Get the default registry URI for a given tracking URI.
124  
125      If the tracking URI starts with "databricks", returns "databricks-uc" with profile if present.
126      Otherwise, returns the tracking URI itself.
127  
128      Args:
129          tracking_uri: The tracking URI to get the default registry URI for
130  
131      Returns:
132          The default registry URI
133      """
134      if tracking_uri is not None and is_databricks_uri(tracking_uri):
135          # If the tracking URI is "databricks", we impute the registry URI as "databricks-uc"
136          # corresponding to Databricks Unity Catalog Model Registry, which is the recommended
137          # model registry offering on Databricks
138          if tracking_uri == "databricks":
139              return _DATABRICKS_UNITY_CATALOG_SCHEME
140          else:
141              # Extract profile from tracking URI and construct databricks-uc URI
142              profile, key_prefix = get_db_info_from_uri(tracking_uri)
143              if profile:
144                  # Reconstruct the profile string including key_prefix if present
145                  profile_string = f"{profile}:{key_prefix}" if key_prefix else profile
146                  return construct_db_uc_uri_from_profile(profile_string)
147              else:
148                  return _DATABRICKS_UNITY_CATALOG_SCHEME
149  
150      # For non-databricks tracking URIs, use the tracking URI as the registry URI
151      return tracking_uri
152  
153  
154  def get_registry_uri() -> str:
155      """Get the current registry URI. If none has been specified, defaults to the tracking URI.
156  
157      Returns:
158          The registry URI.
159  
160      .. code-block:: python
161  
162          # Get the current model registry uri
163          mr_uri = mlflow.get_registry_uri()
164          print(f"Current model registry uri: {mr_uri}")
165  
166          # Get the current tracking uri
167          tracking_uri = mlflow.get_tracking_uri()
168          print(f"Current tracking uri: {tracking_uri}")
169  
170          # They should be the same
171          assert mr_uri == tracking_uri
172  
173      .. code-block:: text
174  
175          Current model registry uri: file:///.../mlruns
176          Current tracking uri: file:///.../mlruns
177  
178      """
179      return _resolve_registry_uri()
180  
181  
182  def _resolve_registry_uri(
183      registry_uri: str | None = None, tracking_uri: str | None = None
184  ) -> str | None:
185      """
186      Resolve the registry URI following the same logic as get_registry_uri().
187      """
188      return (
189          registry_uri
190          or _get_registry_uri_from_context()
191          or _get_default_registry_uri_for_tracking_uri(_resolve_tracking_uri(tracking_uri))
192      )
193  
194  
195  def _get_sqlalchemy_store(store_uri):
196      from mlflow.store.model_registry.sqlalchemy_store import SqlAlchemyStore
197      from mlflow.store.model_registry.sqlalchemy_workspace_store import (
198          WorkspaceAwareSqlAlchemyStore,
199      )
200  
201      store_cls = WorkspaceAwareSqlAlchemyStore if MLFLOW_ENABLE_WORKSPACES.get() else SqlAlchemyStore
202      return store_cls(store_uri)
203  
204  
205  def _get_rest_store(store_uri, **_):
206      return RestStore(partial(get_default_host_creds, store_uri))
207  
208  
209  def _get_databricks_rest_store(store_uri, tracking_uri, **_):
210      warn_on_deprecated_cross_workspace_registry_uri(registry_uri=store_uri)
211      return DatabricksWorkspaceModelRegistryRestStore(store_uri, tracking_uri)
212  
213  
214  # We define the global variable as `None` so that instantiating the store does not lead to circular
215  # dependency issues.
216  _model_registry_store_registry = None
217  
218  
219  def _get_file_store(store_uri, **_):
220      return FileStore(store_uri)
221  
222  
223  def _get_store_registry():
224      global _model_registry_store_registry
225      from mlflow.store._unity_catalog.registry.rest_store import UcModelRegistryStore
226      from mlflow.store._unity_catalog.registry.uc_oss_rest_store import UnityCatalogOssStore
227  
228      if _model_registry_store_registry is not None:
229          return _model_registry_store_registry
230  
231      _model_registry_store_registry = ModelRegistryStoreRegistry()
232      _model_registry_store_registry.register("databricks", _get_databricks_rest_store)
233      # Register a placeholder function that raises if users pass a registry URI with scheme
234      # "databricks-uc"
235      _model_registry_store_registry.register(_DATABRICKS_UNITY_CATALOG_SCHEME, UcModelRegistryStore)
236      _model_registry_store_registry.register(_OSS_UNITY_CATALOG_SCHEME, UnityCatalogOssStore)
237  
238      for scheme in ["http", "https"]:
239          _model_registry_store_registry.register(scheme, _get_rest_store)
240  
241      if importlib.util.find_spec("sqlalchemy") is not None:
242          for scheme in DATABASE_ENGINES:
243              _model_registry_store_registry.register(scheme, _get_sqlalchemy_store)
244  
245      for scheme in ["", "file"]:
246          _model_registry_store_registry.register(scheme, _get_file_store)
247  
248      _model_registry_store_registry.register_entrypoints()
249      return _model_registry_store_registry
250  
251  
252  def _get_store(store_uri=None, tracking_uri=None):
253      return _get_store_registry().get_store(store_uri, tracking_uri)