/ tests / store / artifact / test_cli.py
test_cli.py
  1  import json
  2  import pathlib
  3  from unittest import mock
  4  
  5  import pytest
  6  from click.testing import CliRunner
  7  
  8  import mlflow
  9  from mlflow.entities import FileInfo
 10  from mlflow.store.artifact.cli import _file_infos_to_json, download_artifacts
 11  from mlflow.tracking.artifact_utils import _download_artifact_from_uri
 12  
 13  
 14  @pytest.fixture
 15  def run_with_artifact(tmp_path):
 16      artifact_path = "test"
 17      artifact_content = "content"
 18      local_path = tmp_path.joinpath("file.txt")
 19      local_path.write_text(artifact_content)
 20      with mlflow.start_run() as run:
 21          mlflow.log_artifact(local_path, artifact_path)
 22  
 23      return (run, artifact_path, artifact_content)
 24  
 25  
 26  def test_file_info_to_json():
 27      file_infos = [
 28          FileInfo("/my/file", False, 123),
 29          FileInfo("/my/dir", True, None),
 30      ]
 31      info_str = _file_infos_to_json(file_infos)
 32      assert json.loads(info_str) == [
 33          {"path": "/my/file", "is_dir": False, "file_size": 123},
 34          {"path": "/my/dir", "is_dir": True},
 35      ]
 36  
 37  
 38  def test_download_from_uri():
 39      class TestArtifactRepo:
 40          def __init__(self, scheme):
 41              self.scheme = scheme
 42  
 43          def download_artifacts(self, artifact_path, **kwargs):
 44              return (self.scheme, artifact_path)
 45  
 46      def test_get_artifact_repository(artifact_uri, tracking_uri=None, registry_uri=None):
 47          return TestArtifactRepo(artifact_uri)
 48  
 49      pairs = [
 50          ("path", ("", "path")),
 51          ("path/", ("path", "")),
 52          ("/path", ("/", "path")),
 53          ("/path/", ("/path", "")),
 54          ("path/to/dir", ("path/to", "dir")),
 55          ("file:", ("file:", "")),
 56          ("file:path", ("file:", "path")),
 57          ("file:path/", ("file:path", "")),
 58          ("file:path/to/dir", ("file:path/to", "dir")),
 59          ("file:/", ("file:///", "")),
 60          ("file:/path", ("file:///", "path")),
 61          ("file:/path/", ("file:///path", "")),
 62          ("file:/path/to/dir", ("file:///path/to", "dir")),
 63          ("file:///", ("file:///", "")),
 64          ("file:///path", ("file:///", "path")),
 65          ("file:///path/", ("file:///path", "")),
 66          ("file:///path/to/dir", ("file:///path/to", "dir")),
 67          ("s3://", ("s3:", "")),
 68          ("s3://path", ("s3://path", "")),  # path is netloc in this case
 69          ("s3://path/", ("s3://path/", "")),
 70          ("s3://path/to/", ("s3://path/to", "")),
 71          ("s3://path/to", ("s3://path/", "to")),
 72          ("s3://path/to/dir", ("s3://path/to", "dir")),
 73      ]
 74      with mock.patch(
 75          "mlflow.tracking.artifact_utils.get_artifact_repository"
 76      ) as get_artifact_repo_mock:
 77          get_artifact_repo_mock.side_effect = test_get_artifact_repository
 78  
 79          for uri, expected_result in pairs:
 80              actual_result = _download_artifact_from_uri(uri)
 81              assert expected_result == actual_result
 82  
 83  
 84  def _run_download_artifact_command(args) -> pathlib.Path:
 85      """
 86      Args:
 87          command: An `mlflow artifacts` command list.
 88  
 89      Returns:
 90          Path to the downloaded artifact.
 91      """
 92      runner = CliRunner()
 93      resp = runner.invoke(download_artifacts, args=args, catch_exceptions=False)
 94      assert resp.exit_code == 0
 95      download_output_path = resp.stdout.rstrip().split("\n")[-1]
 96      return next(pathlib.Path(download_output_path).iterdir())
 97  
 98  
 99  def test_download_artifacts_with_uri(run_with_artifact):
100      run, artifact_path, artifact_content = run_with_artifact
101      run_uri = f"runs:/{run.info.run_id}/{artifact_path}"
102      actual_uri = str(pathlib.PurePosixPath(run.info.artifact_uri) / artifact_path)
103      for uri in (run_uri, actual_uri):
104          downloaded_content = _run_download_artifact_command(["-u", uri]).read_text()
105          assert downloaded_content == artifact_content
106  
107      # Check for backwards compatibility with preexisting behavior in MLflow <= 1.24.0 where
108      # specifying `artifact_uri` and `artifact_path` together did not throw an exception (unlike
109      # `mlflow.artifacts.download_artifacts()`) and instead used `artifact_uri` while ignoring
110      # `run_id` and `artifact_path`
111      downloaded_content = _run_download_artifact_command([
112          "-u",
113          uri,
114          "--run-id",
115          "bad",
116          "--artifact-path",
117          "bad",
118      ]).read_text()
119      assert downloaded_content == artifact_content
120  
121  
122  def test_download_artifacts_with_run_id_and_path(run_with_artifact):
123      run, artifact_path, artifact_content = run_with_artifact
124      downloaded_content = _run_download_artifact_command([
125          "--run-id",
126          run.info.run_id,
127          "--artifact-path",
128          artifact_path,
129      ]).read_text()
130      assert downloaded_content == artifact_content
131  
132  
133  @pytest.mark.parametrize("dst_subdir_path", [None, "doesnt_exist_yet"])
134  def test_download_artifacts_with_dst_path(run_with_artifact, tmp_path, dst_subdir_path):
135      run, artifact_path, _ = run_with_artifact
136      artifact_uri = f"runs:/{run.info.run_id}/{artifact_path}"
137      dst_path = tmp_path / dst_subdir_path if dst_subdir_path else tmp_path
138      downloaded_file_path = _run_download_artifact_command(["-u", artifact_uri, "-d", str(dst_path)])
139      assert str(downloaded_file_path).startswith(str(dst_path))