/ mlflow / data / http_dataset_source.py
http_dataset_source.py
  1  import os
  2  import re
  3  from typing import Any
  4  from urllib.parse import urlparse
  5  
  6  from mlflow.data.dataset_source import DatasetSource
  7  from mlflow.exceptions import MlflowException
  8  from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
  9  from mlflow.utils.file_utils import create_tmp_dir
 10  from mlflow.utils.rest_utils import augmented_raise_for_status, cloud_storage_http_request
 11  
 12  
 13  def _is_path(filename: str) -> bool:
 14      """
 15      Return True if `filename` is a path, False otherwise. For example,
 16      "foo/bar" is a path, but "bar" is not.
 17      """
 18      return os.path.basename(filename) != filename
 19  
 20  
 21  class HTTPDatasetSource(DatasetSource):
 22      """
 23      Represents the source of a dataset stored at a web location and referred to
 24      by an HTTP or HTTPS URL.
 25      """
 26  
 27      def __init__(self, url):
 28          self._url = url
 29  
 30      @property
 31      def url(self):
 32          """The HTTP/S URL referring to the dataset source location.
 33  
 34          Returns:
 35              The HTTP/S URL referring to the dataset source location.
 36  
 37          """
 38          return self._url
 39  
 40      @staticmethod
 41      def _get_source_type() -> str:
 42          return "http"
 43  
 44      def _extract_filename(self, response) -> str:
 45          """
 46          Extracts a filename from the Content-Disposition header or the URL's path.
 47          """
 48          if content_disposition := response.headers.get("Content-Disposition"):
 49              for match in re.finditer(r"filename=(.+)", content_disposition):
 50                  filename = match[1].strip("'\"")
 51                  if _is_path(filename):
 52                      raise MlflowException.invalid_parameter_value(
 53                          f"Invalid filename in Content-Disposition header: {filename}. "
 54                          "It must be a file name, not a path."
 55                      )
 56                  return filename
 57  
 58          # Extract basename from URL if no valid filename in Content-Disposition
 59          return os.path.basename(urlparse(self.url).path)
 60  
 61      def load(self, dst_path=None) -> str:
 62          """Downloads the dataset source to the local filesystem.
 63  
 64          Args:
 65              dst_path: Path of the local filesystem destination directory to which to download the
 66                  dataset source. If the directory does not exist, it is created. If
 67                  unspecified, the dataset source is downloaded to a new uniquely-named
 68                  directory on the local filesystem.
 69  
 70          Returns:
 71              The path to the downloaded dataset source on the local filesystem.
 72  
 73          """
 74          resp = cloud_storage_http_request(
 75              method="GET",
 76              url=self.url,
 77              stream=True,
 78          )
 79          augmented_raise_for_status(resp)
 80  
 81          basename = self._extract_filename(resp)
 82  
 83          if not basename:
 84              basename = "dataset_source"
 85  
 86          if dst_path is None:
 87              dst_path = create_tmp_dir()
 88  
 89          dst_path = os.path.join(dst_path, basename)
 90          with open(dst_path, "wb") as f:
 91              chunk_size = 1024 * 1024  # 1 MB
 92              for chunk in resp.iter_content(chunk_size=chunk_size):
 93                  f.write(chunk)
 94  
 95          return dst_path
 96  
 97      @staticmethod
 98      def _can_resolve(raw_source: Any) -> bool:
 99          """
100          Args:
101              raw_source: The raw source, e.g. a string like "http://mysite/mydata.tar.gz".
102  
103          Returns:
104              True if this DatasetSource can resolve the raw source, False otherwise.
105          """
106          if not isinstance(raw_source, str):
107              return False
108  
109          try:
110              parsed_source = urlparse(str(raw_source))
111              return parsed_source.scheme in ["http", "https"]
112          except Exception:
113              return False
114  
115      @classmethod
116      def _resolve(cls, raw_source: Any) -> "HTTPDatasetSource":
117          """
118          Args:
119              raw_source: The raw source, e.g. a string like "http://mysite/mydata.tar.gz".
120          """
121          return HTTPDatasetSource(raw_source)
122  
123      def to_dict(self) -> dict[Any, Any]:
124          """
125          Returns:
126              A JSON-compatible dictionary representation of the HTTPDatasetSource.
127          """
128          return {
129              "url": self.url,
130          }
131  
132      @classmethod
133      def from_dict(cls, source_dict: dict[Any, Any]) -> "HTTPDatasetSource":
134          """
135          Args:
136              source_dict: A dictionary representation of the HTTPDatasetSource.
137          """
138          url = source_dict.get("url")
139          if url is None:
140              raise MlflowException(
141                  'Failed to parse HTTPDatasetSource. Missing expected key: "url"',
142                  INVALID_PARAMETER_VALUE,
143              )
144  
145          return cls(url=url)