/ mlflow / utils / file_utils.py
file_utils.py
  1  import atexit
  2  import codecs
  3  import errno
  4  import fnmatch
  5  import gzip
  6  import importlib.util
  7  import json
  8  import logging
  9  import math
 10  import os
 11  import pathlib
 12  import posixpath
 13  import shutil
 14  import stat
 15  import subprocess
 16  import sys
 17  import tarfile
 18  import tempfile
 19  import time
 20  import urllib.parse
 21  import urllib.request
 22  from concurrent.futures import as_completed
 23  from contextlib import contextmanager
 24  from dataclasses import dataclass
 25  from subprocess import CalledProcessError, TimeoutExpired
 26  from types import TracebackType
 27  from typing import Any
 28  from urllib.parse import unquote
 29  from urllib.request import pathname2url
 30  
 31  from mlflow.entities import FileInfo
 32  from mlflow.environment_variables import (
 33      _MLFLOW_MPD_NUM_RETRIES,
 34      _MLFLOW_MPD_RETRY_INTERVAL_SECONDS,
 35      MLFLOW_DOWNLOAD_CHUNK_TIMEOUT,
 36      MLFLOW_ENABLE_ARTIFACTS_PROGRESS_BAR,
 37  )
 38  from mlflow.exceptions import MlflowException
 39  from mlflow.protos.databricks_artifacts_pb2 import ArtifactCredentialType
 40  from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
 41  from mlflow.utils import download_cloud_file_chunk
 42  from mlflow.utils.databricks_utils import (
 43      get_databricks_local_temp_dir,
 44      get_databricks_nfs_temp_dir,
 45  )
 46  from mlflow.utils.os import is_windows
 47  from mlflow.utils.process import cache_return_value_per_process
 48  from mlflow.utils.request_utils import cloud_storage_http_request, download_chunk
 49  from mlflow.utils.rest_utils import augmented_raise_for_status
 50  
 51  ENCODING = "utf-8"
 52  _PROGRESS_BAR_DISPLAY_THRESHOLD = 500_000_000  # 500 MB
 53  
 54  _logger = logging.getLogger(__name__)
 55  
 56  # This is for backward compatibility with databricks-feature-engineering<=0.10.2
 57  if importlib.util.find_spec("yaml") is not None:
 58      try:
 59          from yaml import CSafeDumper as YamlSafeDumper
 60      except ImportError:
 61          from yaml import SafeDumper as YamlSafeDumper  # noqa: F401
 62  
 63  
 64  class ArtifactProgressBar:
 65      def __init__(self, desc, total, step, **kwargs) -> None:
 66          self.desc = desc
 67          self.total = total
 68          self.step = step
 69          self.pbar = None
 70          self.progress = 0
 71          self.kwargs = kwargs
 72  
 73      def set_pbar(self):
 74          if MLFLOW_ENABLE_ARTIFACTS_PROGRESS_BAR.get():
 75              try:
 76                  from tqdm.auto import tqdm
 77  
 78                  self.pbar = tqdm(total=self.total, desc=self.desc, **self.kwargs)
 79              except ImportError:
 80                  pass
 81  
 82      @classmethod
 83      def chunks(cls, file_size, desc, chunk_size):
 84          bar = cls(
 85              desc,
 86              total=file_size,
 87              step=chunk_size,
 88              unit="iB",
 89              unit_scale=True,
 90              unit_divisor=1024,
 91              miniters=1,
 92          )
 93          if file_size >= _PROGRESS_BAR_DISPLAY_THRESHOLD:
 94              bar.set_pbar()
 95          return bar
 96  
 97      @classmethod
 98      def files(cls, desc, total):
 99          bar = cls(desc, total=total, step=1)
100          bar.set_pbar()
101          return bar
102  
103      def update(self):
104          if self.pbar:
105              update_step = min(self.total - self.progress, self.step)
106              self.pbar.update(update_step)
107              self.pbar.refresh()
108              self.progress += update_step
109  
110      def __enter__(self):
111          return self
112  
113      def __exit__(self, *args):
114          if self.pbar:
115              self.pbar.close()
116  
117  
118  def is_directory(name):
119      return os.path.isdir(name)
120  
121  
122  def is_file(name):
123      return os.path.isfile(name)
124  
125  
126  def exists(name):
127      return os.path.exists(name)
128  
129  
130  def list_all(root, filter_func=lambda x: True, full_path=False):
131      """List all entities directly under 'dir_name' that satisfy 'filter_func'
132  
133      Args:
134          root: Name of directory to start search.
135          filter_func: function or lambda that takes path.
136          full_path: If True will return results as full path including `root`.
137  
138      Returns:
139          list of all files or directories that satisfy the criteria.
140  
141      """
142      if not is_directory(root):
143          raise Exception(f"Invalid parent directory '{root}'")
144      matches = [x for x in os.listdir(root) if filter_func(os.path.join(root, x))]
145      return [os.path.join(root, m) for m in matches] if full_path else matches
146  
147  
148  def list_subdirs(dir_name, full_path=False):
149      """
150      Equivalent to UNIX command:
151        ``find $dir_name -depth 1 -type d``
152  
153      Args:
154          dir_name: Name of directory to start search.
155          full_path: If True will return results as full path including `root`.
156  
157      Returns:
158          list of all directories directly under 'dir_name'.
159      """
160      return list_all(dir_name, os.path.isdir, full_path)
161  
162  
163  def list_files(dir_name, full_path=False):
164      """
165      Equivalent to UNIX command:
166        ``find $dir_name -depth 1 -type f``
167  
168      Args:
169          dir_name: Name of directory to start search.
170          full_path: If True will return results as full path including `root`.
171  
172      Returns:
173          list of all files directly under 'dir_name'.
174      """
175      return list_all(dir_name, os.path.isfile, full_path)
176  
177  
178  def find(root, name, full_path=False):
179      """Search for a file in a root directory. Equivalent to:
180        ``find $root -name "$name" -depth 1``
181  
182      Args:
183          root: Name of root directory for find.
184          name: Name of file or directory to find directly under root directory.
185          full_path: If True will return results as full path including `root`.
186  
187      Returns:
188          list of matching files or directories.
189      """
190      path_name = os.path.join(root, name)
191      return list_all(root, lambda x: x == path_name, full_path)
192  
193  
194  def mkdir(root, name=None):
195      """Make directory with name "root/name", or just "root" if name is None.
196  
197      Args:
198          root: Name of parent directory.
199          name: Optional name of leaf directory.
200  
201      Returns:
202          Path to created directory.
203      """
204      target = os.path.join(root, name) if name is not None else root
205      try:
206          os.makedirs(target, exist_ok=True)
207      except OSError as e:
208          if e.errno != errno.EEXIST or not os.path.isdir(target):
209              raise e
210      return target
211  
212  
213  def make_containing_dirs(path):
214      """
215      Create the base directory for a given file path if it does not exist; also creates parent
216      directories.
217      """
218      dir_name = os.path.dirname(path)
219      if not os.path.exists(dir_name):
220          os.makedirs(dir_name)
221  
222  
223  class TempDir:
224      def __init__(self, chdr=False, remove_on_exit=True):
225          self._dir = None
226          self._path = None
227          self._chdr = chdr
228          self._remove = remove_on_exit
229  
230      def __enter__(self):
231          self._path = os.path.abspath(create_tmp_dir())
232          assert os.path.exists(self._path)
233          if self._chdr:
234              self._dir = os.path.abspath(os.getcwd())
235              os.chdir(self._path)
236          return self
237  
238      def __exit__(self, tp, val, traceback):
239          if self._chdr and self._dir:
240              os.chdir(self._dir)
241              self._dir = None
242          if self._remove and os.path.exists(self._path):
243              shutil.rmtree(self._path)
244  
245          assert not self._remove or not os.path.exists(self._path)
246  
247      def path(self, *path):
248          return os.path.join("./", *path) if self._chdr else os.path.join(self._path, *path)
249  
250  
251  def read_file_lines(parent_path, file_name):
252      """Return the contents of the file as an array where each element is a separate line.
253  
254      Args:
255          parent_path: Full path to the directory that contains the file.
256          file_name: Leaf file name.
257  
258      Returns:
259          All lines in the file as an array.
260  
261      """
262      file_path = os.path.join(parent_path, file_name)
263      with codecs.open(file_path, mode="r", encoding=ENCODING) as f:
264          return f.readlines()
265  
266  
267  def read_file(parent_path, file_name):
268      """Return the contents of the file.
269  
270      Args:
271          parent_path: Full path to the directory that contains the file.
272          file_name: Leaf file name.
273  
274      Returns:
275          The contents of the file.
276  
277      """
278      file_path = os.path.join(parent_path, file_name)
279      with codecs.open(file_path, mode="r", encoding=ENCODING) as f:
280          return f.read()
281  
282  
283  def get_file_info(path, rel_path):
284      """Returns file meta data : location, size, ... etc
285  
286      Args:
287          path: Path to artifact.
288          rel_path: Relative path.
289  
290      Returns:
291          `FileInfo` object
292      """
293      if is_directory(path):
294          return FileInfo(rel_path, True, None)
295      else:
296          return FileInfo(rel_path, False, os.path.getsize(path))
297  
298  
299  def mv(target, new_parent):
300      shutil.move(target, new_parent)
301  
302  
303  def write_to(filename, data):
304      with codecs.open(filename, mode="w", encoding=ENCODING) as handle:
305          handle.write(data)
306  
307  
308  def append_to(filename, data):
309      with open(filename, "a") as handle:
310          handle.write(data)
311  
312  
313  def make_tarfile(output_filename, source_dir, archive_name, custom_filter=None):
314      # Helper for filtering out modification timestamps
315      def _filter_timestamps(tar_info):
316          tar_info.mtime = 0
317          return tar_info if custom_filter is None else custom_filter(tar_info)
318  
319      unzipped_file_handle, unzipped_filename = tempfile.mkstemp()
320      try:
321          with tarfile.open(unzipped_filename, "w") as tar:
322              tar.add(source_dir, arcname=archive_name, filter=_filter_timestamps)
323          # When gzipping the tar, don't include the tar's filename or modification time in the
324          # zipped archive (see https://docs.python.org/3/library/gzip.html#gzip.GzipFile)
325          with (
326              gzip.GzipFile(
327                  filename="", fileobj=open(output_filename, "wb"), mode="wb", mtime=0
328              ) as gzipped_tar,
329              open(unzipped_filename, "rb") as tar,
330          ):
331              gzipped_tar.write(tar.read())
332      finally:
333          os.close(unzipped_file_handle)
334  
335  
336  def _copy_project(src_path, dst_path=""):
337      """Internal function used to copy MLflow project during development.
338  
339      Copies the content of the whole directory tree except patterns defined in .dockerignore.
340      The MLflow is assumed to be accessible as a local directory in this case.
341  
342      Args:
343          src_path: Path to the original MLflow project
344          dst_path: MLflow will be copied here
345  
346      Returns:
347          Name of the MLflow project directory.
348      """
349  
350      def _docker_ignore(mlflow_root):
351          docker_ignore = os.path.join(mlflow_root, ".dockerignore")
352          patterns = []
353          if os.path.exists(docker_ignore):
354              with open(docker_ignore) as f:
355                  patterns = [x.strip() for x in f]
356  
357          def ignore(_, names):
358              res = set()
359              for p in patterns:
360                  res.update(set(fnmatch.filter(names, p)))
361              return list(res)
362  
363          return ignore if patterns else None
364  
365      mlflow_dir = "mlflow-project"
366      # check if we have project root
367      assert os.path.isfile(os.path.join(src_path, "pyproject.toml")), "file not found " + str(
368          os.path.abspath(os.path.join(src_path, "pyproject.toml"))
369      )
370      shutil.copytree(src_path, os.path.join(dst_path, mlflow_dir), ignore=_docker_ignore(src_path))
371      return mlflow_dir
372  
373  
374  def _copy_file_or_tree(src, dst, dst_dir=None):
375      """
376      Returns:
377          The path to the copied artifacts, relative to `dst`.
378      """
379      dst_subpath = os.path.basename(os.path.abspath(src))
380      if dst_dir is not None:
381          dst_subpath = os.path.join(dst_dir, dst_subpath)
382      dst_path = os.path.join(dst, dst_subpath)
383      if os.path.isfile(src):
384          dst_dirpath = os.path.dirname(dst_path)
385          if not os.path.exists(dst_dirpath):
386              os.makedirs(dst_dirpath)
387          shutil.copy(src=src, dst=dst_path)
388      else:
389          shutil.copytree(src=src, dst=dst_path, ignore=shutil.ignore_patterns("__pycache__"))
390      return dst_subpath
391  
392  
393  def _get_local_project_dir_size(project_path):
394      """Internal function for reporting the size of a local project directory before copying to
395      destination for cli logging reporting to stdout.
396  
397      Args:
398          project_path: local path of the project directory
399  
400      Returns:
401          directory file sizes in KB, rounded to single decimal point for legibility
402      """
403  
404      total_size = 0
405      for root, _, files in os.walk(project_path):
406          for f in files:
407              path = os.path.join(root, f)
408              total_size += os.path.getsize(path)
409      return round(total_size / 1024.0, 1)
410  
411  
412  def _get_local_file_size(file):
413      """
414      Get the size of a local file in KB
415      """
416      return round(os.path.getsize(file) / 1024.0, 1)
417  
418  
419  def get_parent_dir(path):
420      return os.path.abspath(os.path.join(path, os.pardir))
421  
422  
423  def relative_path_to_artifact_path(path):
424      if os.path == posixpath:
425          return path
426      if os.path.abspath(path) == path:
427          raise Exception("This method only works with relative paths.")
428      return unquote(pathname2url(path))
429  
430  
431  def path_to_local_file_uri(path):
432      """
433      Convert local filesystem path to local file uri.
434      """
435      return pathlib.Path(os.path.abspath(path)).as_uri()
436  
437  
438  def path_to_local_sqlite_uri(path):
439      """
440      Convert local filesystem path to sqlite uri.
441      """
442      path = posixpath.abspath(pathname2url(os.path.abspath(path)))
443      prefix = "sqlite://" if sys.platform == "win32" else "sqlite:///"
444      return prefix + path
445  
446  
447  def local_file_uri_to_path(uri):
448      """
449      Convert URI to local filesystem path.
450      No-op if the uri does not have the expected scheme.
451      """
452      path = uri
453      if uri.startswith("file:"):
454          parsed_path = urllib.parse.urlparse(uri)
455          path = parsed_path.path
456          # Fix for retaining server name in UNC path.
457          if is_windows() and parsed_path.netloc:
458              return urllib.request.url2pathname(rf"\\{parsed_path.netloc}{path}")
459      return urllib.request.url2pathname(path)
460  
461  
462  def get_local_path_or_none(path_or_uri):
463      """Check if the argument is a local path (no scheme or file:///) and return local path if true,
464      None otherwise.
465      """
466      parsed_uri = urllib.parse.urlparse(path_or_uri)
467      if len(parsed_uri.scheme) == 0 or parsed_uri.scheme == "file" and len(parsed_uri.netloc) == 0:
468          return local_file_uri_to_path(path_or_uri)
469      else:
470          return None
471  
472  
473  def download_file_using_http_uri(http_uri, download_path, chunk_size=100000000, headers=None):
474      """
475      Downloads a file specified using the `http_uri` to a local `download_path`. This function
476      uses a `chunk_size` to ensure an OOM error is not raised a large file is downloaded.
477  
478      Note : This function is meant to download files using presigned urls from various cloud
479              providers.
480      """
481      if headers is None:
482          headers = {}
483      with cloud_storage_http_request("get", http_uri, stream=True, headers=headers) as response:
484          augmented_raise_for_status(response)
485          with open(download_path, "wb") as output_file:
486              for chunk in response.iter_content(chunk_size=chunk_size):
487                  if not chunk:
488                      break
489                  output_file.write(chunk)
490  
491  
492  @dataclass(frozen=True)
493  class _Chunk:
494      index: int
495      start: int
496      end: int
497      path: str
498  
499  
500  def _yield_chunks(path, file_size, chunk_size):
501      num_requests = int(math.ceil(file_size / float(chunk_size)))
502      for i in range(num_requests):
503          range_start = i * chunk_size
504          range_end = min(range_start + chunk_size - 1, file_size - 1)
505          yield _Chunk(i, range_start, range_end, path)
506  
507  
508  def parallelized_download_file_using_http_uri(
509      thread_pool_executor,
510      http_uri,
511      download_path,
512      remote_file_path,
513      file_size,
514      uri_type,
515      chunk_size,
516      env,
517      headers=None,
518  ):
519      """
520      Downloads a file specified using the `http_uri` to a local `download_path`. This function
521      sends multiple requests in parallel each specifying its own desired byte range as a header,
522      then reconstructs the file from the downloaded chunks. This allows for downloads of large files
523      without OOM risk.
524  
525      Note : This function is meant to download files using presigned urls from various cloud
526              providers.
527      Returns a dict of chunk index : exception, if one was thrown for that index.
528      """
529  
530      def run_download(chunk: _Chunk):
531          try:
532              subprocess.run(
533                  [
534                      sys.executable,
535                      download_cloud_file_chunk.__file__,
536                      "--range-start",
537                      str(chunk.start),
538                      "--range-end",
539                      str(chunk.end),
540                      "--headers",
541                      json.dumps(headers or {}),
542                      "--download-path",
543                      download_path,
544                      "--http-uri",
545                      http_uri,
546                  ],
547                  text=True,
548                  check=True,
549                  capture_output=True,
550                  timeout=MLFLOW_DOWNLOAD_CHUNK_TIMEOUT.get(),
551                  env=env,
552              )
553          except (TimeoutExpired, CalledProcessError) as e:
554              raise MlflowException(
555                  f"""
556  ----- stdout -----
557  {e.stdout.strip()}
558  
559  ----- stderr -----
560  {e.stderr.strip()}
561  """
562              ) from e
563  
564      chunks = _yield_chunks(remote_file_path, file_size, chunk_size)
565      # Create file if it doesn't exist or erase the contents if it does. We should do this here
566      # before sending to the workers so they can each individually seek to their respective positions
567      # and write chunks without overwriting.
568      with open(download_path, "w"):
569          pass
570      if uri_type == ArtifactCredentialType.GCP_SIGNED_URL or uri_type is None:
571          chunk = next(chunks)
572          # GCP files could be transcoded, in which case the range header is ignored.
573          # Test if this is the case by downloading one chunk and seeing if it's larger than the
574          # requested size. If yes, let that be the file; if not, continue downloading more chunks.
575          download_chunk(
576              range_start=chunk.start,
577              range_end=chunk.end,
578              headers=headers,
579              download_path=download_path,
580              http_uri=http_uri,
581          )
582          downloaded_size = os.path.getsize(download_path)
583          # If downloaded size was equal to the chunk size it would have been downloaded serially,
584          # so we don't need to consider this here
585          if downloaded_size > chunk_size:
586              return {}
587  
588      futures = {thread_pool_executor.submit(run_download, chunk): chunk for chunk in chunks}
589      failed_downloads = {}
590      with ArtifactProgressBar.chunks(file_size, f"Downloading {download_path}", chunk_size) as pbar:
591          for future in as_completed(futures):
592              chunk = futures[future]
593              try:
594                  future.result()
595              except Exception as e:
596                  _logger.debug(
597                      f"Failed to download chunk {chunk.index} for {chunk.path}: {e}. "
598                      f"The download of this chunk will be retried later."
599                  )
600                  failed_downloads[chunk] = future.exception()
601              else:
602                  pbar.update()
603  
604      return failed_downloads
605  
606  
607  def download_chunk_retries(*, chunks, http_uri, headers, download_path):
608      num_retries = _MLFLOW_MPD_NUM_RETRIES.get()
609      interval = _MLFLOW_MPD_RETRY_INTERVAL_SECONDS.get()
610      for chunk in chunks:
611          _logger.info(f"Retrying download of chunk {chunk.index} for {chunk.path}")
612          for retry in range(num_retries):
613              try:
614                  download_chunk(
615                      range_start=chunk.start,
616                      range_end=chunk.end,
617                      headers=headers,
618                      download_path=download_path,
619                      http_uri=http_uri,
620                  )
621                  _logger.info(f"Successfully downloaded chunk {chunk.index} for {chunk.path}")
622                  break
623              except Exception:
624                  if retry == num_retries - 1:
625                      raise
626              time.sleep(interval)
627  
628  
629  def _handle_readonly_on_windows(func, path, exc_info):
630      """
631      This function should not be called directly but should be passed to `onerror` of
632      `shutil.rmtree` in order to reattempt the removal of a read-only file after making
633      it writable on Windows.
634  
635      References:
636      - https://bugs.python.org/issue19643
637      - https://bugs.python.org/issue43657
638      """
639      exc_type, exc_value = exc_info[:2]
640      should_reattempt = (
641          is_windows()
642          and func in (os.unlink, os.rmdir)
643          and issubclass(exc_type, PermissionError)
644          and exc_value.winerror == 5
645      )
646      if not should_reattempt:
647          raise exc_value
648      os.chmod(path, stat.S_IWRITE)
649      func(path)
650  
651  
652  def _get_tmp_dir():
653      from mlflow.utils.databricks_utils import get_repl_id, is_in_databricks_runtime
654  
655      if is_in_databricks_runtime():
656          try:
657              return get_databricks_local_temp_dir()
658          except Exception:
659              pass
660  
661          if repl_id := get_repl_id():
662              return os.path.join("/tmp", "repl_tmp_data", repl_id)
663  
664      return None
665  
666  
667  def create_tmp_dir():
668      if directory := _get_tmp_dir():
669          os.makedirs(directory, exist_ok=True)
670          return tempfile.mkdtemp(dir=directory)
671  
672      return tempfile.mkdtemp()
673  
674  
675  @cache_return_value_per_process
676  def get_or_create_tmp_dir():
677      """
678      Get or create a temporary directory which will be removed once python process exit.
679      """
680      from mlflow.utils.databricks_utils import get_repl_id, is_in_databricks_runtime
681  
682      if is_in_databricks_runtime() and get_repl_id() is not None:
683          # Note: For python process attached to databricks notebook, atexit does not work.
684          # The directory returned by `get_databricks_local_tmp_dir`
685          # will be removed once databricks notebook detaches.
686          # The temp directory is designed to be used by all kinds of applications,
687          # so create a child directory "mlflow" for storing mlflow temp data.
688          try:
689              repl_local_tmp_dir = get_databricks_local_temp_dir()
690          except Exception:
691              repl_local_tmp_dir = os.path.join("/tmp", "repl_tmp_data", get_repl_id())
692  
693          tmp_dir = os.path.join(repl_local_tmp_dir, "mlflow")
694          os.makedirs(tmp_dir, exist_ok=True)
695      else:
696          tmp_dir = tempfile.mkdtemp()
697          # mkdtemp creates a directory with permission 0o700
698          # For Spark UDFs, we need to make it accessible to other processes
699          # Use 0o750 (owner: rwx, group: r-x, others: None) instead of 0o777
700          # This allows read/execute but not write for group and others
701          os.chmod(tmp_dir, 0o750)
702          atexit.register(shutil.rmtree, tmp_dir, ignore_errors=True)
703  
704      return tmp_dir
705  
706  
707  @cache_return_value_per_process
708  def get_or_create_nfs_tmp_dir():
709      """
710      Get or create a temporary NFS directory which will be removed once python process exit.
711      """
712      from mlflow.utils.databricks_utils import get_repl_id, is_in_databricks_runtime
713      from mlflow.utils.nfs_on_spark import get_nfs_cache_root_dir
714  
715      nfs_root_dir = get_nfs_cache_root_dir()
716  
717      if is_in_databricks_runtime() and get_repl_id() is not None:
718          # Note: In databricks, atexit hook does not work.
719          # The directory returned by `get_databricks_nfs_tmp_dir`
720          # will be removed once databricks notebook detaches.
721          # The temp directory is designed to be used by all kinds of applications,
722          # so create a child directory "mlflow" for storing mlflow temp data.
723          try:
724              repl_nfs_tmp_dir = get_databricks_nfs_temp_dir()
725          except Exception:
726              repl_nfs_tmp_dir = os.path.join(nfs_root_dir, "repl_tmp_data", get_repl_id())
727  
728          tmp_nfs_dir = os.path.join(repl_nfs_tmp_dir, "mlflow")
729          os.makedirs(tmp_nfs_dir, exist_ok=True)
730      else:
731          tmp_nfs_dir = tempfile.mkdtemp(dir=nfs_root_dir)
732          # mkdtemp creates a directory with permission 0o700
733          # For Spark UDFs, we need to make it accessible to other processes
734          # Use 0o750 (owner: rwx, group: r-x, others: None) instead of 0o777
735          os.chmod(tmp_nfs_dir, 0o750)
736          atexit.register(shutil.rmtree, tmp_nfs_dir, ignore_errors=True)
737  
738      return tmp_nfs_dir
739  
740  
741  def shutil_copytree_without_file_permissions(src_dir, dst_dir):
742      """
743      Copies the directory src_dir into dst_dir, without preserving filesystem permissions
744      """
745      for dirpath, dirnames, filenames in os.walk(src_dir):
746          for dirname in dirnames:
747              relative_dir_path = os.path.relpath(os.path.join(dirpath, dirname), src_dir)
748              # For each directory <dirname> immediately under <dirpath>, create an equivalently-named
749              # directory under the destination directory
750              abs_dir_path = os.path.join(dst_dir, relative_dir_path)
751              if not os.path.exists(abs_dir_path):
752                  os.mkdir(abs_dir_path)
753          for filename in filenames:
754              # For each file with name <filename> immediately under <dirpath>, copy that file to
755              # the appropriate location in the destination directory
756              file_path = os.path.join(dirpath, filename)
757              relative_file_path = os.path.relpath(file_path, src_dir)
758              abs_file_path = os.path.join(dst_dir, relative_file_path)
759              shutil.copy2(file_path, abs_file_path)
760  
761  
762  def contains_path_separator(path):
763      """
764      Returns True if a path contains a path separator, False otherwise.
765      """
766      return any((sep in path) for sep in (os.path.sep, os.path.altsep) if sep is not None)
767  
768  
769  def contains_percent(path):
770      """
771      Returns True if a path contains a percent character, False otherwise.
772      """
773      return "%" in path
774  
775  
776  def read_chunk(path: os.PathLike, size: int, start_byte: int = 0) -> bytes:
777      """Read a chunk of bytes from a file.
778  
779      Args:
780          path: Path to the file.
781          size: The size of the chunk.
782          start_byte: The start byte of the chunk.
783  
784      Returns:
785          The chunk of bytes.
786  
787      """
788      with open(path, "rb") as f:
789          if start_byte > 0:
790              f.seek(start_byte)
791          return f.read(size)
792  
793  
794  @contextmanager
795  def remove_on_error(path: os.PathLike, onerror=None):
796      """A context manager that removes a file or directory if an exception is raised during
797      execution.
798  
799      Args:
800          path: Path to the file or directory.
801          onerror: A callback function that will be called with the captured exception before
802              the file or directory is removed. For example, you can use this callback to
803              log the exception.
804  
805      """
806      try:
807          yield
808      except Exception as e:
809          if onerror:
810              onerror(e)
811          if os.path.exists(path):
812              if os.path.isfile(path):
813                  os.remove(path)
814              elif os.path.isdir(path):
815                  shutil.rmtree(path)
816          _logger.warning(
817              f"Failed to remove {path}" if os.path.exists(path) else f"Successfully removed {path}"
818          )
819          raise
820  
821  
822  def get_total_file_size(path: str | pathlib.Path) -> int | None:
823      """Return the size of all files under given path, including files in subdirectories.
824  
825      Args:
826          path: The absolute path of a local directory.
827  
828      Returns:
829          size in bytes.
830  
831      """
832      try:
833          if isinstance(path, pathlib.Path):
834              path = str(path)
835          if not os.path.exists(path):
836              raise MlflowException(
837                  message=f"The given {path} does not exist.", error_code=INVALID_PARAMETER_VALUE
838              )
839          if not os.path.isdir(path):
840              raise MlflowException(
841                  message=f"The given {path} is not a directory.", error_code=INVALID_PARAMETER_VALUE
842              )
843  
844          total_size = 0
845          for cur_path, dirs, files in os.walk(path):
846              full_paths = [os.path.join(cur_path, file) for file in files]
847              total_size += sum(map(os.path.getsize, full_paths))
848          return total_size
849      except Exception as e:
850          _logger.info(f"Failed to get the total size of {path} because of error :{e}")
851          return None
852  
853  
854  def write_yaml(
855      root: str,
856      file_name: str,
857      data: dict[str, Any],
858      overwrite: bool = False,
859      sort_keys: bool = True,
860      ensure_yaml_extension: bool = True,
861  ) -> None:
862      """
863      NEVER TOUCH THIS FUNCTION. KEPT FOR BACKWARD COMPATIBILITY with
864      databricks-feature-engineering<=0.10.2
865      """
866      import yaml
867  
868      with open(os.path.join(root, file_name), "w") as f:
869          yaml.safe_dump(
870              data,
871              f,
872              default_flow_style=False,
873              allow_unicode=True,
874              sort_keys=sort_keys,
875          )
876  
877  
878  def read_yaml(root: str, file_name: str) -> dict[str, Any]:
879      """
880      NEVER TOUCH THIS FUNCTION. KEPT FOR BACKWARD COMPATIBILITY with
881      databricks-feature-engineering<=0.10.2
882      """
883      import yaml
884  
885      with open(os.path.join(root, file_name)) as f:
886          return yaml.safe_load(f)
887  
888  
889  class ExclusiveFileLock:
890      """
891      Exclusive file lock (only works on Unix system)
892      """
893  
894      def __init__(self, path: str):
895          if os.name == "nt":
896              raise MlflowException("ExclusiveFileLock class does not support Windows system.")
897          self.path = path
898          self.fd = None
899  
900      def __enter__(self) -> None:
901          # Python on Windows does not have `fcntl` module, so importing it lazily.
902          import fcntl  # clint: disable=lazy-import
903  
904          # Open file (create if missing)
905          self.fd = open(self.path, "w")
906          # Acquire exclusive lock (blocking)
907          fcntl.flock(self.fd, fcntl.LOCK_EX)
908  
909      def __exit__(
910          self,
911          exc_type: type[BaseException] | None,
912          exc_val: BaseException | None,
913          exc_tb: TracebackType | None,
914      ):
915          # Python on Windows does not have `fcntl` module, so importing it lazily.
916          import fcntl  # clint: disable=lazy-import
917  
918          # Release lock
919          fcntl.flock(self.fd, fcntl.LOCK_UN)
920          self.fd.close()
921  
922  
923  def check_tarfile_security(archive_path: str) -> None:
924      """
925      Check the tar file content.
926      If its members contain any of the following paths:
927       * An absolute path.
928       * A relative path that escapes the extraction directory.
929       * A relative path that goes through a symlink.
930      then raise an error.
931      """
932      with tarfile.open(archive_path, "r") as tar:
933          symlink_set = set()
934          for m in tar.getmembers():
935              # Normalize backslashes to forward slashes before path validation to prevent
936              # bypass on Windows where backslashes are treated as directory separators.
937              path = posixpath.normpath(m.name.replace("\\", "/"))
938              _check_path_is_safe(path)
939              if m.issym():
940                  symlink_set.add(path)
941              elif m.islnk():
942                  symlink_set.add(path)
943                  # Hard link targets are dangerous: tar.extract creates an actual hard
944                  # link to the target path, so validate they don't escape.
945                  link_target = posixpath.normpath(m.linkname.replace("\\", "/"))
946                  _check_path_is_safe(link_target, context=f"hard link target of {path}")
947                  link_parent = posixpath.dirname(path)
948                  resolved = posixpath.normpath(posixpath.join(link_parent, link_target))
949                  if resolved == ".." or resolved.startswith("../"):
950                      raise MlflowException.invalid_parameter_value(
951                          "Hard link target that escapes the extraction directory is not "
952                          f"allowed, but got {path} -> {link_target}."
953                      )
954          for m in tar.getmembers():
955              if not m.issym() and not m.islnk():
956                  path = posixpath.normpath(m.name.replace("\\", "/"))
957                  path_parts = path.split("/")
958                  for prefix_len in range(1, len(path_parts) + 1):
959                      prefix_path = "/".join(path_parts[:prefix_len])
960                      if prefix_path in symlink_set:
961                          raise MlflowException.invalid_parameter_value(
962                              "Destination path in the archive file can not go through a symlink, "
963                              f"but got path {path}."
964                          )
965  
966  
967  def _check_path_is_safe(path: str, context: str = "") -> None:
968      label = f" ({context})" if context else ""
969      # Reject Unix absolute paths
970      if path.startswith("/"):
971          raise MlflowException.invalid_parameter_value(
972              f"Absolute path destination in the archive file is not allowed, "
973              f"but got path {path}{label}."
974          )
975      # Reject Windows drive-letter absolute paths (e.g., C:/...) and UNC paths (//server/...)
976      if len(path) >= 3 and path[0].isalpha() and path[1:3] == ":/":
977          raise MlflowException.invalid_parameter_value(
978              f"Absolute path destination in the archive file is not allowed, "
979              f"but got path {path}{label}."
980          )
981      path_parts = path.split("/")
982      if path_parts[0] == "..":
983          raise MlflowException.invalid_parameter_value(
984              f"Escaped path destination in the archive file is not allowed, "
985              f"but got path {path}{label}."
986          )