huggingface_dataset_source.py
1 from typing import TYPE_CHECKING, Any, Mapping, Sequence, Union 2 3 from packaging.version import Version 4 5 from mlflow.data.dataset_source import DatasetSource 6 7 if TYPE_CHECKING: 8 import datasets 9 10 11 class HuggingFaceDatasetSource(DatasetSource): 12 """Represents the source of a Hugging Face dataset used in MLflow Tracking.""" 13 14 def __init__( 15 self, 16 path: str, 17 config_name: str | None = None, 18 data_dir: str | None = None, 19 data_files: str | Sequence[str] | Mapping[str, str | Sequence[str]] | None = None, 20 split: Union[str, "datasets.Split"] | None = None, 21 revision: Union[str, "datasets.Version"] | None = None, 22 trust_remote_code: bool | None = None, 23 ): 24 """Create a `HuggingFaceDatasetSource` instance. 25 26 Arguments in `__init__` match arguments of the same name in 27 `datasets.load_dataset() <https://huggingface.co/docs/datasets/v2.14.5/en/package_reference/loading_methods#datasets.load_dataset>`_. 28 The only exception is `config_name` matches `name` in `datasets.load_dataset()`, because 29 we need to differentiate from `mlflow.data.Dataset` `name` attribute. 30 31 Args: 32 path: The path of the Hugging Face dataset, if it is a dataset from HuggingFace hub, 33 `path` must match the hub path, e.g., "databricks/databricks-dolly-15k". 34 config_name: The name of of the Hugging Face dataset configuration. 35 data_dir: The `data_dir` of the Hugging Face dataset configuration. 36 data_files: Paths to source data file(s) for the Hugging Face dataset configuration. 37 split: Which split of the data to load. 38 revision: Version of the dataset script to load. 39 trust_remote_code: Whether to trust remote code from the dataset repo. 40 """ 41 self.path = path 42 self.config_name = config_name 43 self.data_dir = data_dir 44 self.data_files = data_files 45 self.split = split 46 self.revision = revision 47 self.trust_remote_code = trust_remote_code 48 49 @staticmethod 50 def _get_source_type() -> str: 51 return "hugging_face" 52 53 def load(self, **kwargs): 54 """Load the Hugging Face dataset based on `HuggingFaceDatasetSource`. 55 56 Args: 57 kwargs: Additional keyword arguments used for loading the dataset with the Hugging Face 58 `datasets.load_dataset()` method. 59 60 Returns: 61 An instance of `datasets.Dataset`. 62 """ 63 import datasets 64 65 load_kwargs = { 66 "path": self.path, 67 "name": self.config_name, 68 "data_dir": self.data_dir, 69 "data_files": self.data_files, 70 "split": self.split, 71 "revision": self.revision, 72 } 73 74 # this argument only exists in >= 2.16.0 75 if Version(datasets.__version__) >= Version("2.16.0"): 76 load_kwargs["trust_remote_code"] = self.trust_remote_code 77 78 if intersecting_keys := set(load_kwargs.keys()) & set(kwargs.keys()): 79 raise KeyError( 80 f"Found duplicated arguments in `HuggingFaceDatasetSource` and " 81 f"`kwargs`: {intersecting_keys}. Please remove them from `kwargs`." 82 ) 83 load_kwargs.update(kwargs) 84 return datasets.load_dataset(**load_kwargs) 85 86 @staticmethod 87 def _can_resolve(raw_source: Any): 88 # NB: Initially, we expect that Hugging Face dataset sources will only be used with 89 # Hugging Face datasets constructed by from_huggingface_dataset, which can create 90 # an instance of HuggingFaceDatasetSource directly without the need for resolution 91 return False 92 93 @classmethod 94 def _resolve(cls, raw_source: str) -> "HuggingFaceDatasetSource": 95 raise NotImplementedError 96 97 def to_dict(self) -> dict[Any, Any]: 98 return { 99 "path": self.path, 100 "config_name": self.config_name, 101 "data_dir": self.data_dir, 102 "data_files": self.data_files, 103 "split": str(self.split), 104 "revision": self.revision, 105 } 106 107 @classmethod 108 def from_dict(cls, source_dict: dict[Any, Any]) -> "HuggingFaceDatasetSource": 109 return cls( 110 path=source_dict.get("path"), 111 config_name=source_dict.get("config_name"), 112 data_dir=source_dict.get("data_dir"), 113 data_files=source_dict.get("data_files"), 114 split=source_dict.get("split"), 115 revision=source_dict.get("revision"), 116 )