test_artifact_utils.py
1 import os 2 from unittest import mock 3 from unittest.mock import ANY 4 from uuid import UUID 5 6 import pytest 7 8 import mlflow 9 from mlflow.exceptions import MlflowException 10 from mlflow.tracking.artifact_utils import ( 11 _download_artifact_from_uri, 12 _upload_artifact_to_uri, 13 _upload_artifacts_to_databricks, 14 ) 15 16 17 def test_artifact_can_be_downloaded_from_absolute_uri_successfully(tmp_path): 18 artifact_file_name = "artifact.txt" 19 artifact_text = "Sample artifact text" 20 local_artifact_path = tmp_path.joinpath(artifact_file_name) 21 local_artifact_path.write_text(artifact_text) 22 23 logged_artifact_path = "artifact" 24 with mlflow.start_run(): 25 mlflow.log_artifact(local_path=local_artifact_path, artifact_path=logged_artifact_path) 26 artifact_uri = mlflow.get_artifact_uri(artifact_path=logged_artifact_path) 27 28 downloaded_artifact_path = os.path.join( 29 _download_artifact_from_uri(artifact_uri), artifact_file_name 30 ) 31 assert downloaded_artifact_path != local_artifact_path 32 assert downloaded_artifact_path != logged_artifact_path 33 with open(downloaded_artifact_path) as f: 34 assert f.read() == artifact_text 35 36 37 def test_download_artifact_from_absolute_uri_persists_data_to_specified_output_directory(tmp_path): 38 artifact_file_name = "artifact.txt" 39 artifact_text = "Sample artifact text" 40 local_artifact_path = tmp_path.joinpath(artifact_file_name) 41 local_artifact_path.write_text(artifact_text) 42 43 logged_artifact_subdir = "logged_artifact" 44 with mlflow.start_run(): 45 mlflow.log_artifact(local_path=local_artifact_path, artifact_path=logged_artifact_subdir) 46 artifact_uri = mlflow.get_artifact_uri(artifact_path=logged_artifact_subdir) 47 48 artifact_output_path = tmp_path.joinpath("artifact_output") 49 artifact_output_path.mkdir() 50 _download_artifact_from_uri(artifact_uri=artifact_uri, output_path=artifact_output_path) 51 assert logged_artifact_subdir in os.listdir(artifact_output_path) 52 assert artifact_file_name in os.listdir( 53 os.path.join(artifact_output_path, logged_artifact_subdir) 54 ) 55 with open(os.path.join(artifact_output_path, logged_artifact_subdir, artifact_file_name)) as f: 56 assert f.read() == artifact_text 57 58 59 def test_download_artifact_with_special_characters_in_file_name_and_path(tmp_path): 60 artifact_file_name = " artifact_ with! special characters.txt" 61 artifact_sub_dir = " path with ! special characters" 62 artifact_text = "Sample artifact text" 63 local_sub_path = tmp_path.joinpath(artifact_sub_dir) 64 local_sub_path.mkdir() 65 66 local_artifact_path = os.path.join(local_sub_path, artifact_file_name) 67 with open(local_artifact_path, "w") as out: 68 out.write(artifact_text) 69 70 logged_artifact_subdir = "logged_artifact" 71 with mlflow.start_run(): 72 mlflow.log_artifact(local_path=local_artifact_path, artifact_path=logged_artifact_subdir) 73 artifact_uri = mlflow.get_artifact_uri(artifact_path=logged_artifact_subdir) 74 75 artifact_output_path = tmp_path.joinpath("artifact output path!") 76 artifact_output_path.mkdir() 77 _download_artifact_from_uri(artifact_uri=artifact_uri, output_path=artifact_output_path) 78 assert logged_artifact_subdir in os.listdir(artifact_output_path) 79 assert artifact_file_name in os.listdir( 80 os.path.join(artifact_output_path, logged_artifact_subdir) 81 ) 82 with open(os.path.join(artifact_output_path, logged_artifact_subdir, artifact_file_name)) as f: 83 assert f.read() == artifact_text 84 85 86 def test_download_artifact_invalid_uri_model_id(): 87 with pytest.raises( 88 MlflowException, 89 match="Invalid uri `m-dummy` is passed. Maybe you meant 'models:/m-dummy'?", 90 ): 91 _download_artifact_from_uri("m-dummy") 92 93 94 def test_upload_artifacts_to_databricks(): 95 import_root = "mlflow.tracking.artifact_utils" 96 with ( 97 mock.patch(import_root + "._download_artifact_from_uri") as download_mock, 98 mock.patch(import_root + ".DbfsRestArtifactRepository") as repo_mock, 99 ): 100 new_source = _upload_artifacts_to_databricks( 101 "dbfs:/original/sourcedir/", 102 "runid12345", 103 "databricks://tracking", 104 "databricks://registry:ws", 105 ) 106 download_mock.assert_called_once_with("dbfs://tracking@databricks/original/sourcedir/", ANY) 107 repo_mock.assert_called_once_with( 108 "dbfs://registry:ws@databricks/databricks/mlflow/tmp-external-source/" 109 ) 110 assert new_source == "dbfs:/databricks/mlflow/tmp-external-source/runid12345/sourcedir" 111 112 113 def test_upload_artifacts_to_databricks_no_run_id(): 114 import_root = "mlflow.tracking.artifact_utils" 115 with ( 116 mock.patch(import_root + "._download_artifact_from_uri") as download_mock, 117 mock.patch(import_root + ".DbfsRestArtifactRepository") as repo_mock, 118 mock.patch("uuid.uuid4", return_value=UUID("4f746cdcc0374da2808917e81bb53323")), 119 ): 120 new_source = _upload_artifacts_to_databricks( 121 "dbfs:/original/sourcedir/", None, "databricks://tracking:ws", "databricks://registry" 122 ) 123 download_mock.assert_called_once_with( 124 "dbfs://tracking:ws@databricks/original/sourcedir/", ANY 125 ) 126 repo_mock.assert_called_once_with( 127 "dbfs://registry@databricks/databricks/mlflow/tmp-external-source/" 128 ) 129 assert ( 130 new_source == "dbfs:/databricks/mlflow/tmp-external-source/" 131 "4f746cdcc0374da2808917e81bb53323/sourcedir" 132 ) 133 134 135 def test_upload_artifacts_to_uri(tmp_path): 136 artifact_file_name = "artifact.txt" 137 artifact_text = "Sample artifact text" 138 local_artifact_path = tmp_path.joinpath(artifact_file_name) 139 local_artifact_path.write_text(artifact_text) 140 141 with mlflow.start_run() as run: 142 mlflow.log_metric("coolness", 1) 143 144 artifact_uri = f"runs:/{run.info.run_id}/" 145 _upload_artifact_to_uri(local_artifact_path, artifact_uri) 146 downloaded_artifact_path = os.path.join( 147 _download_artifact_from_uri(artifact_uri), artifact_file_name 148 ) 149 with open(downloaded_artifact_path) as f: 150 assert f.read() == artifact_text