multipart.py
1 """ 2 Benchmark for multi-part upload and download of artifacts. 3 """ 4 5 import hashlib 6 import json 7 import os 8 import pathlib 9 import tempfile 10 from concurrent.futures import ThreadPoolExecutor, as_completed 11 12 import pandas as pd 13 import psutil 14 from tqdm.auto import tqdm 15 16 import mlflow 17 from mlflow.environment_variables import ( 18 MLFLOW_ENABLE_MULTIPART_DOWNLOAD, 19 MLFLOW_ENABLE_MULTIPART_UPLOAD, 20 ) 21 from mlflow.utils.time import Timer 22 23 GiB = 1024**3 24 25 26 def show_system_info(): 27 svmem = psutil.virtual_memory() 28 info = json.dumps( 29 { 30 "MLflow version": mlflow.__version__, 31 "MPU enabled": MLFLOW_ENABLE_MULTIPART_DOWNLOAD.get(), 32 "MPD enabled": MLFLOW_ENABLE_MULTIPART_UPLOAD.get(), 33 "CPU count": psutil.cpu_count(), 34 "Memory usage (total) [GiB]": svmem.total // GiB, 35 "Memory used [GiB]": svmem.used // GiB, 36 "Memory available [GiB]": svmem.available // GiB, 37 }, 38 indent=2, 39 ) 40 max_len = max(map(len, info.splitlines())) 41 print("=" * max_len) 42 print(info) 43 print("=" * max_len) 44 45 46 def md5_checksum(path): 47 file_hash = hashlib.sha256() 48 with open(path, "rb") as f: 49 while chunk := f.read(1024**2): 50 file_hash.update(chunk) 51 return file_hash.hexdigest() 52 53 54 def assert_checksum_equal(path1, path2): 55 assert md5_checksum(path1) == md5_checksum(path2), f"Checksum mismatch for {path1} and {path2}" 56 57 58 def yield_random_bytes(num_bytes): 59 while num_bytes > 0: 60 chunk_size = min(num_bytes, 1024**2) 61 yield os.urandom(chunk_size) 62 num_bytes -= chunk_size 63 64 65 def generate_random_file(path, num_bytes): 66 with open(path, "wb") as f: 67 for chunk in yield_random_bytes(num_bytes): 68 f.write(chunk) 69 70 71 def upload_and_download(file_size, num_files): 72 with tempfile.TemporaryDirectory() as tmpdir: 73 tmpdir = pathlib.Path(tmpdir) 74 75 # Prepare files 76 src_dir = tmpdir / "src" 77 src_dir.mkdir() 78 files = {} 79 with ThreadPoolExecutor() as pool: 80 futures = [] 81 for i in range(num_files): 82 f = src_dir / str(i) 83 futures.append(pool.submit(generate_random_file, f, file_size)) 84 files[f.name] = f 85 86 for fut in tqdm( 87 as_completed(futures), 88 total=len(futures), 89 desc="Generating files", 90 colour="#FFA500", 91 ): 92 fut.result() 93 94 # Upload 95 with mlflow.start_run() as run: 96 with Timer() as t_upload: 97 mlflow.log_artifacts(str(src_dir)) 98 99 # Download 100 dst_dir = tmpdir / "dst" 101 dst_dir.mkdir() 102 with Timer() as t_download: 103 mlflow.artifacts.download_artifacts( 104 artifact_uri=f"{run.info.artifact_uri}/", dst_path=dst_dir 105 ) 106 107 # Verify checksums 108 with ThreadPoolExecutor() as pool: 109 futures = [] 110 for f in dst_dir.rglob("*"): 111 if f.is_dir(): 112 continue 113 futures.append(pool.submit(assert_checksum_equal, f, files[f.name])) 114 115 for fut in tqdm( 116 as_completed(futures), 117 total=len(futures), 118 desc="Verifying checksums", 119 colour="#FFA500", 120 ): 121 fut.result() 122 123 return t_upload.elapsed, t_download.elapsed 124 125 126 def main(): 127 # Uncomment the following lines if you're running this script outside of Databricks 128 # using a personal access token: 129 # mlflow.set_tracking_uri("databricks") 130 # mlflow.set_experiment("/Users/<username>/benchmark") 131 132 FILE_SIZE = 1 * GiB 133 NUM_FILES = 2 134 NUM_ATTEMPTS = 3 135 136 show_system_info() 137 stats = [] 138 for i in range(NUM_ATTEMPTS): 139 print(f"Attempt {i + 1} / {NUM_ATTEMPTS}") 140 stats.append(upload_and_download(FILE_SIZE, NUM_FILES)) 141 142 df = pd.DataFrame(stats, columns=["upload [s]", "download [s]"]) 143 # show mean, min, max in markdown table 144 print(df.aggregate(["count", "mean", "min", "max"]).to_markdown()) 145 146 147 if __name__ == "__main__": 148 main()