/ tests / tracking / test_mlflow_artifacts.py
test_mlflow_artifacts.py
  1  import cgi
  2  import os
  3  import pathlib
  4  import subprocess
  5  import tempfile
  6  from contextlib import contextmanager
  7  from io import BytesIO
  8  from typing import NamedTuple
  9  
 10  import pytest
 11  import requests
 12  
 13  import mlflow
 14  from mlflow import MlflowClient
 15  from mlflow.artifacts import download_artifacts
 16  from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore
 17  from mlflow.utils.os import is_windows
 18  
 19  from tests.helper_functions import LOCALHOST, get_safe_port, kill_process_tree
 20  from tests.tracking.integration_test_utils import _await_server_up_or_die
 21  
 22  
 23  @contextmanager
 24  def _launch_server(host, port, backend_store_uri, default_artifact_root, artifacts_destination):
 25      extra_cmd = [] if is_windows() else ["--gunicorn-opts", "--log-level debug"]
 26      cmd = [
 27          "mlflow",
 28          "server",
 29          "--host",
 30          host,
 31          "--port",
 32          str(port),
 33          "--backend-store-uri",
 34          backend_store_uri,
 35          "--default-artifact-root",
 36          default_artifact_root,
 37          "--artifacts-destination",
 38          artifacts_destination,
 39          *extra_cmd,
 40      ]
 41      with subprocess.Popen(cmd) as process:
 42          try:
 43              _await_server_up_or_die(port)
 44              yield process
 45          finally:
 46              kill_process_tree(process.pid)
 47  
 48  
 49  class ArtifactsServer(NamedTuple):
 50      backend_store_uri: str
 51      default_artifact_root: str
 52      artifacts_destination: str
 53      url: str
 54      process: subprocess.Popen
 55  
 56  
 57  @pytest.fixture(scope="module")
 58  def artifacts_server():
 59      with tempfile.TemporaryDirectory() as tmpdir:
 60          port = get_safe_port()
 61          backend_store_uri = f"sqlite:///{os.path.join(tmpdir, 'mlruns.db')}"
 62          artifacts_destination = os.path.join(tmpdir, "mlartifacts")
 63          url = f"http://{LOCALHOST}:{port}"
 64          default_artifact_root = f"{url}/api/2.0/mlflow-artifacts/artifacts"
 65          # Initialize the database before launching the server process
 66          s = SqlAlchemyStore(backend_store_uri, default_artifact_root)
 67          s.engine.dispose()
 68          with _launch_server(
 69              LOCALHOST,
 70              port,
 71              backend_store_uri,
 72              default_artifact_root,
 73              ("file:///" + artifacts_destination if is_windows() else artifacts_destination),
 74          ) as process:
 75              yield ArtifactsServer(
 76                  backend_store_uri, default_artifact_root, artifacts_destination, url, process
 77              )
 78  
 79  
 80  def read_file(path):
 81      with open(path) as f:
 82          return f.read()
 83  
 84  
 85  def upload_file(path, url, headers=None):
 86      with open(path, "rb") as f:
 87          requests.put(url, data=f, headers=headers).raise_for_status()
 88  
 89  
 90  def download_file(url, local_path, headers=None):
 91      with requests.get(url, stream=True, headers=headers) as r:
 92          r.raise_for_status()
 93          assert r.headers["X-Content-Type-Options"] == "nosniff"
 94          assert "Content-Type" in r.headers
 95          assert "Content-Disposition" in r.headers
 96          with open(local_path, "wb") as f:
 97              for chunk in r.iter_content(chunk_size=8192):
 98                  f.write(chunk)
 99          return r
100  
101  
102  def test_mlflow_artifacts_rest_apis(artifacts_server, tmp_path):
103      default_artifact_root = artifacts_server.default_artifact_root
104      artifacts_destination = artifacts_server.artifacts_destination
105  
106      # Upload artifacts
107      file_a = tmp_path.joinpath("a.txt")
108      file_a.write_text("0")
109      upload_file(file_a, f"{default_artifact_root}/a.txt")
110      assert os.path.exists(os.path.join(artifacts_destination, "a.txt"))
111      assert read_file(os.path.join(artifacts_destination, "a.txt")) == "0"
112  
113      file_b = tmp_path.joinpath("b.txt")
114      file_b.write_text("1")
115      upload_file(file_b, f"{default_artifact_root}/dir/b.txt")
116      assert os.path.join(artifacts_destination, "dir", "b.txt")
117      assert read_file(os.path.join(artifacts_destination, "dir", "b.txt")) == "1"
118  
119      # Download artifacts
120      local_dir = tmp_path.joinpath("folder")
121      local_dir.mkdir()
122      local_path_a = local_dir.joinpath("a.txt")
123      download_file(f"{default_artifact_root}/a.txt", local_path_a)
124      assert read_file(local_path_a) == "0"
125  
126      local_path_b = local_dir.joinpath("b.txt")
127      download_file(f"{default_artifact_root}/dir/b.txt", local_path_b)
128      assert read_file(local_path_b) == "1"
129  
130      # List artifacts
131      resp = requests.get(default_artifact_root)
132      assert resp.json() == {
133          "files": [
134              {"path": "a.txt", "is_dir": False, "file_size": 1},
135              {"path": "dir", "is_dir": True},
136          ]
137      }
138      resp = requests.get(default_artifact_root, params={"path": "dir"})
139      assert resp.json() == {"files": [{"path": "b.txt", "is_dir": False, "file_size": 1}]}
140  
141  
142  def test_log_artifact(artifacts_server, tmp_path):
143      url = artifacts_server.url
144      artifacts_destination = artifacts_server.artifacts_destination
145      mlflow.set_tracking_uri(url)
146  
147      tmp_path = tmp_path.joinpath("a.txt")
148      tmp_path.write_text("0")
149  
150      with mlflow.start_run() as run:
151          mlflow.log_artifact(tmp_path)
152  
153      experiment_id = "0"
154      run_artifact_root = os.path.join(
155          artifacts_destination, experiment_id, run.info.run_id, "artifacts"
156      )
157      dest_path = os.path.join(run_artifact_root, tmp_path.name)
158      assert os.path.exists(dest_path)
159      assert read_file(dest_path) == "0"
160  
161      with mlflow.start_run() as run:
162          mlflow.log_artifact(tmp_path, artifact_path="artifact_path")
163  
164      run_artifact_root = os.path.join(
165          artifacts_destination, experiment_id, run.info.run_id, "artifacts"
166      )
167      dest_path = os.path.join(run_artifact_root, "artifact_path", tmp_path.name)
168      assert os.path.exists(dest_path)
169      assert read_file(dest_path) == "0"
170  
171  
172  def test_log_artifacts(artifacts_server, tmp_path):
173      url = artifacts_server.url
174      mlflow.set_tracking_uri(url)
175  
176      tmp_path.joinpath("a.txt").write_text("0")
177      d = tmp_path.joinpath("dir")
178      d.mkdir()
179      d.joinpath("b.txt").write_text("1")
180  
181      with mlflow.start_run() as run:
182          mlflow.log_artifacts(tmp_path)
183  
184      client = MlflowClient()
185      artifacts = [a.path for a in client.list_artifacts(run.info.run_id)]
186      assert sorted(artifacts) == ["a.txt", "dir"]
187      artifacts = [a.path for a in client.list_artifacts(run.info.run_id, "dir")]
188      assert artifacts == ["dir/b.txt"]
189  
190      # With `artifact_path`
191      with mlflow.start_run() as run:
192          mlflow.log_artifacts(tmp_path, artifact_path="artifact_path")
193  
194      artifacts = [a.path for a in client.list_artifacts(run.info.run_id)]
195      assert artifacts == ["artifact_path"]
196      artifacts = [a.path for a in client.list_artifacts(run.info.run_id, "artifact_path")]
197      assert sorted(artifacts) == ["artifact_path/a.txt", "artifact_path/dir"]
198      artifacts = [a.path for a in client.list_artifacts(run.info.run_id, "artifact_path/dir")]
199      assert artifacts == ["artifact_path/dir/b.txt"]
200  
201  
202  def test_list_artifacts(artifacts_server, tmp_path):
203      url = artifacts_server.url
204      mlflow.set_tracking_uri(url)
205  
206      tmp_path_a = tmp_path.joinpath("a.txt")
207      tmp_path_a.write_text("0")
208      tmp_path_b = tmp_path.joinpath("b.txt")
209      tmp_path_b.write_text("1")
210      client = MlflowClient()
211      with mlflow.start_run() as run:
212          assert client.list_artifacts(run.info.run_id) == []
213          mlflow.log_artifact(tmp_path_a)
214          mlflow.log_artifact(tmp_path_b, "dir")
215  
216      artifacts = [a.path for a in client.list_artifacts(run.info.run_id)]
217      assert sorted(artifacts) == ["a.txt", "dir"]
218      artifacts = [a.path for a in client.list_artifacts(run.info.run_id, "dir")]
219      assert artifacts == ["dir/b.txt"]
220  
221  
222  def test_download_artifacts(artifacts_server, tmp_path):
223      url = artifacts_server.url
224      mlflow.set_tracking_uri(url)
225  
226      tmp_path_a = tmp_path.joinpath("a.txt")
227      tmp_path_a.write_text("0")
228      tmp_path_b = tmp_path.joinpath("b.txt")
229      tmp_path_b.write_text("1")
230      with mlflow.start_run() as run:
231          mlflow.log_artifact(tmp_path_a)
232          mlflow.log_artifact(tmp_path_b, "dir")
233  
234      dest_path = download_artifacts(run_id=run.info.run_id, artifact_path="")
235      assert sorted(os.listdir(dest_path)) == ["a.txt", "dir"]
236      assert read_file(os.path.join(dest_path, "a.txt")) == "0"
237      dest_path = download_artifacts(run_id=run.info.run_id, artifact_path="dir")
238      assert os.listdir(dest_path) == ["b.txt"]
239      assert read_file(os.path.join(dest_path, "b.txt")) == "1"
240  
241  
242  def is_github_actions():
243      return "GITHUB_ACTIONS" in os.environ
244  
245  
246  @pytest.mark.skipif(is_windows(), reason="This example doesn't work on Windows")
247  def test_mlflow_artifacts_example(tmp_path):
248      root = pathlib.Path(mlflow.__file__).parents[1]
249      # On GitHub Actions, remove generated images to save disk space
250      rmi_option = "--rmi all" if is_github_actions() else ""
251      cmd = f"""
252  err=0
253  trap 'err=1' ERR
254  ./build.sh
255  docker compose run -v ${{PWD}}/example.py:/app/example.py client python example.py
256  docker compose logs
257  docker compose down {rmi_option} --volumes --remove-orphans
258  test $err = 0
259  """
260      script_path = tmp_path.joinpath("test.sh")
261      script_path.write_text(cmd)
262      subprocess.run(
263          ["bash", script_path],
264          check=True,
265          cwd=os.path.join(root, "examples", "mlflow_artifacts"),
266      )
267  
268  
269  def test_rest_tracking_api_list_artifacts_with_proxied_artifacts(artifacts_server, tmp_path):
270      def list_artifacts_via_rest_api(url, run_id, path=None):
271          if path:
272              resp = requests.get(url, params={"run_id": run_id, "path": path})
273          else:
274              resp = requests.get(url, params={"run_id": run_id})
275          resp.raise_for_status()
276          return resp.json()
277  
278      url = artifacts_server.url
279      mlflow.set_tracking_uri(url)
280      api = f"{url}/api/2.0/mlflow/artifacts/list"
281  
282      tmp_path_a = tmp_path.joinpath("a.txt")
283      tmp_path_a.write_text("0")
284      tmp_path_b = tmp_path.joinpath("b.txt")
285      tmp_path_b.write_text("1")
286      mlflow.set_experiment("rest_list_api_test")
287      with mlflow.start_run() as run:
288          mlflow.log_artifact(tmp_path_a)
289          mlflow.log_artifact(tmp_path_b, "dir")
290  
291      list_artifacts_response = list_artifacts_via_rest_api(url=api, run_id=run.info.run_id)
292      assert list_artifacts_response.get("files") == [
293          {"path": "a.txt", "is_dir": False, "file_size": 1},
294          {"path": "dir", "is_dir": True},
295      ]
296      assert list_artifacts_response.get("root_uri") == run.info.artifact_uri
297  
298      nested_list_artifacts_response = list_artifacts_via_rest_api(
299          url=api, run_id=run.info.run_id, path="dir"
300      )
301      assert nested_list_artifacts_response.get("files") == [
302          {"path": "dir/b.txt", "is_dir": False, "file_size": 1},
303      ]
304      assert list_artifacts_response.get("root_uri") == run.info.artifact_uri
305  
306  
307  def test_rest_get_artifact_api_proxied_with_artifacts(artifacts_server, tmp_path):
308      url = artifacts_server.url
309      mlflow.set_tracking_uri(url)
310      tmp_path_a = tmp_path.joinpath("a.txt")
311      tmp_path_a.write_text("abcdefg")
312  
313      mlflow.set_experiment("rest_get_artifact_api_test")
314      with mlflow.start_run() as run:
315          mlflow.log_artifact(tmp_path_a)
316  
317      get_artifact_response = requests.get(
318          url=f"{url}/get-artifact", params={"run_id": run.info.run_id, "path": "a.txt"}
319      )
320      get_artifact_response.raise_for_status()
321      assert get_artifact_response.text == "abcdefg"
322  
323  
324  def test_rest_get_model_version_artifact_api_proxied_artifact_root(artifacts_server):
325      url = artifacts_server.url
326      artifact_file = pathlib.Path(artifacts_server.artifacts_destination, "a.txt")
327      artifact_file.parent.mkdir(exist_ok=True, parents=True)
328      artifact_file.write_text("abcdefg")
329  
330      name = "GetModelVersionTest"
331      mlflow_client = MlflowClient(artifacts_server.backend_store_uri)
332      mlflow_client.create_registered_model(name)
333      # An artifact root with scheme http, https, or mlflow-artifacts is a proxied artifact root
334      mlflow_client.create_model_version(name, "mlflow-artifacts:", 1)
335  
336      get_model_version_artifact_response = requests.get(
337          url=f"{url}/model-versions/get-artifact",
338          params={"name": name, "version": "1", "path": "a.txt"},
339      )
340      get_model_version_artifact_response.raise_for_status()
341      assert get_model_version_artifact_response.text == "abcdefg"
342  
343  
344  @pytest.mark.parametrize(
345      ("filename", "expected_mime_type"),
346      [
347          ("a.txt", "text/plain"),
348          ("b.pkl", "application/octet-stream"),
349          ("c.png", "image/png"),
350          ("d.pdf", "application/pdf"),
351          ("MLmodel", "text/plain"),
352          ("mlproject", "text/plain"),
353      ],
354  )
355  def test_mime_type_for_download_artifacts_api(
356      artifacts_server, tmp_path, filename, expected_mime_type
357  ):
358      default_artifact_root = artifacts_server.default_artifact_root
359      url = artifacts_server.url
360      test_file = tmp_path.joinpath(filename)
361      test_file.touch()
362      upload_file(test_file, f"{default_artifact_root}/dir/{filename}")
363      download_response = download_file(f"{default_artifact_root}/dir/{filename}", test_file)
364  
365      _, params = cgi.parse_header(download_response.headers["Content-Disposition"])
366      assert params["filename"] == filename
367      assert download_response.headers["Content-Type"] == expected_mime_type
368  
369      mlflow.set_tracking_uri(url)
370      with mlflow.start_run() as run:
371          mlflow.log_artifact(test_file)
372      artifact_response = requests.get(
373          url=f"{url}/get-artifact", params={"run_id": run.info.run_id, "path": filename}
374      )
375      artifact_response.raise_for_status()
376      _, params = cgi.parse_header(artifact_response.headers["Content-Disposition"])
377      assert params["filename"] == filename
378      assert artifact_response.headers["Content-Type"] == expected_mime_type
379      assert artifact_response.headers["X-Content-Type-Options"] == "nosniff"
380  
381  
382  def test_rest_get_artifact_api_log_image(artifacts_server):
383      url = artifacts_server.url
384      mlflow.set_tracking_uri(url)
385  
386      import numpy as np
387      from PIL import Image
388  
389      image = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8)
390  
391      with mlflow.start_run() as run:
392          mlflow.log_image(image, key="dog", step=20, timestamp=100, synchronous=True)
393  
394      artifact_list_response = requests.get(
395          url=f"{url}/ajax-api/2.0/mlflow/artifacts/list",
396          params={"path": "images", "run_id": run.info.run_id},
397      )
398      artifact_list_response.raise_for_status()
399  
400      for file in artifact_list_response.json()["files"]:
401          path = file["path"]
402          get_artifact_response = requests.get(
403              url=f"{url}/get-artifact", params={"run_id": run.info.run_id, "path": path}
404          )
405          get_artifact_response.raise_for_status()
406          assert (
407              "attachment; filename=dog+step+20+timestamp+100"
408              in get_artifact_response.headers["Content-Disposition"]
409          )
410          if path.endswith("png"):
411              loaded_image = np.asarray(
412                  Image.open(BytesIO(get_artifact_response.content)), dtype=np.uint8
413              )
414              np.testing.assert_array_equal(loaded_image, image)
415  
416  
417  @pytest.mark.parametrize(
418      ("filename", "requested_mime_type", "responded_mime_type"),
419      [
420          ("b.pkl", "text/html", "application/octet-stream"),
421          ("c.png", "text/html", "image/png"),
422          ("d.pdf", "text/html", "application/pdf"),
423      ],
424  )
425  def test_server_overrides_requested_mime_type(
426      artifacts_server, tmp_path, filename, requested_mime_type, responded_mime_type
427  ):
428      default_artifact_root = artifacts_server.default_artifact_root
429      test_file = tmp_path.joinpath(filename)
430      test_file.touch()
431      upload_file(
432          test_file,
433          f"{default_artifact_root}/dir/{filename}",
434      )
435      download_response = download_file(
436          f"{default_artifact_root}/dir/{filename}",
437          test_file,
438          headers={"Accept": requested_mime_type},
439      )
440  
441      _, params = cgi.parse_header(download_response.headers["Content-Disposition"])
442      assert params["filename"] == filename
443      assert download_response.headers["Content-Type"] == responded_mime_type