install.py
1 """ 2 Install binary tools for MLflow development. 3 """ 4 5 # ruff: noqa: T201 6 import argparse 7 import gzip 8 import hashlib 9 import http.client 10 import json 11 import platform 12 import shutil 13 import subprocess 14 import tarfile 15 import tempfile 16 import time 17 import urllib.request 18 from dataclasses import dataclass 19 from pathlib import Path 20 from typing import Literal 21 from urllib.error import HTTPError, URLError 22 23 INSTALLED_VERSIONS_FILE = ".installed_versions.json" 24 25 # Type definitions 26 PlatformKey = tuple[ 27 Literal["linux", "darwin"], 28 Literal["x86_64", "arm64"], 29 ] 30 ExtractType = Literal["gzip", "tar", "binary"] 31 32 33 @dataclass 34 class Tool: 35 name: str 36 version: str 37 assets: dict[PlatformKey, tuple[str, str]] # platform -> (url, sha256) 38 version_args: list[str] | None = None # Custom version check args (default: ["--version"]) 39 40 def get_asset(self, platform_key: PlatformKey) -> tuple[str, str] | None: 41 return self.assets.get(platform_key) 42 43 def get_version_args(self) -> list[str]: 44 """Get version check arguments, defaulting to --version.""" 45 return self.version_args or ["--version"] 46 47 def get_extract_type(self, url: str) -> ExtractType: 48 """Infer extract type from URL file extension.""" 49 if url.endswith(".gz") and not url.endswith(".tar.gz"): 50 return "gzip" 51 elif url.endswith((".tar.gz", ".tgz")): 52 return "tar" 53 elif url.endswith(".exe") or ("/" in url and not url.split("/")[-1].count(".")): 54 # Windows executables or files without extensions (plain binaries) 55 return "binary" 56 else: 57 # Default to tar for unknown extensions 58 return "tar" 59 60 61 # Tool configurations 62 TOOLS = [ 63 Tool( 64 name="taplo", 65 version="0.9.3", 66 assets={ 67 ("linux", "x86_64"): ( 68 "https://github.com/tamasfe/taplo/releases/download/0.9.3/taplo-linux-x86_64.gz", 69 "889efcfa067b179fda488427d3b13ce2d679537da8b9ed8138ba415db7da2a5e", 70 ), 71 ("darwin", "arm64"): ( 72 "https://github.com/tamasfe/taplo/releases/download/0.9.3/taplo-darwin-aarch64.gz", 73 "39b84d62d6a47855b2c64148cde9c9ca5721bf422b8c9fe9c92776860badde5f", 74 ), 75 }, 76 ), 77 Tool( 78 name="typos", 79 version="1.39.2", 80 assets={ 81 ("linux", "x86_64"): ( 82 "https://github.com/crate-ci/typos/releases/download/v1.39.2/typos-v1.39.2-x86_64-unknown-linux-musl.tar.gz", 83 "4acfb2123a9a295d34a411ad90af23717d06914c58023ab1a12b6605f0ce3e3c", 84 ), 85 ("darwin", "arm64"): ( 86 "https://github.com/crate-ci/typos/releases/download/v1.39.2/typos-v1.39.2-aarch64-apple-darwin.tar.gz", 87 "1dac53624939bf7b638df8cd168af46532f4fbad2b512c8b092cdf1487b94612", 88 ), 89 }, 90 ), 91 Tool( 92 name="conftest", 93 version="0.63.0", 94 assets={ 95 ("linux", "x86_64"): ( 96 "https://github.com/open-policy-agent/conftest/releases/download/v0.63.0/conftest_0.63.0_Linux_x86_64.tar.gz", 97 "59b354bedf0d761fb562404a8af3015a48415636382f975a2037ca81c0c6202f", 98 ), 99 ("darwin", "arm64"): ( 100 "https://github.com/open-policy-agent/conftest/releases/download/v0.63.0/conftest_0.63.0_Darwin_arm64.tar.gz", 101 "026378585ed42609f23996663c2feea9535bc19dc3909a99dabe776b7708b85c", 102 ), 103 }, 104 ), 105 Tool( 106 name="regal", 107 version="0.36.1", 108 assets={ 109 ("linux", "x86_64"): ( 110 "https://github.com/open-policy-agent/regal/releases/download/v0.36.1/regal_Linux_x86_64", 111 "75509b89de9d2fa12ac30157cc7269e7abc61e8c4c407a29ce897b681a78f8a4", 112 ), 113 ("darwin", "arm64"): ( 114 "https://github.com/open-policy-agent/regal/releases/download/v0.36.1/regal_Darwin_arm64", 115 "66d1578885bf8fb7a4bd7b435a74acf8205af7fc49d6db84b6df0cddba9d7591", 116 ), 117 }, 118 version_args=["version"], 119 ), 120 Tool( 121 name="buf", 122 version="1.59.0", 123 assets={ 124 ("linux", "x86_64"): ( 125 "https://github.com/bufbuild/buf/releases/download/v1.59.0/buf-Linux-x86_64", 126 "d7462609e3814629c642ac10f0e7e27ec7e8e21d1dd75742f4434c31619e986b", 127 ), 128 ("darwin", "arm64"): ( 129 "https://github.com/bufbuild/buf/releases/download/v1.59.0/buf-Darwin-arm64", 130 "71f060640b9f1a3fce43db31eb8e8faf714a3bfbbcb70617946bdeba3aadf56b", 131 ), 132 }, 133 ), 134 Tool( 135 name="rg", 136 version="14.1.1", 137 assets={ 138 ("linux", "x86_64"): ( 139 "https://github.com/BurntSushi/ripgrep/releases/download/14.1.1/ripgrep-14.1.1-x86_64-unknown-linux-musl.tar.gz", 140 "4cf9f2741e6c465ffdb7c26f38056a59e2a2544b51f7cc128ef28337eeae4d8e", 141 ), 142 ("darwin", "arm64"): ( 143 "https://github.com/BurntSushi/ripgrep/releases/download/14.1.1/ripgrep-14.1.1-aarch64-apple-darwin.tar.gz", 144 "24ad76777745fbff131c8fbc466742b011f925bfa4fffa2ded6def23b5b937be", 145 ), 146 }, 147 ), 148 ] 149 150 151 def get_platform_key() -> PlatformKey | None: 152 system = platform.system().lower() 153 machine = platform.machine().lower() 154 155 # Normalize machine architecture 156 if machine in ["x86_64", "amd64"]: 157 machine = "x86_64" 158 elif machine in ["aarch64", "arm64"]: 159 machine = "arm64" 160 161 # Return if it's a supported platform combination 162 if system == "linux" and machine == "x86_64": 163 return ("linux", "x86_64") 164 elif system == "darwin" and machine == "arm64": 165 return ("darwin", "arm64") 166 167 return None 168 169 170 def urlopen_with_retry( 171 url: str, max_retries: int = 7, base_delay: float = 1.0 172 ) -> http.client.HTTPResponse: 173 """Open a URL with retry logic for transient HTTP errors (e.g., 503).""" 174 for attempt in range(max_retries): 175 try: 176 return urllib.request.urlopen(url) 177 except HTTPError as e: 178 if e.code in (502, 503, 504) and attempt < max_retries - 1: 179 delay = base_delay * (2**attempt) 180 print(f" HTTP {e.code}, retrying in {delay}s... ({attempt + 1}/{max_retries})") 181 time.sleep(delay) 182 else: 183 raise 184 except (http.client.RemoteDisconnected, ConnectionResetError, URLError) as e: 185 if attempt < max_retries - 1: 186 delay = base_delay * (2**attempt) 187 print(f" {e}, retrying in {delay}s... ({attempt + 1}/{max_retries})") 188 time.sleep(delay) 189 else: 190 raise 191 192 193 def download_and_verify(url: str, dest: Path, expected_sha256: str) -> None: 194 print(f"Downloading from {url}...") 195 sha256 = hashlib.sha256() 196 with urlopen_with_retry(url) as response, dest.open("wb") as f: 197 while chunk := response.read(1024 * 1024): 198 sha256.update(chunk) 199 f.write(chunk) 200 actual = sha256.hexdigest() 201 if actual != expected_sha256: 202 raise RuntimeError( 203 f"SHA256 mismatch for {url}\n expected: {expected_sha256}\n actual: {actual}" 204 ) 205 206 207 def extract_gzip(download_path: Path, dest: Path) -> None: 208 with gzip.open(download_path, "rb") as gz: 209 dest.write_bytes(gz.read()) 210 211 212 def extract_tar(download_path: Path, dest: Path, binary_name: str) -> None: 213 with tarfile.open(download_path, mode="r:*") as tar: 214 for member in tar: 215 if member.isfile() and member.name.endswith(binary_name): 216 f = tar.extractfile(member) 217 if f is not None: 218 dest.write_bytes(f.read()) 219 return 220 raise FileNotFoundError(f"Could not find {binary_name} in archive") 221 222 223 def install_tool(tool: Tool, dest_dir: Path, force: bool = False) -> None: 224 # Check if tool already exists 225 binary_path = dest_dir / tool.name 226 if binary_path.exists(): 227 if not force: 228 print(f" ✓ {tool.name} already installed") 229 return 230 else: 231 print(f" Removing existing {tool.name}...") 232 binary_path.unlink() 233 234 platform_key = get_platform_key() 235 236 if platform_key is None: 237 supported = [f"{os}-{arch}" for os, arch in tool.assets.keys()] 238 raise RuntimeError( 239 f"Current platform is not supported. Supported platforms: {', '.join(supported)}" 240 ) 241 242 asset = tool.get_asset(platform_key) 243 if asset is None: 244 os, arch = platform_key 245 supported = [f"{os}-{arch}" for os, arch in tool.assets.keys()] 246 raise RuntimeError( 247 f"Platform {os}-{arch} not supported for {tool.name}. " 248 f"Supported platforms: {', '.join(supported)}" 249 ) 250 url, expected_sha256 = asset 251 252 binary_path = dest_dir / tool.name 253 with tempfile.TemporaryDirectory() as tmp_dir: 254 download_path = Path(tmp_dir) / "download" 255 download_and_verify(url, download_path, expected_sha256) 256 257 extract_type = tool.get_extract_type(url) 258 if extract_type == "gzip": 259 extract_gzip(download_path, binary_path) 260 elif extract_type == "tar": 261 extract_tar(download_path, binary_path, tool.name) 262 elif extract_type == "binary": 263 shutil.move(download_path, binary_path) 264 else: 265 raise ValueError(f"Unknown extract type: {extract_type}") 266 267 # Make executable 268 binary_path.chmod(0o755) 269 270 # Verify installation by running version command 271 version_cmd = [binary_path] + tool.get_version_args() 272 subprocess.check_call(version_cmd, timeout=5) 273 print(f"Successfully installed {tool.name} to {binary_path}") 274 275 276 def load_installed_versions(dest_dir: Path) -> dict[str, str]: 277 f = dest_dir / INSTALLED_VERSIONS_FILE 278 if f.exists(): 279 return json.loads(f.read_text()) 280 return {} 281 282 283 def save_installed_versions(dest_dir: Path, versions: dict[str, str]) -> None: 284 f = dest_dir / INSTALLED_VERSIONS_FILE 285 f.write_text(json.dumps(versions, indent=2, sort_keys=True) + "\n") 286 287 288 def main() -> None: 289 all_tool_names = [t.name for t in TOOLS] 290 parser = argparse.ArgumentParser(description="Install binary tools for MLflow development") 291 parser.add_argument( 292 "-f", 293 "--force-reinstall", 294 action="store_true", 295 help="Force reinstall by removing existing tools", 296 ) 297 parser.add_argument( 298 "tools", 299 nargs="*", 300 metavar="TOOL", 301 help=f"Tools to install (default: all). Available: {', '.join(all_tool_names)}", 302 ) 303 args = parser.parse_args() 304 305 # Filter tools if specific ones requested 306 if args.tools: 307 if invalid_tools := set(args.tools) - set(all_tool_names): 308 parser.error( 309 f"Unknown tools: {', '.join(sorted(invalid_tools))}. " 310 f"Available: {', '.join(all_tool_names)}" 311 ) 312 tools_to_install = [t for t in TOOLS if t.name in args.tools] 313 else: 314 tools_to_install = TOOLS 315 316 dest_dir = Path(__file__).resolve().parent 317 dest_dir.mkdir(parents=True, exist_ok=True) 318 319 installed_versions = load_installed_versions(dest_dir) 320 outdated_tools = sorted( 321 t.name for t in tools_to_install if installed_versions.get(t.name) != t.version 322 ) 323 force_all = args.force_reinstall 324 325 if force_all: 326 print("Force reinstall: removing existing tools and reinstalling...") 327 elif outdated_tools: 328 print(f"Version changes detected for: {', '.join(outdated_tools)}") 329 else: 330 print("Installing tools to bin/ directory...") 331 332 for tool in tools_to_install: 333 # Force reinstall if globally forced or if this tool's version changed 334 force = force_all or tool.name in outdated_tools 335 print(f"\nInstalling {tool.name}...") 336 install_tool(tool, dest_dir, force=force) 337 installed_versions[tool.name] = tool.version 338 339 save_installed_versions(dest_dir, installed_versions) 340 print("\nDone!") 341 342 343 if __name__ == "__main__": 344 main()