test_mlflow_artifacts.py
1 import cgi 2 import os 3 import pathlib 4 import subprocess 5 import tempfile 6 from contextlib import contextmanager 7 from io import BytesIO 8 from typing import NamedTuple 9 10 import pytest 11 import requests 12 13 import mlflow 14 from mlflow import MlflowClient 15 from mlflow.artifacts import download_artifacts 16 from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore 17 from mlflow.utils.os import is_windows 18 19 from tests.helper_functions import LOCALHOST, get_safe_port, kill_process_tree 20 from tests.tracking.integration_test_utils import _await_server_up_or_die 21 22 23 @contextmanager 24 def _launch_server(host, port, backend_store_uri, default_artifact_root, artifacts_destination): 25 extra_cmd = [] if is_windows() else ["--gunicorn-opts", "--log-level debug"] 26 cmd = [ 27 "mlflow", 28 "server", 29 "--host", 30 host, 31 "--port", 32 str(port), 33 "--backend-store-uri", 34 backend_store_uri, 35 "--default-artifact-root", 36 default_artifact_root, 37 "--artifacts-destination", 38 artifacts_destination, 39 *extra_cmd, 40 ] 41 with subprocess.Popen(cmd) as process: 42 try: 43 _await_server_up_or_die(port) 44 yield process 45 finally: 46 kill_process_tree(process.pid) 47 48 49 class ArtifactsServer(NamedTuple): 50 backend_store_uri: str 51 default_artifact_root: str 52 artifacts_destination: str 53 url: str 54 process: subprocess.Popen 55 56 57 @pytest.fixture(scope="module") 58 def artifacts_server(): 59 with tempfile.TemporaryDirectory() as tmpdir: 60 port = get_safe_port() 61 backend_store_uri = f"sqlite:///{os.path.join(tmpdir, 'mlruns.db')}" 62 artifacts_destination = os.path.join(tmpdir, "mlartifacts") 63 url = f"http://{LOCALHOST}:{port}" 64 default_artifact_root = f"{url}/api/2.0/mlflow-artifacts/artifacts" 65 # Initialize the database before launching the server process 66 s = SqlAlchemyStore(backend_store_uri, default_artifact_root) 67 s.engine.dispose() 68 with _launch_server( 69 LOCALHOST, 70 port, 71 backend_store_uri, 72 default_artifact_root, 73 ("file:///" + artifacts_destination if is_windows() else artifacts_destination), 74 ) as process: 75 yield ArtifactsServer( 76 backend_store_uri, default_artifact_root, artifacts_destination, url, process 77 ) 78 79 80 def read_file(path): 81 with open(path) as f: 82 return f.read() 83 84 85 def upload_file(path, url, headers=None): 86 with open(path, "rb") as f: 87 requests.put(url, data=f, headers=headers).raise_for_status() 88 89 90 def download_file(url, local_path, headers=None): 91 with requests.get(url, stream=True, headers=headers) as r: 92 r.raise_for_status() 93 assert r.headers["X-Content-Type-Options"] == "nosniff" 94 assert "Content-Type" in r.headers 95 assert "Content-Disposition" in r.headers 96 with open(local_path, "wb") as f: 97 for chunk in r.iter_content(chunk_size=8192): 98 f.write(chunk) 99 return r 100 101 102 def test_mlflow_artifacts_rest_apis(artifacts_server, tmp_path): 103 default_artifact_root = artifacts_server.default_artifact_root 104 artifacts_destination = artifacts_server.artifacts_destination 105 106 # Upload artifacts 107 file_a = tmp_path.joinpath("a.txt") 108 file_a.write_text("0") 109 upload_file(file_a, f"{default_artifact_root}/a.txt") 110 assert os.path.exists(os.path.join(artifacts_destination, "a.txt")) 111 assert read_file(os.path.join(artifacts_destination, "a.txt")) == "0" 112 113 file_b = tmp_path.joinpath("b.txt") 114 file_b.write_text("1") 115 upload_file(file_b, f"{default_artifact_root}/dir/b.txt") 116 assert os.path.join(artifacts_destination, "dir", "b.txt") 117 assert read_file(os.path.join(artifacts_destination, "dir", "b.txt")) == "1" 118 119 # Download artifacts 120 local_dir = tmp_path.joinpath("folder") 121 local_dir.mkdir() 122 local_path_a = local_dir.joinpath("a.txt") 123 download_file(f"{default_artifact_root}/a.txt", local_path_a) 124 assert read_file(local_path_a) == "0" 125 126 local_path_b = local_dir.joinpath("b.txt") 127 download_file(f"{default_artifact_root}/dir/b.txt", local_path_b) 128 assert read_file(local_path_b) == "1" 129 130 # List artifacts 131 resp = requests.get(default_artifact_root) 132 assert resp.json() == { 133 "files": [ 134 {"path": "a.txt", "is_dir": False, "file_size": 1}, 135 {"path": "dir", "is_dir": True}, 136 ] 137 } 138 resp = requests.get(default_artifact_root, params={"path": "dir"}) 139 assert resp.json() == {"files": [{"path": "b.txt", "is_dir": False, "file_size": 1}]} 140 141 142 def test_log_artifact(artifacts_server, tmp_path): 143 url = artifacts_server.url 144 artifacts_destination = artifacts_server.artifacts_destination 145 mlflow.set_tracking_uri(url) 146 147 tmp_path = tmp_path.joinpath("a.txt") 148 tmp_path.write_text("0") 149 150 with mlflow.start_run() as run: 151 mlflow.log_artifact(tmp_path) 152 153 experiment_id = "0" 154 run_artifact_root = os.path.join( 155 artifacts_destination, experiment_id, run.info.run_id, "artifacts" 156 ) 157 dest_path = os.path.join(run_artifact_root, tmp_path.name) 158 assert os.path.exists(dest_path) 159 assert read_file(dest_path) == "0" 160 161 with mlflow.start_run() as run: 162 mlflow.log_artifact(tmp_path, artifact_path="artifact_path") 163 164 run_artifact_root = os.path.join( 165 artifacts_destination, experiment_id, run.info.run_id, "artifacts" 166 ) 167 dest_path = os.path.join(run_artifact_root, "artifact_path", tmp_path.name) 168 assert os.path.exists(dest_path) 169 assert read_file(dest_path) == "0" 170 171 172 def test_log_artifacts(artifacts_server, tmp_path): 173 url = artifacts_server.url 174 mlflow.set_tracking_uri(url) 175 176 tmp_path.joinpath("a.txt").write_text("0") 177 d = tmp_path.joinpath("dir") 178 d.mkdir() 179 d.joinpath("b.txt").write_text("1") 180 181 with mlflow.start_run() as run: 182 mlflow.log_artifacts(tmp_path) 183 184 client = MlflowClient() 185 artifacts = [a.path for a in client.list_artifacts(run.info.run_id)] 186 assert sorted(artifacts) == ["a.txt", "dir"] 187 artifacts = [a.path for a in client.list_artifacts(run.info.run_id, "dir")] 188 assert artifacts == ["dir/b.txt"] 189 190 # With `artifact_path` 191 with mlflow.start_run() as run: 192 mlflow.log_artifacts(tmp_path, artifact_path="artifact_path") 193 194 artifacts = [a.path for a in client.list_artifacts(run.info.run_id)] 195 assert artifacts == ["artifact_path"] 196 artifacts = [a.path for a in client.list_artifacts(run.info.run_id, "artifact_path")] 197 assert sorted(artifacts) == ["artifact_path/a.txt", "artifact_path/dir"] 198 artifacts = [a.path for a in client.list_artifacts(run.info.run_id, "artifact_path/dir")] 199 assert artifacts == ["artifact_path/dir/b.txt"] 200 201 202 def test_list_artifacts(artifacts_server, tmp_path): 203 url = artifacts_server.url 204 mlflow.set_tracking_uri(url) 205 206 tmp_path_a = tmp_path.joinpath("a.txt") 207 tmp_path_a.write_text("0") 208 tmp_path_b = tmp_path.joinpath("b.txt") 209 tmp_path_b.write_text("1") 210 client = MlflowClient() 211 with mlflow.start_run() as run: 212 assert client.list_artifacts(run.info.run_id) == [] 213 mlflow.log_artifact(tmp_path_a) 214 mlflow.log_artifact(tmp_path_b, "dir") 215 216 artifacts = [a.path for a in client.list_artifacts(run.info.run_id)] 217 assert sorted(artifacts) == ["a.txt", "dir"] 218 artifacts = [a.path for a in client.list_artifacts(run.info.run_id, "dir")] 219 assert artifacts == ["dir/b.txt"] 220 221 222 def test_download_artifacts(artifacts_server, tmp_path): 223 url = artifacts_server.url 224 mlflow.set_tracking_uri(url) 225 226 tmp_path_a = tmp_path.joinpath("a.txt") 227 tmp_path_a.write_text("0") 228 tmp_path_b = tmp_path.joinpath("b.txt") 229 tmp_path_b.write_text("1") 230 with mlflow.start_run() as run: 231 mlflow.log_artifact(tmp_path_a) 232 mlflow.log_artifact(tmp_path_b, "dir") 233 234 dest_path = download_artifacts(run_id=run.info.run_id, artifact_path="") 235 assert sorted(os.listdir(dest_path)) == ["a.txt", "dir"] 236 assert read_file(os.path.join(dest_path, "a.txt")) == "0" 237 dest_path = download_artifacts(run_id=run.info.run_id, artifact_path="dir") 238 assert os.listdir(dest_path) == ["b.txt"] 239 assert read_file(os.path.join(dest_path, "b.txt")) == "1" 240 241 242 def is_github_actions(): 243 return "GITHUB_ACTIONS" in os.environ 244 245 246 @pytest.mark.skipif(is_windows(), reason="This example doesn't work on Windows") 247 def test_mlflow_artifacts_example(tmp_path): 248 root = pathlib.Path(mlflow.__file__).parents[1] 249 # On GitHub Actions, remove generated images to save disk space 250 rmi_option = "--rmi all" if is_github_actions() else "" 251 cmd = f""" 252 err=0 253 trap 'err=1' ERR 254 ./build.sh 255 docker compose run -v ${{PWD}}/example.py:/app/example.py client python example.py 256 docker compose logs 257 docker compose down {rmi_option} --volumes --remove-orphans 258 test $err = 0 259 """ 260 script_path = tmp_path.joinpath("test.sh") 261 script_path.write_text(cmd) 262 subprocess.run( 263 ["bash", script_path], 264 check=True, 265 cwd=os.path.join(root, "examples", "mlflow_artifacts"), 266 ) 267 268 269 def test_rest_tracking_api_list_artifacts_with_proxied_artifacts(artifacts_server, tmp_path): 270 def list_artifacts_via_rest_api(url, run_id, path=None): 271 if path: 272 resp = requests.get(url, params={"run_id": run_id, "path": path}) 273 else: 274 resp = requests.get(url, params={"run_id": run_id}) 275 resp.raise_for_status() 276 return resp.json() 277 278 url = artifacts_server.url 279 mlflow.set_tracking_uri(url) 280 api = f"{url}/api/2.0/mlflow/artifacts/list" 281 282 tmp_path_a = tmp_path.joinpath("a.txt") 283 tmp_path_a.write_text("0") 284 tmp_path_b = tmp_path.joinpath("b.txt") 285 tmp_path_b.write_text("1") 286 mlflow.set_experiment("rest_list_api_test") 287 with mlflow.start_run() as run: 288 mlflow.log_artifact(tmp_path_a) 289 mlflow.log_artifact(tmp_path_b, "dir") 290 291 list_artifacts_response = list_artifacts_via_rest_api(url=api, run_id=run.info.run_id) 292 assert list_artifacts_response.get("files") == [ 293 {"path": "a.txt", "is_dir": False, "file_size": 1}, 294 {"path": "dir", "is_dir": True}, 295 ] 296 assert list_artifacts_response.get("root_uri") == run.info.artifact_uri 297 298 nested_list_artifacts_response = list_artifacts_via_rest_api( 299 url=api, run_id=run.info.run_id, path="dir" 300 ) 301 assert nested_list_artifacts_response.get("files") == [ 302 {"path": "dir/b.txt", "is_dir": False, "file_size": 1}, 303 ] 304 assert list_artifacts_response.get("root_uri") == run.info.artifact_uri 305 306 307 def test_rest_get_artifact_api_proxied_with_artifacts(artifacts_server, tmp_path): 308 url = artifacts_server.url 309 mlflow.set_tracking_uri(url) 310 tmp_path_a = tmp_path.joinpath("a.txt") 311 tmp_path_a.write_text("abcdefg") 312 313 mlflow.set_experiment("rest_get_artifact_api_test") 314 with mlflow.start_run() as run: 315 mlflow.log_artifact(tmp_path_a) 316 317 get_artifact_response = requests.get( 318 url=f"{url}/get-artifact", params={"run_id": run.info.run_id, "path": "a.txt"} 319 ) 320 get_artifact_response.raise_for_status() 321 assert get_artifact_response.text == "abcdefg" 322 323 324 def test_rest_get_model_version_artifact_api_proxied_artifact_root(artifacts_server): 325 url = artifacts_server.url 326 artifact_file = pathlib.Path(artifacts_server.artifacts_destination, "a.txt") 327 artifact_file.parent.mkdir(exist_ok=True, parents=True) 328 artifact_file.write_text("abcdefg") 329 330 name = "GetModelVersionTest" 331 mlflow_client = MlflowClient(artifacts_server.backend_store_uri) 332 mlflow_client.create_registered_model(name) 333 # An artifact root with scheme http, https, or mlflow-artifacts is a proxied artifact root 334 mlflow_client.create_model_version(name, "mlflow-artifacts:", 1) 335 336 get_model_version_artifact_response = requests.get( 337 url=f"{url}/model-versions/get-artifact", 338 params={"name": name, "version": "1", "path": "a.txt"}, 339 ) 340 get_model_version_artifact_response.raise_for_status() 341 assert get_model_version_artifact_response.text == "abcdefg" 342 343 344 @pytest.mark.parametrize( 345 ("filename", "expected_mime_type"), 346 [ 347 ("a.txt", "text/plain"), 348 ("b.pkl", "application/octet-stream"), 349 ("c.png", "image/png"), 350 ("d.pdf", "application/pdf"), 351 ("MLmodel", "text/plain"), 352 ("mlproject", "text/plain"), 353 ], 354 ) 355 def test_mime_type_for_download_artifacts_api( 356 artifacts_server, tmp_path, filename, expected_mime_type 357 ): 358 default_artifact_root = artifacts_server.default_artifact_root 359 url = artifacts_server.url 360 test_file = tmp_path.joinpath(filename) 361 test_file.touch() 362 upload_file(test_file, f"{default_artifact_root}/dir/{filename}") 363 download_response = download_file(f"{default_artifact_root}/dir/{filename}", test_file) 364 365 _, params = cgi.parse_header(download_response.headers["Content-Disposition"]) 366 assert params["filename"] == filename 367 assert download_response.headers["Content-Type"] == expected_mime_type 368 369 mlflow.set_tracking_uri(url) 370 with mlflow.start_run() as run: 371 mlflow.log_artifact(test_file) 372 artifact_response = requests.get( 373 url=f"{url}/get-artifact", params={"run_id": run.info.run_id, "path": filename} 374 ) 375 artifact_response.raise_for_status() 376 _, params = cgi.parse_header(artifact_response.headers["Content-Disposition"]) 377 assert params["filename"] == filename 378 assert artifact_response.headers["Content-Type"] == expected_mime_type 379 assert artifact_response.headers["X-Content-Type-Options"] == "nosniff" 380 381 382 def test_rest_get_artifact_api_log_image(artifacts_server): 383 url = artifacts_server.url 384 mlflow.set_tracking_uri(url) 385 386 import numpy as np 387 from PIL import Image 388 389 image = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8) 390 391 with mlflow.start_run() as run: 392 mlflow.log_image(image, key="dog", step=20, timestamp=100, synchronous=True) 393 394 artifact_list_response = requests.get( 395 url=f"{url}/ajax-api/2.0/mlflow/artifacts/list", 396 params={"path": "images", "run_id": run.info.run_id}, 397 ) 398 artifact_list_response.raise_for_status() 399 400 for file in artifact_list_response.json()["files"]: 401 path = file["path"] 402 get_artifact_response = requests.get( 403 url=f"{url}/get-artifact", params={"run_id": run.info.run_id, "path": path} 404 ) 405 get_artifact_response.raise_for_status() 406 assert ( 407 "attachment; filename=dog+step+20+timestamp+100" 408 in get_artifact_response.headers["Content-Disposition"] 409 ) 410 if path.endswith("png"): 411 loaded_image = np.asarray( 412 Image.open(BytesIO(get_artifact_response.content)), dtype=np.uint8 413 ) 414 np.testing.assert_array_equal(loaded_image, image) 415 416 417 @pytest.mark.parametrize( 418 ("filename", "requested_mime_type", "responded_mime_type"), 419 [ 420 ("b.pkl", "text/html", "application/octet-stream"), 421 ("c.png", "text/html", "image/png"), 422 ("d.pdf", "text/html", "application/pdf"), 423 ], 424 ) 425 def test_server_overrides_requested_mime_type( 426 artifacts_server, tmp_path, filename, requested_mime_type, responded_mime_type 427 ): 428 default_artifact_root = artifacts_server.default_artifact_root 429 test_file = tmp_path.joinpath(filename) 430 test_file.touch() 431 upload_file( 432 test_file, 433 f"{default_artifact_root}/dir/{filename}", 434 ) 435 download_response = download_file( 436 f"{default_artifact_root}/dir/{filename}", 437 test_file, 438 headers={"Accept": requested_mime_type}, 439 ) 440 441 _, params = cgi.parse_header(download_response.headers["Content-Disposition"]) 442 assert params["filename"] == filename 443 assert download_response.headers["Content-Type"] == responded_mime_type