test_file_utils.py
1 import filecmp 2 import hashlib 3 import io 4 import os 5 import shutil 6 import stat 7 import tarfile 8 from pathlib import Path 9 10 import pytest 11 from pyspark.sql import SparkSession 12 13 import mlflow 14 from mlflow.exceptions import MlflowException 15 from mlflow.pyfunc.dbconnect_artifact_cache import extract_archive_to_dir 16 from mlflow.utils import file_utils 17 from mlflow.utils.file_utils import ( 18 TempDir, 19 _copy_file_or_tree, 20 _handle_readonly_on_windows, 21 check_tarfile_security, 22 get_parent_dir, 23 get_total_file_size, 24 local_file_uri_to_path, 25 ) 26 from mlflow.utils.os import is_windows 27 28 from tests.helper_functions import random_int 29 from tests.projects.utils import TEST_PROJECT_DIR 30 31 32 @pytest.fixture(scope="module") 33 def spark_session(): 34 with SparkSession.builder.master("local[*]").getOrCreate() as session: 35 yield session 36 37 38 def test_mkdir(tmp_path): 39 temp_dir = str(tmp_path) 40 new_dir_name = f"mkdir_test_{random_int()}" 41 file_utils.mkdir(temp_dir, new_dir_name) 42 assert os.listdir(temp_dir) == [new_dir_name] 43 44 with pytest.raises(OSError, match="bad directory"): 45 file_utils.mkdir("/ bad directory @ name ", "ouch") 46 47 # does not raise if directory exists already 48 file_utils.mkdir(temp_dir, new_dir_name) 49 50 # raises if it exists already but is a file 51 dummy_file_path = str(tmp_path.joinpath("dummy_file")) 52 with open(dummy_file_path, "a"): 53 pass 54 55 with pytest.raises(OSError, match="exists"): 56 file_utils.mkdir(dummy_file_path) 57 58 59 def test_make_tarfile(tmp_path): 60 # Tar a local project 61 tarfile0 = str(tmp_path.joinpath("first-tarfile")) 62 file_utils.make_tarfile( 63 output_filename=tarfile0, source_dir=TEST_PROJECT_DIR, archive_name="some-archive" 64 ) 65 # Copy local project into a temp dir 66 dst_dir = str(tmp_path.joinpath("project-directory")) 67 shutil.copytree(TEST_PROJECT_DIR, dst_dir) 68 # Tar the copied project 69 tarfile1 = str(tmp_path.joinpath("second-tarfile")) 70 file_utils.make_tarfile( 71 output_filename=tarfile1, source_dir=dst_dir, archive_name="some-archive" 72 ) 73 # Compare the archives & explicitly verify their SHA256 hashes match (i.e. that 74 # changes in file modification timestamps don't affect the archive contents) 75 assert filecmp.cmp(tarfile0, tarfile1, shallow=False) 76 with open(tarfile0, "rb") as first_tar, open(tarfile1, "rb") as second_tar: 77 assert ( 78 hashlib.sha256(first_tar.read()).hexdigest() 79 == hashlib.sha256(second_tar.read()).hexdigest() 80 ) 81 # Extract the TAR and check that its contents match the original directory 82 extract_dir = str(tmp_path.joinpath("extracted-tar")) 83 os.makedirs(extract_dir) 84 with tarfile.open(tarfile0, "r:gz") as handle: 85 handle.extractall(path=extract_dir) 86 dir_comparison = filecmp.dircmp(os.path.join(extract_dir, "some-archive"), TEST_PROJECT_DIR) 87 assert len(dir_comparison.left_only) == 0 88 assert len(dir_comparison.right_only) == 0 89 assert len(dir_comparison.diff_files) == 0 90 assert len(dir_comparison.funny_files) == 0 91 92 93 def test_get_parent_dir(tmp_path): 94 child_dir = tmp_path.joinpath("dir") 95 child_dir.mkdir() 96 assert str(tmp_path) == get_parent_dir(str(child_dir)) 97 98 99 def test_file_copy(): 100 with TempDir() as tmp: 101 file_path = tmp.path("test_file.txt") 102 copy_path = tmp.path("test_dir1/") 103 os.mkdir(copy_path) 104 with open(file_path, "a") as f: 105 f.write("testing") 106 _copy_file_or_tree(file_path, copy_path, "") 107 assert filecmp.cmp(file_path, os.path.join(copy_path, "test_file.txt")) 108 109 110 def test_dir_create(): 111 with TempDir() as tmp: 112 file_path = tmp.path("test_file.txt") 113 create_dir = tmp.path("test_dir2/") 114 with open(file_path, "a") as f: 115 f.write("testing") 116 name = _copy_file_or_tree(file_path, file_path, create_dir) 117 assert filecmp.cmp(file_path, name) 118 119 120 def test_dir_copy(): 121 with TempDir() as tmp: 122 dir_path = tmp.path("test_dir1/") 123 copy_path = tmp.path("test_dir2") 124 os.mkdir(dir_path) 125 with open(os.path.join(dir_path, "test_file.txt"), "a") as f: 126 f.write("testing") 127 _copy_file_or_tree(dir_path, copy_path, "") 128 assert filecmp.dircmp(dir_path, copy_path) 129 130 131 @pytest.mark.skipif(not is_windows(), reason="requires Windows") 132 def test_handle_readonly_on_windows(tmp_path): 133 tmp_path = tmp_path.joinpath("file") 134 with open(tmp_path, "w"): 135 pass 136 137 # Make the file read-only 138 os.chmod(tmp_path, stat.S_IREAD | stat.S_IRGRP | stat.S_IROTH) 139 # Ensure the file can't be removed 140 with pytest.raises(PermissionError, match="Access is denied") as exc: 141 os.unlink(tmp_path) 142 143 _handle_readonly_on_windows( 144 os.unlink, 145 tmp_path, 146 (exc.type, exc.value, exc.traceback), 147 ) 148 assert not os.path.exists(tmp_path) 149 150 151 @pytest.mark.skipif(not is_windows(), reason="This test only passes on Windows") 152 @pytest.mark.parametrize( 153 ("input_uri", "expected_path"), 154 [ 155 (r"\\my_server\my_path\my_sub_path", r"\\my_server\my_path\my_sub_path"), 156 ], 157 ) 158 def test_local_file_uri_to_path_on_windows(input_uri, expected_path): 159 assert local_file_uri_to_path(input_uri) == expected_path 160 161 162 def test_shutil_copytree_without_file_permissions(tmp_path): 163 src_dir = tmp_path.joinpath("src-dir") 164 src_dir.mkdir() 165 dst_dir = tmp_path.joinpath("dst-dir") 166 dst_dir.mkdir() 167 # Test copying empty directory 168 mlflow.utils.file_utils.shutil_copytree_without_file_permissions(src_dir, dst_dir) 169 assert len(os.listdir(dst_dir)) == 0 170 # Test copying directory with contents 171 src_dir.joinpath("subdir").mkdir() 172 src_dir.joinpath("subdir/subdir-file.txt").write_text("testing 123") 173 src_dir.joinpath("top-level-file.txt").write_text("hi") 174 mlflow.utils.file_utils.shutil_copytree_without_file_permissions(src_dir, dst_dir) 175 assert set(os.listdir(dst_dir)) == {"top-level-file.txt", "subdir"} 176 assert set(os.listdir(dst_dir.joinpath("subdir"))) == {"subdir-file.txt"} 177 assert dst_dir.joinpath("subdir/subdir-file.txt").read_text() == "testing 123" 178 assert dst_dir.joinpath("top-level-file.txt").read_text() == "hi" 179 180 181 def test_get_total_size_basic(tmp_path): 182 subdir = tmp_path.joinpath("subdir") 183 subdir.mkdir() 184 185 def generate_file(path, size_in_bytes): 186 with path.open("wb") as fp: 187 fp.write(b"\0" * size_in_bytes) 188 189 file_size_map = {"file1.txt": 11, "file2.txt": 23} 190 for name, size in file_size_map.items(): 191 generate_file(tmp_path.joinpath(name), size) 192 generate_file(subdir.joinpath("file3.txt"), 22) 193 assert get_total_file_size(tmp_path) == 56 194 assert get_total_file_size(subdir) == 22 195 196 path_not_exists = tmp_path.joinpath("does_not_exist") 197 assert get_total_file_size(path_not_exists) is None 198 199 path_file = tmp_path.joinpath("file1.txt") 200 assert get_total_file_size(path_file) is None 201 202 203 def test_check_tarfile_security(tmp_path): 204 def create_tar_with_escaped_path(tar_path: str, escaped_path: str, content: bytes) -> None: 205 """Create tar with path traversal entry.""" 206 with tarfile.open(tar_path, "w:gz") as tar: 207 # Add traversal file 208 data = io.BytesIO(content) 209 info = tarfile.TarInfo(name=escaped_path) 210 info.size = len(content) 211 tar.addfile(info, data) 212 213 tar1_path = str(tmp_path.joinpath("file1.tar")) 214 create_tar_with_escaped_path(tar1_path, "../pwned2.txt", b"ABX") 215 with pytest.raises( 216 MlflowException, match="Escaped path destination in the archive file is not allowed" 217 ): 218 check_tarfile_security(tar1_path) 219 220 def create_tar_with_symlink( 221 tar_path: str, link_name: str, link_target: str, file_via_link: str, content: bytes 222 ) -> None: 223 """Create tar with symlink that points outside, then file through symlink.""" 224 with tarfile.open(tar_path, "w:gz") as tar: 225 # First: create a symlink pointing to parent directory 226 link_info = tarfile.TarInfo(name=link_name) 227 link_info.type = tarfile.SYMTYPE 228 link_info.linkname = link_target 229 tar.addfile(link_info) 230 # Second: create a file that goes through the symlink 231 data = io.BytesIO(content) 232 file_info = tarfile.TarInfo(name=file_via_link) 233 file_info.size = len(content) 234 tar.addfile(file_info, data) 235 236 tar2_path = str(tmp_path.joinpath("file2.tar")) 237 create_tar_with_symlink( 238 tar2_path, 239 link_name="escape", 240 link_target="..", 241 file_via_link="escape/pwned.txt", 242 content=b"XYZ", 243 ) 244 with pytest.raises( 245 MlflowException, 246 match="Destination path in the archive file can not go through a symlink", 247 ): 248 check_tarfile_security(tar2_path) 249 250 def create_tar_with_abs_path(tar_path: str, abs_path: str, content: bytes) -> None: 251 """Create tar with path traversal entry.""" 252 with tarfile.open(tar_path, "w:gz") as tar: 253 # Add traversal file 254 data = io.BytesIO(content) 255 info = tarfile.TarInfo(name=abs_path) 256 info.size = len(content) 257 tar.addfile(info, data) 258 259 tar3_path = str(tmp_path.joinpath("file3.tar")) 260 create_tar_with_abs_path(tar3_path, "/tmp/pwned2.txt", b"ABX") 261 with pytest.raises( 262 MlflowException, match="Absolute path destination in the archive file is not allowed" 263 ): 264 check_tarfile_security(tar3_path) 265 266 # Symlink with safe target but file going through it 267 tar2b_path = str(tmp_path.joinpath("file2b.tar")) 268 create_tar_with_symlink( 269 tar2b_path, 270 link_name="link_dir", 271 link_target="subdir", 272 file_via_link="link_dir/pwned.txt", 273 content=b"XYZ", 274 ) 275 with pytest.raises( 276 MlflowException, match="Destination path in the archive file can not go through a symlink" 277 ): 278 check_tarfile_security(tar2b_path) 279 280 # Backslash-based path traversal in tar (Windows tar slip / path traversal) 281 tar4_path = str(tmp_path.joinpath("file4.tar")) 282 create_tar_with_escaped_path(tar4_path, "..\\..\\pwned.txt", b"ABX") 283 with pytest.raises( 284 MlflowException, match="Escaped path destination in the archive file is not allowed" 285 ): 286 check_tarfile_security(tar4_path) 287 288 def create_tar_with_symlink_only(tar_path: Path, link_name: str, link_target: str) -> None: 289 with tarfile.open(tar_path, "w:gz") as tar: 290 link_info = tarfile.TarInfo(name=link_name) 291 link_info.type = tarfile.SYMTYPE 292 link_info.linkname = link_target 293 tar.addfile(link_info) 294 295 # Symlinks with absolute/escaping targets are allowed (virtualenvs use them). 296 # Security is enforced by _safe_extractall at extraction time. 297 tar5_path = tmp_path / "file5.tar" 298 create_tar_with_symlink_only(tar5_path, "python", "/usr/bin/python3") 299 check_tarfile_security(tar5_path) # should not raise 300 301 tar6_path = tmp_path / "file6.tar" 302 create_tar_with_symlink_only(tar6_path, "lib", "../../shared/lib") 303 check_tarfile_security(tar6_path) # should not raise 304 305 # Symlink with absolute path as its own name 306 tar7_path = tmp_path / "file7.tar" 307 create_tar_with_symlink_only(tar7_path, "/tmp/escape", "foo") 308 with pytest.raises( 309 MlflowException, match="Absolute path destination in the archive file is not allowed" 310 ): 311 check_tarfile_security(tar7_path) 312 313 # Symlink whose name escapes with .. 314 tar8_path = tmp_path / "file8.tar" 315 create_tar_with_symlink_only(tar8_path, "../escape", "foo") 316 with pytest.raises( 317 MlflowException, match="Escaped path destination in the archive file is not allowed" 318 ): 319 check_tarfile_security(tar8_path) 320 321 # Hard link with escaping target 322 def create_tar_with_hardlink(tar_path: Path, name: str, linkname: str) -> None: 323 with tarfile.open(tar_path, "w:gz") as tar: 324 info = tarfile.TarInfo(name=name) 325 info.type = tarfile.LNKTYPE 326 info.linkname = linkname 327 tar.addfile(info) 328 329 tar9_path = tmp_path / "file9.tar" 330 create_tar_with_hardlink(tar9_path, "legit.txt", "../../etc/passwd") 331 with pytest.raises( 332 MlflowException, match="Escaped path destination in the archive file is not allowed" 333 ): 334 check_tarfile_security(tar9_path) 335 336 # Hard link with absolute target 337 tar10_path = tmp_path / "file10.tar" 338 create_tar_with_hardlink(tar10_path, "legit.txt", "/etc/passwd") 339 with pytest.raises( 340 MlflowException, match="Absolute path destination in the archive file is not allowed" 341 ): 342 check_tarfile_security(tar10_path) 343 344 # Windows drive-letter absolute path 345 tar11_path = tmp_path / "file11.tar" 346 create_tar_with_escaped_path(tar11_path, "C:/Windows/System32/evil.dll", b"ABX") 347 with pytest.raises( 348 MlflowException, match="Absolute path destination in the archive file is not allowed" 349 ): 350 check_tarfile_security(tar11_path) 351 352 353 def test_extract_archive_to_dir_blocks_traversal(tmp_path): 354 # Test that check_tarfile_security blocks path traversal 355 mal_tar = tmp_path / "malicious.tar.gz" 356 with tarfile.open(mal_tar, "w:gz") as tar: 357 info = tarfile.TarInfo("../../escape.txt") 358 data = b"owned via tar traversal" 359 info.size = len(data) 360 tar.addfile(info, fileobj=io.BytesIO(data)) 361 362 dest = tmp_path / "extracted" 363 escape_target = tmp_path.parent.parent / "escape.txt" 364 with pytest.raises(MlflowException, match="Escaped path destination in the archive file"): 365 extract_archive_to_dir(mal_tar, dest) 366 assert not escape_target.exists() 367 368 369 def test_safe_extractall_blocks_symlink_escape(tmp_path): 370 """Test that _safe_extractall blocks extraction when a filesystem symlink 371 inside dest_dir would cause a member to resolve outside dest_dir. 372 """ 373 from mlflow.pyfunc.dbconnect_artifact_cache import _safe_extractall 374 375 dest = tmp_path / "extracted" 376 dest.mkdir() 377 # Create a symlink inside dest_dir pointing outside 378 escape_link = dest / "escape_link" 379 escape_link.symlink_to(tmp_path.parent) 380 381 # Create a tar with a file that goes through the symlink 382 mal_tar = tmp_path / "symlink_escape.tar.gz" 383 with tarfile.open(mal_tar, "w:gz") as tar: 384 info = tarfile.TarInfo("escape_link/pwned.txt") 385 data = b"escaped via filesystem symlink" 386 info.size = len(data) 387 tar.addfile(info, fileobj=io.BytesIO(data)) 388 389 with tarfile.open(mal_tar, "r") as tar: 390 with pytest.raises(MlflowException, match="would be extracted outside"): 391 _safe_extractall(tar, dest) 392 assert not (tmp_path.parent / "pwned.txt").exists()