test_cli.py
1 import json 2 import pathlib 3 from unittest import mock 4 5 import pytest 6 from click.testing import CliRunner 7 8 import mlflow 9 from mlflow.entities import FileInfo 10 from mlflow.store.artifact.cli import _file_infos_to_json, download_artifacts 11 from mlflow.tracking.artifact_utils import _download_artifact_from_uri 12 13 14 @pytest.fixture 15 def run_with_artifact(tmp_path): 16 artifact_path = "test" 17 artifact_content = "content" 18 local_path = tmp_path.joinpath("file.txt") 19 local_path.write_text(artifact_content) 20 with mlflow.start_run() as run: 21 mlflow.log_artifact(local_path, artifact_path) 22 23 return (run, artifact_path, artifact_content) 24 25 26 def test_file_info_to_json(): 27 file_infos = [ 28 FileInfo("/my/file", False, 123), 29 FileInfo("/my/dir", True, None), 30 ] 31 info_str = _file_infos_to_json(file_infos) 32 assert json.loads(info_str) == [ 33 {"path": "/my/file", "is_dir": False, "file_size": 123}, 34 {"path": "/my/dir", "is_dir": True}, 35 ] 36 37 38 def test_download_from_uri(): 39 class TestArtifactRepo: 40 def __init__(self, scheme): 41 self.scheme = scheme 42 43 def download_artifacts(self, artifact_path, **kwargs): 44 return (self.scheme, artifact_path) 45 46 def test_get_artifact_repository(artifact_uri, tracking_uri=None, registry_uri=None): 47 return TestArtifactRepo(artifact_uri) 48 49 pairs = [ 50 ("path", ("", "path")), 51 ("path/", ("path", "")), 52 ("/path", ("/", "path")), 53 ("/path/", ("/path", "")), 54 ("path/to/dir", ("path/to", "dir")), 55 ("file:", ("file:", "")), 56 ("file:path", ("file:", "path")), 57 ("file:path/", ("file:path", "")), 58 ("file:path/to/dir", ("file:path/to", "dir")), 59 ("file:/", ("file:///", "")), 60 ("file:/path", ("file:///", "path")), 61 ("file:/path/", ("file:///path", "")), 62 ("file:/path/to/dir", ("file:///path/to", "dir")), 63 ("file:///", ("file:///", "")), 64 ("file:///path", ("file:///", "path")), 65 ("file:///path/", ("file:///path", "")), 66 ("file:///path/to/dir", ("file:///path/to", "dir")), 67 ("s3://", ("s3:", "")), 68 ("s3://path", ("s3://path", "")), # path is netloc in this case 69 ("s3://path/", ("s3://path/", "")), 70 ("s3://path/to/", ("s3://path/to", "")), 71 ("s3://path/to", ("s3://path/", "to")), 72 ("s3://path/to/dir", ("s3://path/to", "dir")), 73 ] 74 with mock.patch( 75 "mlflow.tracking.artifact_utils.get_artifact_repository" 76 ) as get_artifact_repo_mock: 77 get_artifact_repo_mock.side_effect = test_get_artifact_repository 78 79 for uri, expected_result in pairs: 80 actual_result = _download_artifact_from_uri(uri) 81 assert expected_result == actual_result 82 83 84 def _run_download_artifact_command(args) -> pathlib.Path: 85 """ 86 Args: 87 command: An `mlflow artifacts` command list. 88 89 Returns: 90 Path to the downloaded artifact. 91 """ 92 runner = CliRunner() 93 resp = runner.invoke(download_artifacts, args=args, catch_exceptions=False) 94 assert resp.exit_code == 0 95 download_output_path = resp.stdout.rstrip().split("\n")[-1] 96 return next(pathlib.Path(download_output_path).iterdir()) 97 98 99 def test_download_artifacts_with_uri(run_with_artifact): 100 run, artifact_path, artifact_content = run_with_artifact 101 run_uri = f"runs:/{run.info.run_id}/{artifact_path}" 102 actual_uri = str(pathlib.PurePosixPath(run.info.artifact_uri) / artifact_path) 103 for uri in (run_uri, actual_uri): 104 downloaded_content = _run_download_artifact_command(["-u", uri]).read_text() 105 assert downloaded_content == artifact_content 106 107 # Check for backwards compatibility with preexisting behavior in MLflow <= 1.24.0 where 108 # specifying `artifact_uri` and `artifact_path` together did not throw an exception (unlike 109 # `mlflow.artifacts.download_artifacts()`) and instead used `artifact_uri` while ignoring 110 # `run_id` and `artifact_path` 111 downloaded_content = _run_download_artifact_command([ 112 "-u", 113 uri, 114 "--run-id", 115 "bad", 116 "--artifact-path", 117 "bad", 118 ]).read_text() 119 assert downloaded_content == artifact_content 120 121 122 def test_download_artifacts_with_run_id_and_path(run_with_artifact): 123 run, artifact_path, artifact_content = run_with_artifact 124 downloaded_content = _run_download_artifact_command([ 125 "--run-id", 126 run.info.run_id, 127 "--artifact-path", 128 artifact_path, 129 ]).read_text() 130 assert downloaded_content == artifact_content 131 132 133 @pytest.mark.parametrize("dst_subdir_path", [None, "doesnt_exist_yet"]) 134 def test_download_artifacts_with_dst_path(run_with_artifact, tmp_path, dst_subdir_path): 135 run, artifact_path, _ = run_with_artifact 136 artifact_uri = f"runs:/{run.info.run_id}/{artifact_path}" 137 dst_path = tmp_path / dst_subdir_path if dst_subdir_path else tmp_path 138 downloaded_file_path = _run_download_artifact_command(["-u", artifact_uri, "-d", str(dst_path)]) 139 assert str(downloaded_file_path).startswith(str(dst_path))