/ mlflow / data / huggingface_dataset_source.py
huggingface_dataset_source.py
  1  from typing import TYPE_CHECKING, Any, Mapping, Sequence, Union
  2  
  3  from packaging.version import Version
  4  
  5  from mlflow.data.dataset_source import DatasetSource
  6  
  7  if TYPE_CHECKING:
  8      import datasets
  9  
 10  
 11  class HuggingFaceDatasetSource(DatasetSource):
 12      """Represents the source of a Hugging Face dataset used in MLflow Tracking."""
 13  
 14      def __init__(
 15          self,
 16          path: str,
 17          config_name: str | None = None,
 18          data_dir: str | None = None,
 19          data_files: str | Sequence[str] | Mapping[str, str | Sequence[str]] | None = None,
 20          split: Union[str, "datasets.Split"] | None = None,
 21          revision: Union[str, "datasets.Version"] | None = None,
 22          trust_remote_code: bool | None = None,
 23      ):
 24          """Create a `HuggingFaceDatasetSource` instance.
 25  
 26          Arguments in `__init__` match arguments of the same name in
 27          `datasets.load_dataset() <https://huggingface.co/docs/datasets/v2.14.5/en/package_reference/loading_methods#datasets.load_dataset>`_.
 28          The only exception is `config_name` matches `name` in `datasets.load_dataset()`, because
 29          we need to differentiate from `mlflow.data.Dataset` `name` attribute.
 30  
 31          Args:
 32              path: The path of the Hugging Face dataset, if it is a dataset from HuggingFace hub,
 33                  `path` must match the hub path, e.g., "databricks/databricks-dolly-15k".
 34              config_name: The name of of the Hugging Face dataset configuration.
 35              data_dir: The `data_dir` of the Hugging Face dataset configuration.
 36              data_files: Paths to source data file(s) for the Hugging Face dataset configuration.
 37              split: Which split of the data to load.
 38              revision: Version of the dataset script to load.
 39              trust_remote_code: Whether to trust remote code from the dataset repo.
 40          """
 41          self.path = path
 42          self.config_name = config_name
 43          self.data_dir = data_dir
 44          self.data_files = data_files
 45          self.split = split
 46          self.revision = revision
 47          self.trust_remote_code = trust_remote_code
 48  
 49      @staticmethod
 50      def _get_source_type() -> str:
 51          return "hugging_face"
 52  
 53      def load(self, **kwargs):
 54          """Load the Hugging Face dataset based on `HuggingFaceDatasetSource`.
 55  
 56          Args:
 57              kwargs: Additional keyword arguments used for loading the dataset with the Hugging Face
 58                  `datasets.load_dataset()` method.
 59  
 60          Returns:
 61              An instance of `datasets.Dataset`.
 62          """
 63          import datasets
 64  
 65          load_kwargs = {
 66              "path": self.path,
 67              "name": self.config_name,
 68              "data_dir": self.data_dir,
 69              "data_files": self.data_files,
 70              "split": self.split,
 71              "revision": self.revision,
 72          }
 73  
 74          # this argument only exists in >= 2.16.0
 75          if Version(datasets.__version__) >= Version("2.16.0"):
 76              load_kwargs["trust_remote_code"] = self.trust_remote_code
 77  
 78          if intersecting_keys := set(load_kwargs.keys()) & set(kwargs.keys()):
 79              raise KeyError(
 80                  f"Found duplicated arguments in `HuggingFaceDatasetSource` and "
 81                  f"`kwargs`: {intersecting_keys}. Please remove them from `kwargs`."
 82              )
 83          load_kwargs.update(kwargs)
 84          return datasets.load_dataset(**load_kwargs)
 85  
 86      @staticmethod
 87      def _can_resolve(raw_source: Any):
 88          # NB: Initially, we expect that Hugging Face dataset sources will only be used with
 89          # Hugging Face datasets constructed by from_huggingface_dataset, which can create
 90          # an instance of HuggingFaceDatasetSource directly without the need for resolution
 91          return False
 92  
 93      @classmethod
 94      def _resolve(cls, raw_source: str) -> "HuggingFaceDatasetSource":
 95          raise NotImplementedError
 96  
 97      def to_dict(self) -> dict[Any, Any]:
 98          return {
 99              "path": self.path,
100              "config_name": self.config_name,
101              "data_dir": self.data_dir,
102              "data_files": self.data_files,
103              "split": str(self.split),
104              "revision": self.revision,
105          }
106  
107      @classmethod
108      def from_dict(cls, source_dict: dict[Any, Any]) -> "HuggingFaceDatasetSource":
109          return cls(
110              path=source_dict.get("path"),
111              config_name=source_dict.get("config_name"),
112              data_dir=source_dict.get("data_dir"),
113              data_files=source_dict.get("data_files"),
114              split=source_dict.get("split"),
115              revision=source_dict.get("revision"),
116          )