/ tests / utils / test_file_utils.py
test_file_utils.py
  1  import filecmp
  2  import hashlib
  3  import io
  4  import os
  5  import shutil
  6  import stat
  7  import tarfile
  8  from pathlib import Path
  9  
 10  import pytest
 11  from pyspark.sql import SparkSession
 12  
 13  import mlflow
 14  from mlflow.exceptions import MlflowException
 15  from mlflow.pyfunc.dbconnect_artifact_cache import extract_archive_to_dir
 16  from mlflow.utils import file_utils
 17  from mlflow.utils.file_utils import (
 18      TempDir,
 19      _copy_file_or_tree,
 20      _handle_readonly_on_windows,
 21      check_tarfile_security,
 22      get_parent_dir,
 23      get_total_file_size,
 24      local_file_uri_to_path,
 25  )
 26  from mlflow.utils.os import is_windows
 27  
 28  from tests.helper_functions import random_int
 29  from tests.projects.utils import TEST_PROJECT_DIR
 30  
 31  
 32  @pytest.fixture(scope="module")
 33  def spark_session():
 34      with SparkSession.builder.master("local[*]").getOrCreate() as session:
 35          yield session
 36  
 37  
 38  def test_mkdir(tmp_path):
 39      temp_dir = str(tmp_path)
 40      new_dir_name = f"mkdir_test_{random_int()}"
 41      file_utils.mkdir(temp_dir, new_dir_name)
 42      assert os.listdir(temp_dir) == [new_dir_name]
 43  
 44      with pytest.raises(OSError, match="bad directory"):
 45          file_utils.mkdir("/   bad directory @ name ", "ouch")
 46  
 47      # does not raise if directory exists already
 48      file_utils.mkdir(temp_dir, new_dir_name)
 49  
 50      # raises if it exists already but is a file
 51      dummy_file_path = str(tmp_path.joinpath("dummy_file"))
 52      with open(dummy_file_path, "a"):
 53          pass
 54  
 55      with pytest.raises(OSError, match="exists"):
 56          file_utils.mkdir(dummy_file_path)
 57  
 58  
 59  def test_make_tarfile(tmp_path):
 60      # Tar a local project
 61      tarfile0 = str(tmp_path.joinpath("first-tarfile"))
 62      file_utils.make_tarfile(
 63          output_filename=tarfile0, source_dir=TEST_PROJECT_DIR, archive_name="some-archive"
 64      )
 65      # Copy local project into a temp dir
 66      dst_dir = str(tmp_path.joinpath("project-directory"))
 67      shutil.copytree(TEST_PROJECT_DIR, dst_dir)
 68      # Tar the copied project
 69      tarfile1 = str(tmp_path.joinpath("second-tarfile"))
 70      file_utils.make_tarfile(
 71          output_filename=tarfile1, source_dir=dst_dir, archive_name="some-archive"
 72      )
 73      # Compare the archives & explicitly verify their SHA256 hashes match (i.e. that
 74      # changes in file modification timestamps don't affect the archive contents)
 75      assert filecmp.cmp(tarfile0, tarfile1, shallow=False)
 76      with open(tarfile0, "rb") as first_tar, open(tarfile1, "rb") as second_tar:
 77          assert (
 78              hashlib.sha256(first_tar.read()).hexdigest()
 79              == hashlib.sha256(second_tar.read()).hexdigest()
 80          )
 81      # Extract the TAR and check that its contents match the original directory
 82      extract_dir = str(tmp_path.joinpath("extracted-tar"))
 83      os.makedirs(extract_dir)
 84      with tarfile.open(tarfile0, "r:gz") as handle:
 85          handle.extractall(path=extract_dir)
 86      dir_comparison = filecmp.dircmp(os.path.join(extract_dir, "some-archive"), TEST_PROJECT_DIR)
 87      assert len(dir_comparison.left_only) == 0
 88      assert len(dir_comparison.right_only) == 0
 89      assert len(dir_comparison.diff_files) == 0
 90      assert len(dir_comparison.funny_files) == 0
 91  
 92  
 93  def test_get_parent_dir(tmp_path):
 94      child_dir = tmp_path.joinpath("dir")
 95      child_dir.mkdir()
 96      assert str(tmp_path) == get_parent_dir(str(child_dir))
 97  
 98  
 99  def test_file_copy():
100      with TempDir() as tmp:
101          file_path = tmp.path("test_file.txt")
102          copy_path = tmp.path("test_dir1/")
103          os.mkdir(copy_path)
104          with open(file_path, "a") as f:
105              f.write("testing")
106          _copy_file_or_tree(file_path, copy_path, "")
107          assert filecmp.cmp(file_path, os.path.join(copy_path, "test_file.txt"))
108  
109  
110  def test_dir_create():
111      with TempDir() as tmp:
112          file_path = tmp.path("test_file.txt")
113          create_dir = tmp.path("test_dir2/")
114          with open(file_path, "a") as f:
115              f.write("testing")
116          name = _copy_file_or_tree(file_path, file_path, create_dir)
117          assert filecmp.cmp(file_path, name)
118  
119  
120  def test_dir_copy():
121      with TempDir() as tmp:
122          dir_path = tmp.path("test_dir1/")
123          copy_path = tmp.path("test_dir2")
124          os.mkdir(dir_path)
125          with open(os.path.join(dir_path, "test_file.txt"), "a") as f:
126              f.write("testing")
127          _copy_file_or_tree(dir_path, copy_path, "")
128          assert filecmp.dircmp(dir_path, copy_path)
129  
130  
131  @pytest.mark.skipif(not is_windows(), reason="requires Windows")
132  def test_handle_readonly_on_windows(tmp_path):
133      tmp_path = tmp_path.joinpath("file")
134      with open(tmp_path, "w"):
135          pass
136  
137      # Make the file read-only
138      os.chmod(tmp_path, stat.S_IREAD | stat.S_IRGRP | stat.S_IROTH)
139      # Ensure the file can't be removed
140      with pytest.raises(PermissionError, match="Access is denied") as exc:
141          os.unlink(tmp_path)
142  
143      _handle_readonly_on_windows(
144          os.unlink,
145          tmp_path,
146          (exc.type, exc.value, exc.traceback),
147      )
148      assert not os.path.exists(tmp_path)
149  
150  
151  @pytest.mark.skipif(not is_windows(), reason="This test only passes on Windows")
152  @pytest.mark.parametrize(
153      ("input_uri", "expected_path"),
154      [
155          (r"\\my_server\my_path\my_sub_path", r"\\my_server\my_path\my_sub_path"),
156      ],
157  )
158  def test_local_file_uri_to_path_on_windows(input_uri, expected_path):
159      assert local_file_uri_to_path(input_uri) == expected_path
160  
161  
162  def test_shutil_copytree_without_file_permissions(tmp_path):
163      src_dir = tmp_path.joinpath("src-dir")
164      src_dir.mkdir()
165      dst_dir = tmp_path.joinpath("dst-dir")
166      dst_dir.mkdir()
167      # Test copying empty directory
168      mlflow.utils.file_utils.shutil_copytree_without_file_permissions(src_dir, dst_dir)
169      assert len(os.listdir(dst_dir)) == 0
170      # Test copying directory with contents
171      src_dir.joinpath("subdir").mkdir()
172      src_dir.joinpath("subdir/subdir-file.txt").write_text("testing 123")
173      src_dir.joinpath("top-level-file.txt").write_text("hi")
174      mlflow.utils.file_utils.shutil_copytree_without_file_permissions(src_dir, dst_dir)
175      assert set(os.listdir(dst_dir)) == {"top-level-file.txt", "subdir"}
176      assert set(os.listdir(dst_dir.joinpath("subdir"))) == {"subdir-file.txt"}
177      assert dst_dir.joinpath("subdir/subdir-file.txt").read_text() == "testing 123"
178      assert dst_dir.joinpath("top-level-file.txt").read_text() == "hi"
179  
180  
181  def test_get_total_size_basic(tmp_path):
182      subdir = tmp_path.joinpath("subdir")
183      subdir.mkdir()
184  
185      def generate_file(path, size_in_bytes):
186          with path.open("wb") as fp:
187              fp.write(b"\0" * size_in_bytes)
188  
189      file_size_map = {"file1.txt": 11, "file2.txt": 23}
190      for name, size in file_size_map.items():
191          generate_file(tmp_path.joinpath(name), size)
192      generate_file(subdir.joinpath("file3.txt"), 22)
193      assert get_total_file_size(tmp_path) == 56
194      assert get_total_file_size(subdir) == 22
195  
196      path_not_exists = tmp_path.joinpath("does_not_exist")
197      assert get_total_file_size(path_not_exists) is None
198  
199      path_file = tmp_path.joinpath("file1.txt")
200      assert get_total_file_size(path_file) is None
201  
202  
203  def test_check_tarfile_security(tmp_path):
204      def create_tar_with_escaped_path(tar_path: str, escaped_path: str, content: bytes) -> None:
205          """Create tar with path traversal entry."""
206          with tarfile.open(tar_path, "w:gz") as tar:
207              # Add traversal file
208              data = io.BytesIO(content)
209              info = tarfile.TarInfo(name=escaped_path)
210              info.size = len(content)
211              tar.addfile(info, data)
212  
213      tar1_path = str(tmp_path.joinpath("file1.tar"))
214      create_tar_with_escaped_path(tar1_path, "../pwned2.txt", b"ABX")
215      with pytest.raises(
216          MlflowException, match="Escaped path destination in the archive file is not allowed"
217      ):
218          check_tarfile_security(tar1_path)
219  
220      def create_tar_with_symlink(
221          tar_path: str, link_name: str, link_target: str, file_via_link: str, content: bytes
222      ) -> None:
223          """Create tar with symlink that points outside, then file through symlink."""
224          with tarfile.open(tar_path, "w:gz") as tar:
225              # First: create a symlink pointing to parent directory
226              link_info = tarfile.TarInfo(name=link_name)
227              link_info.type = tarfile.SYMTYPE
228              link_info.linkname = link_target
229              tar.addfile(link_info)
230              # Second: create a file that goes through the symlink
231              data = io.BytesIO(content)
232              file_info = tarfile.TarInfo(name=file_via_link)
233              file_info.size = len(content)
234              tar.addfile(file_info, data)
235  
236      tar2_path = str(tmp_path.joinpath("file2.tar"))
237      create_tar_with_symlink(
238          tar2_path,
239          link_name="escape",
240          link_target="..",
241          file_via_link="escape/pwned.txt",
242          content=b"XYZ",
243      )
244      with pytest.raises(
245          MlflowException,
246          match="Destination path in the archive file can not go through a symlink",
247      ):
248          check_tarfile_security(tar2_path)
249  
250      def create_tar_with_abs_path(tar_path: str, abs_path: str, content: bytes) -> None:
251          """Create tar with path traversal entry."""
252          with tarfile.open(tar_path, "w:gz") as tar:
253              # Add traversal file
254              data = io.BytesIO(content)
255              info = tarfile.TarInfo(name=abs_path)
256              info.size = len(content)
257              tar.addfile(info, data)
258  
259      tar3_path = str(tmp_path.joinpath("file3.tar"))
260      create_tar_with_abs_path(tar3_path, "/tmp/pwned2.txt", b"ABX")
261      with pytest.raises(
262          MlflowException, match="Absolute path destination in the archive file is not allowed"
263      ):
264          check_tarfile_security(tar3_path)
265  
266      # Symlink with safe target but file going through it
267      tar2b_path = str(tmp_path.joinpath("file2b.tar"))
268      create_tar_with_symlink(
269          tar2b_path,
270          link_name="link_dir",
271          link_target="subdir",
272          file_via_link="link_dir/pwned.txt",
273          content=b"XYZ",
274      )
275      with pytest.raises(
276          MlflowException, match="Destination path in the archive file can not go through a symlink"
277      ):
278          check_tarfile_security(tar2b_path)
279  
280      # Backslash-based path traversal in tar (Windows tar slip / path traversal)
281      tar4_path = str(tmp_path.joinpath("file4.tar"))
282      create_tar_with_escaped_path(tar4_path, "..\\..\\pwned.txt", b"ABX")
283      with pytest.raises(
284          MlflowException, match="Escaped path destination in the archive file is not allowed"
285      ):
286          check_tarfile_security(tar4_path)
287  
288      def create_tar_with_symlink_only(tar_path: Path, link_name: str, link_target: str) -> None:
289          with tarfile.open(tar_path, "w:gz") as tar:
290              link_info = tarfile.TarInfo(name=link_name)
291              link_info.type = tarfile.SYMTYPE
292              link_info.linkname = link_target
293              tar.addfile(link_info)
294  
295      # Symlinks with absolute/escaping targets are allowed (virtualenvs use them).
296      # Security is enforced by _safe_extractall at extraction time.
297      tar5_path = tmp_path / "file5.tar"
298      create_tar_with_symlink_only(tar5_path, "python", "/usr/bin/python3")
299      check_tarfile_security(tar5_path)  # should not raise
300  
301      tar6_path = tmp_path / "file6.tar"
302      create_tar_with_symlink_only(tar6_path, "lib", "../../shared/lib")
303      check_tarfile_security(tar6_path)  # should not raise
304  
305      # Symlink with absolute path as its own name
306      tar7_path = tmp_path / "file7.tar"
307      create_tar_with_symlink_only(tar7_path, "/tmp/escape", "foo")
308      with pytest.raises(
309          MlflowException, match="Absolute path destination in the archive file is not allowed"
310      ):
311          check_tarfile_security(tar7_path)
312  
313      # Symlink whose name escapes with ..
314      tar8_path = tmp_path / "file8.tar"
315      create_tar_with_symlink_only(tar8_path, "../escape", "foo")
316      with pytest.raises(
317          MlflowException, match="Escaped path destination in the archive file is not allowed"
318      ):
319          check_tarfile_security(tar8_path)
320  
321      # Hard link with escaping target
322      def create_tar_with_hardlink(tar_path: Path, name: str, linkname: str) -> None:
323          with tarfile.open(tar_path, "w:gz") as tar:
324              info = tarfile.TarInfo(name=name)
325              info.type = tarfile.LNKTYPE
326              info.linkname = linkname
327              tar.addfile(info)
328  
329      tar9_path = tmp_path / "file9.tar"
330      create_tar_with_hardlink(tar9_path, "legit.txt", "../../etc/passwd")
331      with pytest.raises(
332          MlflowException, match="Escaped path destination in the archive file is not allowed"
333      ):
334          check_tarfile_security(tar9_path)
335  
336      # Hard link with absolute target
337      tar10_path = tmp_path / "file10.tar"
338      create_tar_with_hardlink(tar10_path, "legit.txt", "/etc/passwd")
339      with pytest.raises(
340          MlflowException, match="Absolute path destination in the archive file is not allowed"
341      ):
342          check_tarfile_security(tar10_path)
343  
344      # Windows drive-letter absolute path
345      tar11_path = tmp_path / "file11.tar"
346      create_tar_with_escaped_path(tar11_path, "C:/Windows/System32/evil.dll", b"ABX")
347      with pytest.raises(
348          MlflowException, match="Absolute path destination in the archive file is not allowed"
349      ):
350          check_tarfile_security(tar11_path)
351  
352  
353  def test_extract_archive_to_dir_blocks_traversal(tmp_path):
354      # Test that check_tarfile_security blocks path traversal
355      mal_tar = tmp_path / "malicious.tar.gz"
356      with tarfile.open(mal_tar, "w:gz") as tar:
357          info = tarfile.TarInfo("../../escape.txt")
358          data = b"owned via tar traversal"
359          info.size = len(data)
360          tar.addfile(info, fileobj=io.BytesIO(data))
361  
362      dest = tmp_path / "extracted"
363      escape_target = tmp_path.parent.parent / "escape.txt"
364      with pytest.raises(MlflowException, match="Escaped path destination in the archive file"):
365          extract_archive_to_dir(mal_tar, dest)
366      assert not escape_target.exists()
367  
368  
369  def test_safe_extractall_blocks_symlink_escape(tmp_path):
370      """Test that _safe_extractall blocks extraction when a filesystem symlink
371      inside dest_dir would cause a member to resolve outside dest_dir.
372      """
373      from mlflow.pyfunc.dbconnect_artifact_cache import _safe_extractall
374  
375      dest = tmp_path / "extracted"
376      dest.mkdir()
377      # Create a symlink inside dest_dir pointing outside
378      escape_link = dest / "escape_link"
379      escape_link.symlink_to(tmp_path.parent)
380  
381      # Create a tar with a file that goes through the symlink
382      mal_tar = tmp_path / "symlink_escape.tar.gz"
383      with tarfile.open(mal_tar, "w:gz") as tar:
384          info = tarfile.TarInfo("escape_link/pwned.txt")
385          data = b"escaped via filesystem symlink"
386          info.size = len(data)
387          tar.addfile(info, fileobj=io.BytesIO(data))
388  
389      with tarfile.open(mal_tar, "r") as tar:
390          with pytest.raises(MlflowException, match="would be extracted outside"):
391              _safe_extractall(tar, dest)
392      assert not (tmp_path.parent / "pwned.txt").exists()