dataset_source_registry.py
1 import warnings 2 from typing import Any 3 4 from mlflow.data.dataset_source import DatasetSource 5 from mlflow.data.http_dataset_source import HTTPDatasetSource 6 from mlflow.exceptions import MlflowException 7 from mlflow.protos.databricks_pb2 import RESOURCE_DOES_NOT_EXIST 8 from mlflow.utils.plugins import get_entry_points 9 10 11 class DatasetSourceRegistry: 12 def __init__(self): 13 self.sources = [] 14 15 def register(self, source: DatasetSource): 16 """Registers a DatasetSource for use with MLflow Tracking. 17 18 Args: 19 source: The DatasetSource to register. 20 """ 21 self.sources.append(source) 22 23 def register_entrypoints(self): 24 """ 25 Registers dataset sources defined as Python entrypoints. For reference, see 26 https://mlflow.org/docs/latest/plugins.html#defining-a-plugin. 27 """ 28 for entrypoint in get_entry_points("mlflow.dataset_source"): 29 try: 30 self.register(entrypoint.load()) 31 except (AttributeError, ImportError) as exc: 32 warnings.warn( 33 "Failure attempting to register dataset constructor" 34 + f' "{entrypoint}": {exc}', 35 stacklevel=2, 36 ) 37 38 def resolve( 39 self, raw_source: Any, candidate_sources: list[DatasetSource] | None = None 40 ) -> DatasetSource: 41 """Resolves a raw source object, such as a string URI, to a DatasetSource for use with 42 MLflow Tracking. 43 44 Args: 45 raw_source: The raw source, e.g. a string like "s3://mybucket/path/to/iris/data" or a 46 HuggingFace :py:class:`datasets.Dataset` object. 47 candidate_sources: A list of DatasetSource classes to consider as potential sources 48 when resolving the raw source. Subclasses of the specified candidate sources are 49 also considered. If unspecified, all registered sources are considered. 50 51 Raises: 52 MlflowException: If no DatasetSource class can resolve the raw source. 53 54 Returns: 55 The resolved DatasetSource. 56 """ 57 matching_sources = [] 58 for source in self.sources: 59 if candidate_sources and not any( 60 issubclass(source, candidate_src) for candidate_src in candidate_sources 61 ): 62 continue 63 try: 64 if source._can_resolve(raw_source): 65 matching_sources.append(source) 66 except Exception as e: 67 warnings.warn( 68 f"Failed to determine whether {source.__name__} can resolve source" 69 f" information for '{raw_source}'. Exception: {e}", 70 stacklevel=2, 71 ) 72 continue 73 74 if len(matching_sources) > 1: 75 source_class_names_str = ", ".join([source.__name__ for source in matching_sources]) 76 warnings.warn( 77 f"The specified dataset source can be interpreted in multiple ways:" 78 f" {source_class_names_str}. MLflow will assume that this is a" 79 f" {matching_sources[-1].__name__} source.", 80 stacklevel=2, 81 ) 82 83 for matching_source in reversed(matching_sources): 84 try: 85 return matching_source._resolve(raw_source) 86 except Exception as e: 87 warnings.warn( 88 f"Encountered an unexpected error while using {matching_source.__name__} to" 89 f" resolve source information for '{raw_source}'. Exception: {e}", 90 stacklevel=2, 91 ) 92 continue 93 94 raise MlflowException( 95 f"Could not find a source information resolver for the specified" 96 f" dataset source: {raw_source}.", 97 RESOURCE_DOES_NOT_EXIST, 98 ) 99 100 def get_source_from_json(self, source_json: str, source_type: str) -> DatasetSource: 101 """Parses and returns a DatasetSource object from its JSON representation. 102 103 Args: 104 source_json: The JSON representation of the DatasetSource. 105 source_type: The string type of the DatasetSource, which indicates how to parse the 106 source JSON. 107 """ 108 for source in reversed(self.sources): 109 if source._get_source_type() == source_type: 110 return source.from_json(source_json) 111 112 raise MlflowException( 113 f"Could not parse dataset source from JSON due to unrecognized" 114 f" source type: {source_type}.", 115 RESOURCE_DOES_NOT_EXIST, 116 ) 117 118 119 def register_dataset_source(source: DatasetSource): 120 """Registers a DatasetSource for use with MLflow Tracking. 121 122 Args: 123 source: The DatasetSource to register. 124 """ 125 _dataset_source_registry.register(source) 126 127 128 def resolve_dataset_source( 129 raw_source: Any, candidate_sources: list[DatasetSource] | None = None 130 ) -> DatasetSource: 131 """Resolves a raw source object, such as a string URI, to a DatasetSource for use with 132 MLflow Tracking. 133 134 Args: 135 raw_source: The raw source, e.g. a string like "s3://mybucket/path/to/iris/data" or a 136 HuggingFace :py:class:`datasets.Dataset` object. 137 candidate_sources: A list of DatasetSource classes to consider as potential sources 138 when resolving the raw source. Subclasses of the specified candidate 139 sources are also considered. If unspecified, all registered sources 140 are considered. 141 142 Raises: 143 MlflowException: If no DatasetSource class can resolve the raw source. 144 145 Returns: 146 The resolved DatasetSource. 147 """ 148 return _dataset_source_registry.resolve( 149 raw_source=raw_source, candidate_sources=candidate_sources 150 ) 151 152 153 def get_dataset_source_from_json(source_json: str, source_type: str) -> DatasetSource: 154 """Parses and returns a DatasetSource object from its JSON representation. 155 156 Args: 157 source_json: The JSON representation of the DatasetSource. 158 source_type: The string type of the DatasetSource, which indicates how to parse the 159 source JSON. 160 """ 161 return _dataset_source_registry.get_source_from_json( 162 source_json=source_json, source_type=source_type 163 ) 164 165 166 def get_registered_sources() -> list[DatasetSource]: 167 """Obtains the registered dataset sources. 168 169 Returns: 170 A list of registered dataset sources. 171 172 """ 173 return _dataset_source_registry.sources 174 175 176 # NB: The ordering here is important. The last dataset source to be registered takes precedence 177 # when resolving dataset information for a raw source (e.g. a string like "s3://mybucket/my/path"). 178 # Dataset sources derived from artifact repositories are the most generic / provide the most 179 # general information about dataset source locations, so they are registered first. More specific 180 # source information is provided by specialized dataset platform sources like 181 # HuggingFaceDatasetSource, so these sources are registered next. Finally, externally-defined 182 # dataset sources are registered last because externally-defined behavior should take precedence 183 # over any internally-defined generic behavior 184 _dataset_source_registry = DatasetSourceRegistry() 185 186 # Register artifact sources first (they should take lower precedence) 187 from mlflow.data.artifact_dataset_sources import register_artifact_dataset_sources 188 189 register_artifact_dataset_sources() 190 191 _dataset_source_registry.register(HTTPDatasetSource) 192 _dataset_source_registry.register_entrypoints() 193 194 try: 195 from mlflow.data.huggingface_dataset_source import HuggingFaceDatasetSource 196 197 _dataset_source_registry.register(HuggingFaceDatasetSource) 198 except ImportError: 199 pass 200 try: 201 from mlflow.data.spark_dataset_source import SparkDatasetSource 202 203 _dataset_source_registry.register(SparkDatasetSource) 204 except ImportError: 205 pass 206 try: 207 from mlflow.data.delta_dataset_source import DeltaDatasetSource 208 209 _dataset_source_registry.register(DeltaDatasetSource) 210 except ImportError: 211 pass 212 try: 213 from mlflow.data.code_dataset_source import CodeDatasetSource 214 215 _dataset_source_registry.register(CodeDatasetSource) 216 except ImportError: 217 pass 218 try: 219 from mlflow.data.uc_volume_dataset_source import UCVolumeDatasetSource 220 221 _dataset_source_registry.register(UCVolumeDatasetSource) 222 except ImportError: 223 pass 224 try: 225 from mlflow.genai.datasets.databricks_evaluation_dataset_source import ( 226 DatabricksEvaluationDatasetSource, 227 DatabricksUCTableDatasetSource, 228 ) 229 230 _dataset_source_registry.register(DatabricksEvaluationDatasetSource) 231 _dataset_source_registry.register(DatabricksUCTableDatasetSource) 232 except ImportError: 233 pass