file_utils.py
1 import atexit 2 import codecs 3 import errno 4 import fnmatch 5 import gzip 6 import importlib.util 7 import json 8 import logging 9 import math 10 import os 11 import pathlib 12 import posixpath 13 import shutil 14 import stat 15 import subprocess 16 import sys 17 import tarfile 18 import tempfile 19 import time 20 import urllib.parse 21 import urllib.request 22 from concurrent.futures import as_completed 23 from contextlib import contextmanager 24 from dataclasses import dataclass 25 from subprocess import CalledProcessError, TimeoutExpired 26 from types import TracebackType 27 from typing import Any 28 from urllib.parse import unquote 29 from urllib.request import pathname2url 30 31 from mlflow.entities import FileInfo 32 from mlflow.environment_variables import ( 33 _MLFLOW_MPD_NUM_RETRIES, 34 _MLFLOW_MPD_RETRY_INTERVAL_SECONDS, 35 MLFLOW_DOWNLOAD_CHUNK_TIMEOUT, 36 MLFLOW_ENABLE_ARTIFACTS_PROGRESS_BAR, 37 ) 38 from mlflow.exceptions import MlflowException 39 from mlflow.protos.databricks_artifacts_pb2 import ArtifactCredentialType 40 from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE 41 from mlflow.utils import download_cloud_file_chunk 42 from mlflow.utils.databricks_utils import ( 43 get_databricks_local_temp_dir, 44 get_databricks_nfs_temp_dir, 45 ) 46 from mlflow.utils.os import is_windows 47 from mlflow.utils.process import cache_return_value_per_process 48 from mlflow.utils.request_utils import cloud_storage_http_request, download_chunk 49 from mlflow.utils.rest_utils import augmented_raise_for_status 50 51 ENCODING = "utf-8" 52 _PROGRESS_BAR_DISPLAY_THRESHOLD = 500_000_000 # 500 MB 53 54 _logger = logging.getLogger(__name__) 55 56 # This is for backward compatibility with databricks-feature-engineering<=0.10.2 57 if importlib.util.find_spec("yaml") is not None: 58 try: 59 from yaml import CSafeDumper as YamlSafeDumper 60 except ImportError: 61 from yaml import SafeDumper as YamlSafeDumper # noqa: F401 62 63 64 class ArtifactProgressBar: 65 def __init__(self, desc, total, step, **kwargs) -> None: 66 self.desc = desc 67 self.total = total 68 self.step = step 69 self.pbar = None 70 self.progress = 0 71 self.kwargs = kwargs 72 73 def set_pbar(self): 74 if MLFLOW_ENABLE_ARTIFACTS_PROGRESS_BAR.get(): 75 try: 76 from tqdm.auto import tqdm 77 78 self.pbar = tqdm(total=self.total, desc=self.desc, **self.kwargs) 79 except ImportError: 80 pass 81 82 @classmethod 83 def chunks(cls, file_size, desc, chunk_size): 84 bar = cls( 85 desc, 86 total=file_size, 87 step=chunk_size, 88 unit="iB", 89 unit_scale=True, 90 unit_divisor=1024, 91 miniters=1, 92 ) 93 if file_size >= _PROGRESS_BAR_DISPLAY_THRESHOLD: 94 bar.set_pbar() 95 return bar 96 97 @classmethod 98 def files(cls, desc, total): 99 bar = cls(desc, total=total, step=1) 100 bar.set_pbar() 101 return bar 102 103 def update(self): 104 if self.pbar: 105 update_step = min(self.total - self.progress, self.step) 106 self.pbar.update(update_step) 107 self.pbar.refresh() 108 self.progress += update_step 109 110 def __enter__(self): 111 return self 112 113 def __exit__(self, *args): 114 if self.pbar: 115 self.pbar.close() 116 117 118 def is_directory(name): 119 return os.path.isdir(name) 120 121 122 def is_file(name): 123 return os.path.isfile(name) 124 125 126 def exists(name): 127 return os.path.exists(name) 128 129 130 def list_all(root, filter_func=lambda x: True, full_path=False): 131 """List all entities directly under 'dir_name' that satisfy 'filter_func' 132 133 Args: 134 root: Name of directory to start search. 135 filter_func: function or lambda that takes path. 136 full_path: If True will return results as full path including `root`. 137 138 Returns: 139 list of all files or directories that satisfy the criteria. 140 141 """ 142 if not is_directory(root): 143 raise Exception(f"Invalid parent directory '{root}'") 144 matches = [x for x in os.listdir(root) if filter_func(os.path.join(root, x))] 145 return [os.path.join(root, m) for m in matches] if full_path else matches 146 147 148 def list_subdirs(dir_name, full_path=False): 149 """ 150 Equivalent to UNIX command: 151 ``find $dir_name -depth 1 -type d`` 152 153 Args: 154 dir_name: Name of directory to start search. 155 full_path: If True will return results as full path including `root`. 156 157 Returns: 158 list of all directories directly under 'dir_name'. 159 """ 160 return list_all(dir_name, os.path.isdir, full_path) 161 162 163 def list_files(dir_name, full_path=False): 164 """ 165 Equivalent to UNIX command: 166 ``find $dir_name -depth 1 -type f`` 167 168 Args: 169 dir_name: Name of directory to start search. 170 full_path: If True will return results as full path including `root`. 171 172 Returns: 173 list of all files directly under 'dir_name'. 174 """ 175 return list_all(dir_name, os.path.isfile, full_path) 176 177 178 def find(root, name, full_path=False): 179 """Search for a file in a root directory. Equivalent to: 180 ``find $root -name "$name" -depth 1`` 181 182 Args: 183 root: Name of root directory for find. 184 name: Name of file or directory to find directly under root directory. 185 full_path: If True will return results as full path including `root`. 186 187 Returns: 188 list of matching files or directories. 189 """ 190 path_name = os.path.join(root, name) 191 return list_all(root, lambda x: x == path_name, full_path) 192 193 194 def mkdir(root, name=None): 195 """Make directory with name "root/name", or just "root" if name is None. 196 197 Args: 198 root: Name of parent directory. 199 name: Optional name of leaf directory. 200 201 Returns: 202 Path to created directory. 203 """ 204 target = os.path.join(root, name) if name is not None else root 205 try: 206 os.makedirs(target, exist_ok=True) 207 except OSError as e: 208 if e.errno != errno.EEXIST or not os.path.isdir(target): 209 raise e 210 return target 211 212 213 def make_containing_dirs(path): 214 """ 215 Create the base directory for a given file path if it does not exist; also creates parent 216 directories. 217 """ 218 dir_name = os.path.dirname(path) 219 if not os.path.exists(dir_name): 220 os.makedirs(dir_name) 221 222 223 class TempDir: 224 def __init__(self, chdr=False, remove_on_exit=True): 225 self._dir = None 226 self._path = None 227 self._chdr = chdr 228 self._remove = remove_on_exit 229 230 def __enter__(self): 231 self._path = os.path.abspath(create_tmp_dir()) 232 assert os.path.exists(self._path) 233 if self._chdr: 234 self._dir = os.path.abspath(os.getcwd()) 235 os.chdir(self._path) 236 return self 237 238 def __exit__(self, tp, val, traceback): 239 if self._chdr and self._dir: 240 os.chdir(self._dir) 241 self._dir = None 242 if self._remove and os.path.exists(self._path): 243 shutil.rmtree(self._path) 244 245 assert not self._remove or not os.path.exists(self._path) 246 247 def path(self, *path): 248 return os.path.join("./", *path) if self._chdr else os.path.join(self._path, *path) 249 250 251 def read_file_lines(parent_path, file_name): 252 """Return the contents of the file as an array where each element is a separate line. 253 254 Args: 255 parent_path: Full path to the directory that contains the file. 256 file_name: Leaf file name. 257 258 Returns: 259 All lines in the file as an array. 260 261 """ 262 file_path = os.path.join(parent_path, file_name) 263 with codecs.open(file_path, mode="r", encoding=ENCODING) as f: 264 return f.readlines() 265 266 267 def read_file(parent_path, file_name): 268 """Return the contents of the file. 269 270 Args: 271 parent_path: Full path to the directory that contains the file. 272 file_name: Leaf file name. 273 274 Returns: 275 The contents of the file. 276 277 """ 278 file_path = os.path.join(parent_path, file_name) 279 with codecs.open(file_path, mode="r", encoding=ENCODING) as f: 280 return f.read() 281 282 283 def get_file_info(path, rel_path): 284 """Returns file meta data : location, size, ... etc 285 286 Args: 287 path: Path to artifact. 288 rel_path: Relative path. 289 290 Returns: 291 `FileInfo` object 292 """ 293 if is_directory(path): 294 return FileInfo(rel_path, True, None) 295 else: 296 return FileInfo(rel_path, False, os.path.getsize(path)) 297 298 299 def mv(target, new_parent): 300 shutil.move(target, new_parent) 301 302 303 def write_to(filename, data): 304 with codecs.open(filename, mode="w", encoding=ENCODING) as handle: 305 handle.write(data) 306 307 308 def append_to(filename, data): 309 with open(filename, "a") as handle: 310 handle.write(data) 311 312 313 def make_tarfile(output_filename, source_dir, archive_name, custom_filter=None): 314 # Helper for filtering out modification timestamps 315 def _filter_timestamps(tar_info): 316 tar_info.mtime = 0 317 return tar_info if custom_filter is None else custom_filter(tar_info) 318 319 unzipped_file_handle, unzipped_filename = tempfile.mkstemp() 320 try: 321 with tarfile.open(unzipped_filename, "w") as tar: 322 tar.add(source_dir, arcname=archive_name, filter=_filter_timestamps) 323 # When gzipping the tar, don't include the tar's filename or modification time in the 324 # zipped archive (see https://docs.python.org/3/library/gzip.html#gzip.GzipFile) 325 with ( 326 gzip.GzipFile( 327 filename="", fileobj=open(output_filename, "wb"), mode="wb", mtime=0 328 ) as gzipped_tar, 329 open(unzipped_filename, "rb") as tar, 330 ): 331 gzipped_tar.write(tar.read()) 332 finally: 333 os.close(unzipped_file_handle) 334 335 336 def _copy_project(src_path, dst_path=""): 337 """Internal function used to copy MLflow project during development. 338 339 Copies the content of the whole directory tree except patterns defined in .dockerignore. 340 The MLflow is assumed to be accessible as a local directory in this case. 341 342 Args: 343 src_path: Path to the original MLflow project 344 dst_path: MLflow will be copied here 345 346 Returns: 347 Name of the MLflow project directory. 348 """ 349 350 def _docker_ignore(mlflow_root): 351 docker_ignore = os.path.join(mlflow_root, ".dockerignore") 352 patterns = [] 353 if os.path.exists(docker_ignore): 354 with open(docker_ignore) as f: 355 patterns = [x.strip() for x in f] 356 357 def ignore(_, names): 358 res = set() 359 for p in patterns: 360 res.update(set(fnmatch.filter(names, p))) 361 return list(res) 362 363 return ignore if patterns else None 364 365 mlflow_dir = "mlflow-project" 366 # check if we have project root 367 assert os.path.isfile(os.path.join(src_path, "pyproject.toml")), "file not found " + str( 368 os.path.abspath(os.path.join(src_path, "pyproject.toml")) 369 ) 370 shutil.copytree(src_path, os.path.join(dst_path, mlflow_dir), ignore=_docker_ignore(src_path)) 371 return mlflow_dir 372 373 374 def _copy_file_or_tree(src, dst, dst_dir=None): 375 """ 376 Returns: 377 The path to the copied artifacts, relative to `dst`. 378 """ 379 dst_subpath = os.path.basename(os.path.abspath(src)) 380 if dst_dir is not None: 381 dst_subpath = os.path.join(dst_dir, dst_subpath) 382 dst_path = os.path.join(dst, dst_subpath) 383 if os.path.isfile(src): 384 dst_dirpath = os.path.dirname(dst_path) 385 if not os.path.exists(dst_dirpath): 386 os.makedirs(dst_dirpath) 387 shutil.copy(src=src, dst=dst_path) 388 else: 389 shutil.copytree(src=src, dst=dst_path, ignore=shutil.ignore_patterns("__pycache__")) 390 return dst_subpath 391 392 393 def _get_local_project_dir_size(project_path): 394 """Internal function for reporting the size of a local project directory before copying to 395 destination for cli logging reporting to stdout. 396 397 Args: 398 project_path: local path of the project directory 399 400 Returns: 401 directory file sizes in KB, rounded to single decimal point for legibility 402 """ 403 404 total_size = 0 405 for root, _, files in os.walk(project_path): 406 for f in files: 407 path = os.path.join(root, f) 408 total_size += os.path.getsize(path) 409 return round(total_size / 1024.0, 1) 410 411 412 def _get_local_file_size(file): 413 """ 414 Get the size of a local file in KB 415 """ 416 return round(os.path.getsize(file) / 1024.0, 1) 417 418 419 def get_parent_dir(path): 420 return os.path.abspath(os.path.join(path, os.pardir)) 421 422 423 def relative_path_to_artifact_path(path): 424 if os.path == posixpath: 425 return path 426 if os.path.abspath(path) == path: 427 raise Exception("This method only works with relative paths.") 428 return unquote(pathname2url(path)) 429 430 431 def path_to_local_file_uri(path): 432 """ 433 Convert local filesystem path to local file uri. 434 """ 435 return pathlib.Path(os.path.abspath(path)).as_uri() 436 437 438 def path_to_local_sqlite_uri(path): 439 """ 440 Convert local filesystem path to sqlite uri. 441 """ 442 path = posixpath.abspath(pathname2url(os.path.abspath(path))) 443 prefix = "sqlite://" if sys.platform == "win32" else "sqlite:///" 444 return prefix + path 445 446 447 def local_file_uri_to_path(uri): 448 """ 449 Convert URI to local filesystem path. 450 No-op if the uri does not have the expected scheme. 451 """ 452 path = uri 453 if uri.startswith("file:"): 454 parsed_path = urllib.parse.urlparse(uri) 455 path = parsed_path.path 456 # Fix for retaining server name in UNC path. 457 if is_windows() and parsed_path.netloc: 458 return urllib.request.url2pathname(rf"\\{parsed_path.netloc}{path}") 459 return urllib.request.url2pathname(path) 460 461 462 def get_local_path_or_none(path_or_uri): 463 """Check if the argument is a local path (no scheme or file:///) and return local path if true, 464 None otherwise. 465 """ 466 parsed_uri = urllib.parse.urlparse(path_or_uri) 467 if len(parsed_uri.scheme) == 0 or parsed_uri.scheme == "file" and len(parsed_uri.netloc) == 0: 468 return local_file_uri_to_path(path_or_uri) 469 else: 470 return None 471 472 473 def download_file_using_http_uri(http_uri, download_path, chunk_size=100000000, headers=None): 474 """ 475 Downloads a file specified using the `http_uri` to a local `download_path`. This function 476 uses a `chunk_size` to ensure an OOM error is not raised a large file is downloaded. 477 478 Note : This function is meant to download files using presigned urls from various cloud 479 providers. 480 """ 481 if headers is None: 482 headers = {} 483 with cloud_storage_http_request("get", http_uri, stream=True, headers=headers) as response: 484 augmented_raise_for_status(response) 485 with open(download_path, "wb") as output_file: 486 for chunk in response.iter_content(chunk_size=chunk_size): 487 if not chunk: 488 break 489 output_file.write(chunk) 490 491 492 @dataclass(frozen=True) 493 class _Chunk: 494 index: int 495 start: int 496 end: int 497 path: str 498 499 500 def _yield_chunks(path, file_size, chunk_size): 501 num_requests = int(math.ceil(file_size / float(chunk_size))) 502 for i in range(num_requests): 503 range_start = i * chunk_size 504 range_end = min(range_start + chunk_size - 1, file_size - 1) 505 yield _Chunk(i, range_start, range_end, path) 506 507 508 def parallelized_download_file_using_http_uri( 509 thread_pool_executor, 510 http_uri, 511 download_path, 512 remote_file_path, 513 file_size, 514 uri_type, 515 chunk_size, 516 env, 517 headers=None, 518 ): 519 """ 520 Downloads a file specified using the `http_uri` to a local `download_path`. This function 521 sends multiple requests in parallel each specifying its own desired byte range as a header, 522 then reconstructs the file from the downloaded chunks. This allows for downloads of large files 523 without OOM risk. 524 525 Note : This function is meant to download files using presigned urls from various cloud 526 providers. 527 Returns a dict of chunk index : exception, if one was thrown for that index. 528 """ 529 530 def run_download(chunk: _Chunk): 531 try: 532 subprocess.run( 533 [ 534 sys.executable, 535 download_cloud_file_chunk.__file__, 536 "--range-start", 537 str(chunk.start), 538 "--range-end", 539 str(chunk.end), 540 "--headers", 541 json.dumps(headers or {}), 542 "--download-path", 543 download_path, 544 "--http-uri", 545 http_uri, 546 ], 547 text=True, 548 check=True, 549 capture_output=True, 550 timeout=MLFLOW_DOWNLOAD_CHUNK_TIMEOUT.get(), 551 env=env, 552 ) 553 except (TimeoutExpired, CalledProcessError) as e: 554 raise MlflowException( 555 f""" 556 ----- stdout ----- 557 {e.stdout.strip()} 558 559 ----- stderr ----- 560 {e.stderr.strip()} 561 """ 562 ) from e 563 564 chunks = _yield_chunks(remote_file_path, file_size, chunk_size) 565 # Create file if it doesn't exist or erase the contents if it does. We should do this here 566 # before sending to the workers so they can each individually seek to their respective positions 567 # and write chunks without overwriting. 568 with open(download_path, "w"): 569 pass 570 if uri_type == ArtifactCredentialType.GCP_SIGNED_URL or uri_type is None: 571 chunk = next(chunks) 572 # GCP files could be transcoded, in which case the range header is ignored. 573 # Test if this is the case by downloading one chunk and seeing if it's larger than the 574 # requested size. If yes, let that be the file; if not, continue downloading more chunks. 575 download_chunk( 576 range_start=chunk.start, 577 range_end=chunk.end, 578 headers=headers, 579 download_path=download_path, 580 http_uri=http_uri, 581 ) 582 downloaded_size = os.path.getsize(download_path) 583 # If downloaded size was equal to the chunk size it would have been downloaded serially, 584 # so we don't need to consider this here 585 if downloaded_size > chunk_size: 586 return {} 587 588 futures = {thread_pool_executor.submit(run_download, chunk): chunk for chunk in chunks} 589 failed_downloads = {} 590 with ArtifactProgressBar.chunks(file_size, f"Downloading {download_path}", chunk_size) as pbar: 591 for future in as_completed(futures): 592 chunk = futures[future] 593 try: 594 future.result() 595 except Exception as e: 596 _logger.debug( 597 f"Failed to download chunk {chunk.index} for {chunk.path}: {e}. " 598 f"The download of this chunk will be retried later." 599 ) 600 failed_downloads[chunk] = future.exception() 601 else: 602 pbar.update() 603 604 return failed_downloads 605 606 607 def download_chunk_retries(*, chunks, http_uri, headers, download_path): 608 num_retries = _MLFLOW_MPD_NUM_RETRIES.get() 609 interval = _MLFLOW_MPD_RETRY_INTERVAL_SECONDS.get() 610 for chunk in chunks: 611 _logger.info(f"Retrying download of chunk {chunk.index} for {chunk.path}") 612 for retry in range(num_retries): 613 try: 614 download_chunk( 615 range_start=chunk.start, 616 range_end=chunk.end, 617 headers=headers, 618 download_path=download_path, 619 http_uri=http_uri, 620 ) 621 _logger.info(f"Successfully downloaded chunk {chunk.index} for {chunk.path}") 622 break 623 except Exception: 624 if retry == num_retries - 1: 625 raise 626 time.sleep(interval) 627 628 629 def _handle_readonly_on_windows(func, path, exc_info): 630 """ 631 This function should not be called directly but should be passed to `onerror` of 632 `shutil.rmtree` in order to reattempt the removal of a read-only file after making 633 it writable on Windows. 634 635 References: 636 - https://bugs.python.org/issue19643 637 - https://bugs.python.org/issue43657 638 """ 639 exc_type, exc_value = exc_info[:2] 640 should_reattempt = ( 641 is_windows() 642 and func in (os.unlink, os.rmdir) 643 and issubclass(exc_type, PermissionError) 644 and exc_value.winerror == 5 645 ) 646 if not should_reattempt: 647 raise exc_value 648 os.chmod(path, stat.S_IWRITE) 649 func(path) 650 651 652 def _get_tmp_dir(): 653 from mlflow.utils.databricks_utils import get_repl_id, is_in_databricks_runtime 654 655 if is_in_databricks_runtime(): 656 try: 657 return get_databricks_local_temp_dir() 658 except Exception: 659 pass 660 661 if repl_id := get_repl_id(): 662 return os.path.join("/tmp", "repl_tmp_data", repl_id) 663 664 return None 665 666 667 def create_tmp_dir(): 668 if directory := _get_tmp_dir(): 669 os.makedirs(directory, exist_ok=True) 670 return tempfile.mkdtemp(dir=directory) 671 672 return tempfile.mkdtemp() 673 674 675 @cache_return_value_per_process 676 def get_or_create_tmp_dir(): 677 """ 678 Get or create a temporary directory which will be removed once python process exit. 679 """ 680 from mlflow.utils.databricks_utils import get_repl_id, is_in_databricks_runtime 681 682 if is_in_databricks_runtime() and get_repl_id() is not None: 683 # Note: For python process attached to databricks notebook, atexit does not work. 684 # The directory returned by `get_databricks_local_tmp_dir` 685 # will be removed once databricks notebook detaches. 686 # The temp directory is designed to be used by all kinds of applications, 687 # so create a child directory "mlflow" for storing mlflow temp data. 688 try: 689 repl_local_tmp_dir = get_databricks_local_temp_dir() 690 except Exception: 691 repl_local_tmp_dir = os.path.join("/tmp", "repl_tmp_data", get_repl_id()) 692 693 tmp_dir = os.path.join(repl_local_tmp_dir, "mlflow") 694 os.makedirs(tmp_dir, exist_ok=True) 695 else: 696 tmp_dir = tempfile.mkdtemp() 697 # mkdtemp creates a directory with permission 0o700 698 # For Spark UDFs, we need to make it accessible to other processes 699 # Use 0o750 (owner: rwx, group: r-x, others: None) instead of 0o777 700 # This allows read/execute but not write for group and others 701 os.chmod(tmp_dir, 0o750) 702 atexit.register(shutil.rmtree, tmp_dir, ignore_errors=True) 703 704 return tmp_dir 705 706 707 @cache_return_value_per_process 708 def get_or_create_nfs_tmp_dir(): 709 """ 710 Get or create a temporary NFS directory which will be removed once python process exit. 711 """ 712 from mlflow.utils.databricks_utils import get_repl_id, is_in_databricks_runtime 713 from mlflow.utils.nfs_on_spark import get_nfs_cache_root_dir 714 715 nfs_root_dir = get_nfs_cache_root_dir() 716 717 if is_in_databricks_runtime() and get_repl_id() is not None: 718 # Note: In databricks, atexit hook does not work. 719 # The directory returned by `get_databricks_nfs_tmp_dir` 720 # will be removed once databricks notebook detaches. 721 # The temp directory is designed to be used by all kinds of applications, 722 # so create a child directory "mlflow" for storing mlflow temp data. 723 try: 724 repl_nfs_tmp_dir = get_databricks_nfs_temp_dir() 725 except Exception: 726 repl_nfs_tmp_dir = os.path.join(nfs_root_dir, "repl_tmp_data", get_repl_id()) 727 728 tmp_nfs_dir = os.path.join(repl_nfs_tmp_dir, "mlflow") 729 os.makedirs(tmp_nfs_dir, exist_ok=True) 730 else: 731 tmp_nfs_dir = tempfile.mkdtemp(dir=nfs_root_dir) 732 # mkdtemp creates a directory with permission 0o700 733 # For Spark UDFs, we need to make it accessible to other processes 734 # Use 0o750 (owner: rwx, group: r-x, others: None) instead of 0o777 735 os.chmod(tmp_nfs_dir, 0o750) 736 atexit.register(shutil.rmtree, tmp_nfs_dir, ignore_errors=True) 737 738 return tmp_nfs_dir 739 740 741 def shutil_copytree_without_file_permissions(src_dir, dst_dir): 742 """ 743 Copies the directory src_dir into dst_dir, without preserving filesystem permissions 744 """ 745 for dirpath, dirnames, filenames in os.walk(src_dir): 746 for dirname in dirnames: 747 relative_dir_path = os.path.relpath(os.path.join(dirpath, dirname), src_dir) 748 # For each directory <dirname> immediately under <dirpath>, create an equivalently-named 749 # directory under the destination directory 750 abs_dir_path = os.path.join(dst_dir, relative_dir_path) 751 if not os.path.exists(abs_dir_path): 752 os.mkdir(abs_dir_path) 753 for filename in filenames: 754 # For each file with name <filename> immediately under <dirpath>, copy that file to 755 # the appropriate location in the destination directory 756 file_path = os.path.join(dirpath, filename) 757 relative_file_path = os.path.relpath(file_path, src_dir) 758 abs_file_path = os.path.join(dst_dir, relative_file_path) 759 shutil.copy2(file_path, abs_file_path) 760 761 762 def contains_path_separator(path): 763 """ 764 Returns True if a path contains a path separator, False otherwise. 765 """ 766 return any((sep in path) for sep in (os.path.sep, os.path.altsep) if sep is not None) 767 768 769 def contains_percent(path): 770 """ 771 Returns True if a path contains a percent character, False otherwise. 772 """ 773 return "%" in path 774 775 776 def read_chunk(path: os.PathLike, size: int, start_byte: int = 0) -> bytes: 777 """Read a chunk of bytes from a file. 778 779 Args: 780 path: Path to the file. 781 size: The size of the chunk. 782 start_byte: The start byte of the chunk. 783 784 Returns: 785 The chunk of bytes. 786 787 """ 788 with open(path, "rb") as f: 789 if start_byte > 0: 790 f.seek(start_byte) 791 return f.read(size) 792 793 794 @contextmanager 795 def remove_on_error(path: os.PathLike, onerror=None): 796 """A context manager that removes a file or directory if an exception is raised during 797 execution. 798 799 Args: 800 path: Path to the file or directory. 801 onerror: A callback function that will be called with the captured exception before 802 the file or directory is removed. For example, you can use this callback to 803 log the exception. 804 805 """ 806 try: 807 yield 808 except Exception as e: 809 if onerror: 810 onerror(e) 811 if os.path.exists(path): 812 if os.path.isfile(path): 813 os.remove(path) 814 elif os.path.isdir(path): 815 shutil.rmtree(path) 816 _logger.warning( 817 f"Failed to remove {path}" if os.path.exists(path) else f"Successfully removed {path}" 818 ) 819 raise 820 821 822 def get_total_file_size(path: str | pathlib.Path) -> int | None: 823 """Return the size of all files under given path, including files in subdirectories. 824 825 Args: 826 path: The absolute path of a local directory. 827 828 Returns: 829 size in bytes. 830 831 """ 832 try: 833 if isinstance(path, pathlib.Path): 834 path = str(path) 835 if not os.path.exists(path): 836 raise MlflowException( 837 message=f"The given {path} does not exist.", error_code=INVALID_PARAMETER_VALUE 838 ) 839 if not os.path.isdir(path): 840 raise MlflowException( 841 message=f"The given {path} is not a directory.", error_code=INVALID_PARAMETER_VALUE 842 ) 843 844 total_size = 0 845 for cur_path, dirs, files in os.walk(path): 846 full_paths = [os.path.join(cur_path, file) for file in files] 847 total_size += sum(map(os.path.getsize, full_paths)) 848 return total_size 849 except Exception as e: 850 _logger.info(f"Failed to get the total size of {path} because of error :{e}") 851 return None 852 853 854 def write_yaml( 855 root: str, 856 file_name: str, 857 data: dict[str, Any], 858 overwrite: bool = False, 859 sort_keys: bool = True, 860 ensure_yaml_extension: bool = True, 861 ) -> None: 862 """ 863 NEVER TOUCH THIS FUNCTION. KEPT FOR BACKWARD COMPATIBILITY with 864 databricks-feature-engineering<=0.10.2 865 """ 866 import yaml 867 868 with open(os.path.join(root, file_name), "w") as f: 869 yaml.safe_dump( 870 data, 871 f, 872 default_flow_style=False, 873 allow_unicode=True, 874 sort_keys=sort_keys, 875 ) 876 877 878 def read_yaml(root: str, file_name: str) -> dict[str, Any]: 879 """ 880 NEVER TOUCH THIS FUNCTION. KEPT FOR BACKWARD COMPATIBILITY with 881 databricks-feature-engineering<=0.10.2 882 """ 883 import yaml 884 885 with open(os.path.join(root, file_name)) as f: 886 return yaml.safe_load(f) 887 888 889 class ExclusiveFileLock: 890 """ 891 Exclusive file lock (only works on Unix system) 892 """ 893 894 def __init__(self, path: str): 895 if os.name == "nt": 896 raise MlflowException("ExclusiveFileLock class does not support Windows system.") 897 self.path = path 898 self.fd = None 899 900 def __enter__(self) -> None: 901 # Python on Windows does not have `fcntl` module, so importing it lazily. 902 import fcntl # clint: disable=lazy-import 903 904 # Open file (create if missing) 905 self.fd = open(self.path, "w") 906 # Acquire exclusive lock (blocking) 907 fcntl.flock(self.fd, fcntl.LOCK_EX) 908 909 def __exit__( 910 self, 911 exc_type: type[BaseException] | None, 912 exc_val: BaseException | None, 913 exc_tb: TracebackType | None, 914 ): 915 # Python on Windows does not have `fcntl` module, so importing it lazily. 916 import fcntl # clint: disable=lazy-import 917 918 # Release lock 919 fcntl.flock(self.fd, fcntl.LOCK_UN) 920 self.fd.close() 921 922 923 def check_tarfile_security(archive_path: str) -> None: 924 """ 925 Check the tar file content. 926 If its members contain any of the following paths: 927 * An absolute path. 928 * A relative path that escapes the extraction directory. 929 * A relative path that goes through a symlink. 930 then raise an error. 931 """ 932 with tarfile.open(archive_path, "r") as tar: 933 symlink_set = set() 934 for m in tar.getmembers(): 935 # Normalize backslashes to forward slashes before path validation to prevent 936 # bypass on Windows where backslashes are treated as directory separators. 937 path = posixpath.normpath(m.name.replace("\\", "/")) 938 _check_path_is_safe(path) 939 if m.issym(): 940 symlink_set.add(path) 941 elif m.islnk(): 942 symlink_set.add(path) 943 # Hard link targets are dangerous: tar.extract creates an actual hard 944 # link to the target path, so validate they don't escape. 945 link_target = posixpath.normpath(m.linkname.replace("\\", "/")) 946 _check_path_is_safe(link_target, context=f"hard link target of {path}") 947 link_parent = posixpath.dirname(path) 948 resolved = posixpath.normpath(posixpath.join(link_parent, link_target)) 949 if resolved == ".." or resolved.startswith("../"): 950 raise MlflowException.invalid_parameter_value( 951 "Hard link target that escapes the extraction directory is not " 952 f"allowed, but got {path} -> {link_target}." 953 ) 954 for m in tar.getmembers(): 955 if not m.issym() and not m.islnk(): 956 path = posixpath.normpath(m.name.replace("\\", "/")) 957 path_parts = path.split("/") 958 for prefix_len in range(1, len(path_parts) + 1): 959 prefix_path = "/".join(path_parts[:prefix_len]) 960 if prefix_path in symlink_set: 961 raise MlflowException.invalid_parameter_value( 962 "Destination path in the archive file can not go through a symlink, " 963 f"but got path {path}." 964 ) 965 966 967 def _check_path_is_safe(path: str, context: str = "") -> None: 968 label = f" ({context})" if context else "" 969 # Reject Unix absolute paths 970 if path.startswith("/"): 971 raise MlflowException.invalid_parameter_value( 972 f"Absolute path destination in the archive file is not allowed, " 973 f"but got path {path}{label}." 974 ) 975 # Reject Windows drive-letter absolute paths (e.g., C:/...) and UNC paths (//server/...) 976 if len(path) >= 3 and path[0].isalpha() and path[1:3] == ":/": 977 raise MlflowException.invalid_parameter_value( 978 f"Absolute path destination in the archive file is not allowed, " 979 f"but got path {path}{label}." 980 ) 981 path_parts = path.split("/") 982 if path_parts[0] == "..": 983 raise MlflowException.invalid_parameter_value( 984 f"Escaped path destination in the archive file is not allowed, " 985 f"but got path {path}{label}." 986 )