/ tests / tracking / test_artifact_utils.py
test_artifact_utils.py
  1  import os
  2  from unittest import mock
  3  from unittest.mock import ANY
  4  from uuid import UUID
  5  
  6  import pytest
  7  
  8  import mlflow
  9  from mlflow.exceptions import MlflowException
 10  from mlflow.tracking.artifact_utils import (
 11      _download_artifact_from_uri,
 12      _upload_artifact_to_uri,
 13      _upload_artifacts_to_databricks,
 14  )
 15  
 16  
 17  def test_artifact_can_be_downloaded_from_absolute_uri_successfully(tmp_path):
 18      artifact_file_name = "artifact.txt"
 19      artifact_text = "Sample artifact text"
 20      local_artifact_path = tmp_path.joinpath(artifact_file_name)
 21      local_artifact_path.write_text(artifact_text)
 22  
 23      logged_artifact_path = "artifact"
 24      with mlflow.start_run():
 25          mlflow.log_artifact(local_path=local_artifact_path, artifact_path=logged_artifact_path)
 26          artifact_uri = mlflow.get_artifact_uri(artifact_path=logged_artifact_path)
 27  
 28      downloaded_artifact_path = os.path.join(
 29          _download_artifact_from_uri(artifact_uri), artifact_file_name
 30      )
 31      assert downloaded_artifact_path != local_artifact_path
 32      assert downloaded_artifact_path != logged_artifact_path
 33      with open(downloaded_artifact_path) as f:
 34          assert f.read() == artifact_text
 35  
 36  
 37  def test_download_artifact_from_absolute_uri_persists_data_to_specified_output_directory(tmp_path):
 38      artifact_file_name = "artifact.txt"
 39      artifact_text = "Sample artifact text"
 40      local_artifact_path = tmp_path.joinpath(artifact_file_name)
 41      local_artifact_path.write_text(artifact_text)
 42  
 43      logged_artifact_subdir = "logged_artifact"
 44      with mlflow.start_run():
 45          mlflow.log_artifact(local_path=local_artifact_path, artifact_path=logged_artifact_subdir)
 46          artifact_uri = mlflow.get_artifact_uri(artifact_path=logged_artifact_subdir)
 47  
 48      artifact_output_path = tmp_path.joinpath("artifact_output")
 49      artifact_output_path.mkdir()
 50      _download_artifact_from_uri(artifact_uri=artifact_uri, output_path=artifact_output_path)
 51      assert logged_artifact_subdir in os.listdir(artifact_output_path)
 52      assert artifact_file_name in os.listdir(
 53          os.path.join(artifact_output_path, logged_artifact_subdir)
 54      )
 55      with open(os.path.join(artifact_output_path, logged_artifact_subdir, artifact_file_name)) as f:
 56          assert f.read() == artifact_text
 57  
 58  
 59  def test_download_artifact_with_special_characters_in_file_name_and_path(tmp_path):
 60      artifact_file_name = " artifact_ with! special  characters.txt"
 61      artifact_sub_dir = " path with ! special  characters"
 62      artifact_text = "Sample artifact text"
 63      local_sub_path = tmp_path.joinpath(artifact_sub_dir)
 64      local_sub_path.mkdir()
 65  
 66      local_artifact_path = os.path.join(local_sub_path, artifact_file_name)
 67      with open(local_artifact_path, "w") as out:
 68          out.write(artifact_text)
 69  
 70      logged_artifact_subdir = "logged_artifact"
 71      with mlflow.start_run():
 72          mlflow.log_artifact(local_path=local_artifact_path, artifact_path=logged_artifact_subdir)
 73          artifact_uri = mlflow.get_artifact_uri(artifact_path=logged_artifact_subdir)
 74  
 75      artifact_output_path = tmp_path.joinpath("artifact output path!")
 76      artifact_output_path.mkdir()
 77      _download_artifact_from_uri(artifact_uri=artifact_uri, output_path=artifact_output_path)
 78      assert logged_artifact_subdir in os.listdir(artifact_output_path)
 79      assert artifact_file_name in os.listdir(
 80          os.path.join(artifact_output_path, logged_artifact_subdir)
 81      )
 82      with open(os.path.join(artifact_output_path, logged_artifact_subdir, artifact_file_name)) as f:
 83          assert f.read() == artifact_text
 84  
 85  
 86  def test_download_artifact_invalid_uri_model_id():
 87      with pytest.raises(
 88          MlflowException,
 89          match="Invalid uri `m-dummy` is passed. Maybe you meant 'models:/m-dummy'?",
 90      ):
 91          _download_artifact_from_uri("m-dummy")
 92  
 93  
 94  def test_upload_artifacts_to_databricks():
 95      import_root = "mlflow.tracking.artifact_utils"
 96      with (
 97          mock.patch(import_root + "._download_artifact_from_uri") as download_mock,
 98          mock.patch(import_root + ".DbfsRestArtifactRepository") as repo_mock,
 99      ):
100          new_source = _upload_artifacts_to_databricks(
101              "dbfs:/original/sourcedir/",
102              "runid12345",
103              "databricks://tracking",
104              "databricks://registry:ws",
105          )
106          download_mock.assert_called_once_with("dbfs://tracking@databricks/original/sourcedir/", ANY)
107          repo_mock.assert_called_once_with(
108              "dbfs://registry:ws@databricks/databricks/mlflow/tmp-external-source/"
109          )
110          assert new_source == "dbfs:/databricks/mlflow/tmp-external-source/runid12345/sourcedir"
111  
112  
113  def test_upload_artifacts_to_databricks_no_run_id():
114      import_root = "mlflow.tracking.artifact_utils"
115      with (
116          mock.patch(import_root + "._download_artifact_from_uri") as download_mock,
117          mock.patch(import_root + ".DbfsRestArtifactRepository") as repo_mock,
118          mock.patch("uuid.uuid4", return_value=UUID("4f746cdcc0374da2808917e81bb53323")),
119      ):
120          new_source = _upload_artifacts_to_databricks(
121              "dbfs:/original/sourcedir/", None, "databricks://tracking:ws", "databricks://registry"
122          )
123          download_mock.assert_called_once_with(
124              "dbfs://tracking:ws@databricks/original/sourcedir/", ANY
125          )
126          repo_mock.assert_called_once_with(
127              "dbfs://registry@databricks/databricks/mlflow/tmp-external-source/"
128          )
129          assert (
130              new_source == "dbfs:/databricks/mlflow/tmp-external-source/"
131              "4f746cdcc0374da2808917e81bb53323/sourcedir"
132          )
133  
134  
135  def test_upload_artifacts_to_uri(tmp_path):
136      artifact_file_name = "artifact.txt"
137      artifact_text = "Sample artifact text"
138      local_artifact_path = tmp_path.joinpath(artifact_file_name)
139      local_artifact_path.write_text(artifact_text)
140  
141      with mlflow.start_run() as run:
142          mlflow.log_metric("coolness", 1)
143  
144      artifact_uri = f"runs:/{run.info.run_id}/"
145      _upload_artifact_to_uri(local_artifact_path, artifact_uri)
146      downloaded_artifact_path = os.path.join(
147          _download_artifact_from_uri(artifact_uri), artifact_file_name
148      )
149      with open(downloaded_artifact_path) as f:
150          assert f.read() == artifact_text