/ mlflow / data / dataset_source_registry.py
dataset_source_registry.py
  1  import warnings
  2  from typing import Any
  3  
  4  from mlflow.data.dataset_source import DatasetSource
  5  from mlflow.data.http_dataset_source import HTTPDatasetSource
  6  from mlflow.exceptions import MlflowException
  7  from mlflow.protos.databricks_pb2 import RESOURCE_DOES_NOT_EXIST
  8  from mlflow.utils.plugins import get_entry_points
  9  
 10  
 11  class DatasetSourceRegistry:
 12      def __init__(self):
 13          self.sources = []
 14  
 15      def register(self, source: DatasetSource):
 16          """Registers a DatasetSource for use with MLflow Tracking.
 17  
 18          Args:
 19              source: The DatasetSource to register.
 20          """
 21          self.sources.append(source)
 22  
 23      def register_entrypoints(self):
 24          """
 25          Registers dataset sources defined as Python entrypoints. For reference, see
 26          https://mlflow.org/docs/latest/plugins.html#defining-a-plugin.
 27          """
 28          for entrypoint in get_entry_points("mlflow.dataset_source"):
 29              try:
 30                  self.register(entrypoint.load())
 31              except (AttributeError, ImportError) as exc:
 32                  warnings.warn(
 33                      "Failure attempting to register dataset constructor"
 34                      + f' "{entrypoint}": {exc}',
 35                      stacklevel=2,
 36                  )
 37  
 38      def resolve(
 39          self, raw_source: Any, candidate_sources: list[DatasetSource] | None = None
 40      ) -> DatasetSource:
 41          """Resolves a raw source object, such as a string URI, to a DatasetSource for use with
 42          MLflow Tracking.
 43  
 44          Args:
 45              raw_source: The raw source, e.g. a string like "s3://mybucket/path/to/iris/data" or a
 46                  HuggingFace :py:class:`datasets.Dataset` object.
 47              candidate_sources: A list of DatasetSource classes to consider as potential sources
 48                  when resolving the raw source. Subclasses of the specified candidate sources are
 49                  also considered. If unspecified, all registered sources are considered.
 50  
 51          Raises:
 52              MlflowException: If no DatasetSource class can resolve the raw source.
 53  
 54          Returns:
 55              The resolved DatasetSource.
 56          """
 57          matching_sources = []
 58          for source in self.sources:
 59              if candidate_sources and not any(
 60                  issubclass(source, candidate_src) for candidate_src in candidate_sources
 61              ):
 62                  continue
 63              try:
 64                  if source._can_resolve(raw_source):
 65                      matching_sources.append(source)
 66              except Exception as e:
 67                  warnings.warn(
 68                      f"Failed to determine whether {source.__name__} can resolve source"
 69                      f" information for '{raw_source}'. Exception: {e}",
 70                      stacklevel=2,
 71                  )
 72                  continue
 73  
 74          if len(matching_sources) > 1:
 75              source_class_names_str = ", ".join([source.__name__ for source in matching_sources])
 76              warnings.warn(
 77                  f"The specified dataset source can be interpreted in multiple ways:"
 78                  f" {source_class_names_str}. MLflow will assume that this is a"
 79                  f" {matching_sources[-1].__name__} source.",
 80                  stacklevel=2,
 81              )
 82  
 83          for matching_source in reversed(matching_sources):
 84              try:
 85                  return matching_source._resolve(raw_source)
 86              except Exception as e:
 87                  warnings.warn(
 88                      f"Encountered an unexpected error while using {matching_source.__name__} to"
 89                      f" resolve source information for '{raw_source}'. Exception: {e}",
 90                      stacklevel=2,
 91                  )
 92                  continue
 93  
 94          raise MlflowException(
 95              f"Could not find a source information resolver for the specified"
 96              f" dataset source: {raw_source}.",
 97              RESOURCE_DOES_NOT_EXIST,
 98          )
 99  
100      def get_source_from_json(self, source_json: str, source_type: str) -> DatasetSource:
101          """Parses and returns a DatasetSource object from its JSON representation.
102  
103          Args:
104              source_json: The JSON representation of the DatasetSource.
105              source_type: The string type of the DatasetSource, which indicates how to parse the
106                  source JSON.
107          """
108          for source in reversed(self.sources):
109              if source._get_source_type() == source_type:
110                  return source.from_json(source_json)
111  
112          raise MlflowException(
113              f"Could not parse dataset source from JSON due to unrecognized"
114              f" source type: {source_type}.",
115              RESOURCE_DOES_NOT_EXIST,
116          )
117  
118  
119  def register_dataset_source(source: DatasetSource):
120      """Registers a DatasetSource for use with MLflow Tracking.
121  
122      Args:
123          source: The DatasetSource to register.
124      """
125      _dataset_source_registry.register(source)
126  
127  
128  def resolve_dataset_source(
129      raw_source: Any, candidate_sources: list[DatasetSource] | None = None
130  ) -> DatasetSource:
131      """Resolves a raw source object, such as a string URI, to a DatasetSource for use with
132      MLflow Tracking.
133  
134      Args:
135          raw_source: The raw source, e.g. a string like "s3://mybucket/path/to/iris/data" or a
136              HuggingFace :py:class:`datasets.Dataset` object.
137          candidate_sources: A list of DatasetSource classes to consider as potential sources
138              when resolving the raw source. Subclasses of the specified candidate
139              sources are also considered. If unspecified, all registered sources
140              are considered.
141  
142      Raises:
143          MlflowException: If no DatasetSource class can resolve the raw source.
144  
145      Returns:
146          The resolved DatasetSource.
147      """
148      return _dataset_source_registry.resolve(
149          raw_source=raw_source, candidate_sources=candidate_sources
150      )
151  
152  
153  def get_dataset_source_from_json(source_json: str, source_type: str) -> DatasetSource:
154      """Parses and returns a DatasetSource object from its JSON representation.
155  
156      Args:
157          source_json: The JSON representation of the DatasetSource.
158          source_type: The string type of the DatasetSource, which indicates how to parse the
159              source JSON.
160      """
161      return _dataset_source_registry.get_source_from_json(
162          source_json=source_json, source_type=source_type
163      )
164  
165  
166  def get_registered_sources() -> list[DatasetSource]:
167      """Obtains the registered dataset sources.
168  
169      Returns:
170          A list of registered dataset sources.
171  
172      """
173      return _dataset_source_registry.sources
174  
175  
176  # NB: The ordering here is important. The last dataset source to be registered takes precedence
177  # when resolving dataset information for a raw source (e.g. a string like "s3://mybucket/my/path").
178  # Dataset sources derived from artifact repositories are the most generic / provide the most
179  # general information about dataset source locations, so they are registered first. More specific
180  # source information is provided by specialized dataset platform sources like
181  # HuggingFaceDatasetSource, so these sources are registered next. Finally, externally-defined
182  # dataset sources are registered last because externally-defined behavior should take precedence
183  # over any internally-defined generic behavior
184  _dataset_source_registry = DatasetSourceRegistry()
185  
186  # Register artifact sources first (they should take lower precedence)
187  from mlflow.data.artifact_dataset_sources import register_artifact_dataset_sources
188  
189  register_artifact_dataset_sources()
190  
191  _dataset_source_registry.register(HTTPDatasetSource)
192  _dataset_source_registry.register_entrypoints()
193  
194  try:
195      from mlflow.data.huggingface_dataset_source import HuggingFaceDatasetSource
196  
197      _dataset_source_registry.register(HuggingFaceDatasetSource)
198  except ImportError:
199      pass
200  try:
201      from mlflow.data.spark_dataset_source import SparkDatasetSource
202  
203      _dataset_source_registry.register(SparkDatasetSource)
204  except ImportError:
205      pass
206  try:
207      from mlflow.data.delta_dataset_source import DeltaDatasetSource
208  
209      _dataset_source_registry.register(DeltaDatasetSource)
210  except ImportError:
211      pass
212  try:
213      from mlflow.data.code_dataset_source import CodeDatasetSource
214  
215      _dataset_source_registry.register(CodeDatasetSource)
216  except ImportError:
217      pass
218  try:
219      from mlflow.data.uc_volume_dataset_source import UCVolumeDatasetSource
220  
221      _dataset_source_registry.register(UCVolumeDatasetSource)
222  except ImportError:
223      pass
224  try:
225      from mlflow.genai.datasets.databricks_evaluation_dataset_source import (
226          DatabricksEvaluationDatasetSource,
227          DatabricksUCTableDatasetSource,
228      )
229  
230      _dataset_source_registry.register(DatabricksEvaluationDatasetSource)
231      _dataset_source_registry.register(DatabricksUCTableDatasetSource)
232  except ImportError:
233      pass