/ examples / databricks / multipart.py
multipart.py
  1  """
  2  Benchmark for multi-part upload and download of artifacts.
  3  """
  4  
  5  import hashlib
  6  import json
  7  import os
  8  import pathlib
  9  import tempfile
 10  from concurrent.futures import ThreadPoolExecutor, as_completed
 11  
 12  import pandas as pd
 13  import psutil
 14  from tqdm.auto import tqdm
 15  
 16  import mlflow
 17  from mlflow.environment_variables import (
 18      MLFLOW_ENABLE_MULTIPART_DOWNLOAD,
 19      MLFLOW_ENABLE_MULTIPART_UPLOAD,
 20  )
 21  from mlflow.utils.time import Timer
 22  
 23  GiB = 1024**3
 24  
 25  
 26  def show_system_info():
 27      svmem = psutil.virtual_memory()
 28      info = json.dumps(
 29          {
 30              "MLflow version": mlflow.__version__,
 31              "MPU enabled": MLFLOW_ENABLE_MULTIPART_DOWNLOAD.get(),
 32              "MPD enabled": MLFLOW_ENABLE_MULTIPART_UPLOAD.get(),
 33              "CPU count": psutil.cpu_count(),
 34              "Memory usage (total) [GiB]": svmem.total // GiB,
 35              "Memory used [GiB]": svmem.used // GiB,
 36              "Memory available [GiB]": svmem.available // GiB,
 37          },
 38          indent=2,
 39      )
 40      max_len = max(map(len, info.splitlines()))
 41      print("=" * max_len)
 42      print(info)
 43      print("=" * max_len)
 44  
 45  
 46  def md5_checksum(path):
 47      file_hash = hashlib.sha256()
 48      with open(path, "rb") as f:
 49          while chunk := f.read(1024**2):
 50              file_hash.update(chunk)
 51      return file_hash.hexdigest()
 52  
 53  
 54  def assert_checksum_equal(path1, path2):
 55      assert md5_checksum(path1) == md5_checksum(path2), f"Checksum mismatch for {path1} and {path2}"
 56  
 57  
 58  def yield_random_bytes(num_bytes):
 59      while num_bytes > 0:
 60          chunk_size = min(num_bytes, 1024**2)
 61          yield os.urandom(chunk_size)
 62          num_bytes -= chunk_size
 63  
 64  
 65  def generate_random_file(path, num_bytes):
 66      with open(path, "wb") as f:
 67          for chunk in yield_random_bytes(num_bytes):
 68              f.write(chunk)
 69  
 70  
 71  def upload_and_download(file_size, num_files):
 72      with tempfile.TemporaryDirectory() as tmpdir:
 73          tmpdir = pathlib.Path(tmpdir)
 74  
 75          # Prepare files
 76          src_dir = tmpdir / "src"
 77          src_dir.mkdir()
 78          files = {}
 79          with ThreadPoolExecutor() as pool:
 80              futures = []
 81              for i in range(num_files):
 82                  f = src_dir / str(i)
 83                  futures.append(pool.submit(generate_random_file, f, file_size))
 84                  files[f.name] = f
 85  
 86              for fut in tqdm(
 87                  as_completed(futures),
 88                  total=len(futures),
 89                  desc="Generating files",
 90                  colour="#FFA500",
 91              ):
 92                  fut.result()
 93  
 94          # Upload
 95          with mlflow.start_run() as run:
 96              with Timer() as t_upload:
 97                  mlflow.log_artifacts(str(src_dir))
 98  
 99          # Download
100          dst_dir = tmpdir / "dst"
101          dst_dir.mkdir()
102          with Timer() as t_download:
103              mlflow.artifacts.download_artifacts(
104                  artifact_uri=f"{run.info.artifact_uri}/", dst_path=dst_dir
105              )
106  
107          # Verify checksums
108          with ThreadPoolExecutor() as pool:
109              futures = []
110              for f in dst_dir.rglob("*"):
111                  if f.is_dir():
112                      continue
113                  futures.append(pool.submit(assert_checksum_equal, f, files[f.name]))
114  
115              for fut in tqdm(
116                  as_completed(futures),
117                  total=len(futures),
118                  desc="Verifying checksums",
119                  colour="#FFA500",
120              ):
121                  fut.result()
122  
123          return t_upload.elapsed, t_download.elapsed
124  
125  
126  def main():
127      # Uncomment the following lines if you're running this script outside of Databricks
128      # using a personal access token:
129      # mlflow.set_tracking_uri("databricks")
130      # mlflow.set_experiment("/Users/<username>/benchmark")
131  
132      FILE_SIZE = 1 * GiB
133      NUM_FILES = 2
134      NUM_ATTEMPTS = 3
135  
136      show_system_info()
137      stats = []
138      for i in range(NUM_ATTEMPTS):
139          print(f"Attempt {i + 1} / {NUM_ATTEMPTS}")
140          stats.append(upload_and_download(FILE_SIZE, NUM_FILES))
141  
142      df = pd.DataFrame(stats, columns=["upload [s]", "download [s]"])
143      # show mean, min, max in markdown table
144      print(df.aggregate(["count", "mean", "min", "max"]).to_markdown())
145  
146  
147  if __name__ == "__main__":
148      main()