/ bin / install.py
install.py
  1  """
  2  Install binary tools for MLflow development.
  3  """
  4  
  5  # ruff: noqa: T201
  6  import argparse
  7  import gzip
  8  import hashlib
  9  import http.client
 10  import json
 11  import platform
 12  import shutil
 13  import subprocess
 14  import tarfile
 15  import tempfile
 16  import time
 17  import urllib.request
 18  from dataclasses import dataclass
 19  from pathlib import Path
 20  from typing import Literal
 21  from urllib.error import HTTPError, URLError
 22  
 23  INSTALLED_VERSIONS_FILE = ".installed_versions.json"
 24  
 25  # Type definitions
 26  PlatformKey = tuple[
 27      Literal["linux", "darwin"],
 28      Literal["x86_64", "arm64"],
 29  ]
 30  ExtractType = Literal["gzip", "tar", "binary"]
 31  
 32  
 33  @dataclass
 34  class Tool:
 35      name: str
 36      version: str
 37      assets: dict[PlatformKey, tuple[str, str]]  # platform -> (url, sha256)
 38      version_args: list[str] | None = None  # Custom version check args (default: ["--version"])
 39  
 40      def get_asset(self, platform_key: PlatformKey) -> tuple[str, str] | None:
 41          return self.assets.get(platform_key)
 42  
 43      def get_version_args(self) -> list[str]:
 44          """Get version check arguments, defaulting to --version."""
 45          return self.version_args or ["--version"]
 46  
 47      def get_extract_type(self, url: str) -> ExtractType:
 48          """Infer extract type from URL file extension."""
 49          if url.endswith(".gz") and not url.endswith(".tar.gz"):
 50              return "gzip"
 51          elif url.endswith((".tar.gz", ".tgz")):
 52              return "tar"
 53          elif url.endswith(".exe") or ("/" in url and not url.split("/")[-1].count(".")):
 54              # Windows executables or files without extensions (plain binaries)
 55              return "binary"
 56          else:
 57              # Default to tar for unknown extensions
 58              return "tar"
 59  
 60  
 61  # Tool configurations
 62  TOOLS = [
 63      Tool(
 64          name="taplo",
 65          version="0.9.3",
 66          assets={
 67              ("linux", "x86_64"): (
 68                  "https://github.com/tamasfe/taplo/releases/download/0.9.3/taplo-linux-x86_64.gz",
 69                  "889efcfa067b179fda488427d3b13ce2d679537da8b9ed8138ba415db7da2a5e",
 70              ),
 71              ("darwin", "arm64"): (
 72                  "https://github.com/tamasfe/taplo/releases/download/0.9.3/taplo-darwin-aarch64.gz",
 73                  "39b84d62d6a47855b2c64148cde9c9ca5721bf422b8c9fe9c92776860badde5f",
 74              ),
 75          },
 76      ),
 77      Tool(
 78          name="typos",
 79          version="1.39.2",
 80          assets={
 81              ("linux", "x86_64"): (
 82                  "https://github.com/crate-ci/typos/releases/download/v1.39.2/typos-v1.39.2-x86_64-unknown-linux-musl.tar.gz",
 83                  "4acfb2123a9a295d34a411ad90af23717d06914c58023ab1a12b6605f0ce3e3c",
 84              ),
 85              ("darwin", "arm64"): (
 86                  "https://github.com/crate-ci/typos/releases/download/v1.39.2/typos-v1.39.2-aarch64-apple-darwin.tar.gz",
 87                  "1dac53624939bf7b638df8cd168af46532f4fbad2b512c8b092cdf1487b94612",
 88              ),
 89          },
 90      ),
 91      Tool(
 92          name="conftest",
 93          version="0.63.0",
 94          assets={
 95              ("linux", "x86_64"): (
 96                  "https://github.com/open-policy-agent/conftest/releases/download/v0.63.0/conftest_0.63.0_Linux_x86_64.tar.gz",
 97                  "59b354bedf0d761fb562404a8af3015a48415636382f975a2037ca81c0c6202f",
 98              ),
 99              ("darwin", "arm64"): (
100                  "https://github.com/open-policy-agent/conftest/releases/download/v0.63.0/conftest_0.63.0_Darwin_arm64.tar.gz",
101                  "026378585ed42609f23996663c2feea9535bc19dc3909a99dabe776b7708b85c",
102              ),
103          },
104      ),
105      Tool(
106          name="regal",
107          version="0.36.1",
108          assets={
109              ("linux", "x86_64"): (
110                  "https://github.com/open-policy-agent/regal/releases/download/v0.36.1/regal_Linux_x86_64",
111                  "75509b89de9d2fa12ac30157cc7269e7abc61e8c4c407a29ce897b681a78f8a4",
112              ),
113              ("darwin", "arm64"): (
114                  "https://github.com/open-policy-agent/regal/releases/download/v0.36.1/regal_Darwin_arm64",
115                  "66d1578885bf8fb7a4bd7b435a74acf8205af7fc49d6db84b6df0cddba9d7591",
116              ),
117          },
118          version_args=["version"],
119      ),
120      Tool(
121          name="buf",
122          version="1.59.0",
123          assets={
124              ("linux", "x86_64"): (
125                  "https://github.com/bufbuild/buf/releases/download/v1.59.0/buf-Linux-x86_64",
126                  "d7462609e3814629c642ac10f0e7e27ec7e8e21d1dd75742f4434c31619e986b",
127              ),
128              ("darwin", "arm64"): (
129                  "https://github.com/bufbuild/buf/releases/download/v1.59.0/buf-Darwin-arm64",
130                  "71f060640b9f1a3fce43db31eb8e8faf714a3bfbbcb70617946bdeba3aadf56b",
131              ),
132          },
133      ),
134      Tool(
135          name="rg",
136          version="14.1.1",
137          assets={
138              ("linux", "x86_64"): (
139                  "https://github.com/BurntSushi/ripgrep/releases/download/14.1.1/ripgrep-14.1.1-x86_64-unknown-linux-musl.tar.gz",
140                  "4cf9f2741e6c465ffdb7c26f38056a59e2a2544b51f7cc128ef28337eeae4d8e",
141              ),
142              ("darwin", "arm64"): (
143                  "https://github.com/BurntSushi/ripgrep/releases/download/14.1.1/ripgrep-14.1.1-aarch64-apple-darwin.tar.gz",
144                  "24ad76777745fbff131c8fbc466742b011f925bfa4fffa2ded6def23b5b937be",
145              ),
146          },
147      ),
148  ]
149  
150  
151  def get_platform_key() -> PlatformKey | None:
152      system = platform.system().lower()
153      machine = platform.machine().lower()
154  
155      # Normalize machine architecture
156      if machine in ["x86_64", "amd64"]:
157          machine = "x86_64"
158      elif machine in ["aarch64", "arm64"]:
159          machine = "arm64"
160  
161      # Return if it's a supported platform combination
162      if system == "linux" and machine == "x86_64":
163          return ("linux", "x86_64")
164      elif system == "darwin" and machine == "arm64":
165          return ("darwin", "arm64")
166  
167      return None
168  
169  
170  def urlopen_with_retry(
171      url: str, max_retries: int = 7, base_delay: float = 1.0
172  ) -> http.client.HTTPResponse:
173      """Open a URL with retry logic for transient HTTP errors (e.g., 503)."""
174      for attempt in range(max_retries):
175          try:
176              return urllib.request.urlopen(url)
177          except HTTPError as e:
178              if e.code in (502, 503, 504) and attempt < max_retries - 1:
179                  delay = base_delay * (2**attempt)
180                  print(f"  HTTP {e.code}, retrying in {delay}s... ({attempt + 1}/{max_retries})")
181                  time.sleep(delay)
182              else:
183                  raise
184          except (http.client.RemoteDisconnected, ConnectionResetError, URLError) as e:
185              if attempt < max_retries - 1:
186                  delay = base_delay * (2**attempt)
187                  print(f"  {e}, retrying in {delay}s... ({attempt + 1}/{max_retries})")
188                  time.sleep(delay)
189              else:
190                  raise
191  
192  
193  def download_and_verify(url: str, dest: Path, expected_sha256: str) -> None:
194      print(f"Downloading from {url}...")
195      sha256 = hashlib.sha256()
196      with urlopen_with_retry(url) as response, dest.open("wb") as f:
197          while chunk := response.read(1024 * 1024):
198              sha256.update(chunk)
199              f.write(chunk)
200      actual = sha256.hexdigest()
201      if actual != expected_sha256:
202          raise RuntimeError(
203              f"SHA256 mismatch for {url}\n  expected: {expected_sha256}\n  actual:   {actual}"
204          )
205  
206  
207  def extract_gzip(download_path: Path, dest: Path) -> None:
208      with gzip.open(download_path, "rb") as gz:
209          dest.write_bytes(gz.read())
210  
211  
212  def extract_tar(download_path: Path, dest: Path, binary_name: str) -> None:
213      with tarfile.open(download_path, mode="r:*") as tar:
214          for member in tar:
215              if member.isfile() and member.name.endswith(binary_name):
216                  f = tar.extractfile(member)
217                  if f is not None:
218                      dest.write_bytes(f.read())
219                      return
220      raise FileNotFoundError(f"Could not find {binary_name} in archive")
221  
222  
223  def install_tool(tool: Tool, dest_dir: Path, force: bool = False) -> None:
224      # Check if tool already exists
225      binary_path = dest_dir / tool.name
226      if binary_path.exists():
227          if not force:
228              print(f"  ✓ {tool.name} already installed")
229              return
230          else:
231              print(f"  Removing existing {tool.name}...")
232              binary_path.unlink()
233  
234      platform_key = get_platform_key()
235  
236      if platform_key is None:
237          supported = [f"{os}-{arch}" for os, arch in tool.assets.keys()]
238          raise RuntimeError(
239              f"Current platform is not supported. Supported platforms: {', '.join(supported)}"
240          )
241  
242      asset = tool.get_asset(platform_key)
243      if asset is None:
244          os, arch = platform_key
245          supported = [f"{os}-{arch}" for os, arch in tool.assets.keys()]
246          raise RuntimeError(
247              f"Platform {os}-{arch} not supported for {tool.name}. "
248              f"Supported platforms: {', '.join(supported)}"
249          )
250      url, expected_sha256 = asset
251  
252      binary_path = dest_dir / tool.name
253      with tempfile.TemporaryDirectory() as tmp_dir:
254          download_path = Path(tmp_dir) / "download"
255          download_and_verify(url, download_path, expected_sha256)
256  
257          extract_type = tool.get_extract_type(url)
258          if extract_type == "gzip":
259              extract_gzip(download_path, binary_path)
260          elif extract_type == "tar":
261              extract_tar(download_path, binary_path, tool.name)
262          elif extract_type == "binary":
263              shutil.move(download_path, binary_path)
264          else:
265              raise ValueError(f"Unknown extract type: {extract_type}")
266  
267      # Make executable
268      binary_path.chmod(0o755)
269  
270      # Verify installation by running version command
271      version_cmd = [binary_path] + tool.get_version_args()
272      subprocess.check_call(version_cmd, timeout=5)
273      print(f"Successfully installed {tool.name} to {binary_path}")
274  
275  
276  def load_installed_versions(dest_dir: Path) -> dict[str, str]:
277      f = dest_dir / INSTALLED_VERSIONS_FILE
278      if f.exists():
279          return json.loads(f.read_text())
280      return {}
281  
282  
283  def save_installed_versions(dest_dir: Path, versions: dict[str, str]) -> None:
284      f = dest_dir / INSTALLED_VERSIONS_FILE
285      f.write_text(json.dumps(versions, indent=2, sort_keys=True) + "\n")
286  
287  
288  def main() -> None:
289      all_tool_names = [t.name for t in TOOLS]
290      parser = argparse.ArgumentParser(description="Install binary tools for MLflow development")
291      parser.add_argument(
292          "-f",
293          "--force-reinstall",
294          action="store_true",
295          help="Force reinstall by removing existing tools",
296      )
297      parser.add_argument(
298          "tools",
299          nargs="*",
300          metavar="TOOL",
301          help=f"Tools to install (default: all). Available: {', '.join(all_tool_names)}",
302      )
303      args = parser.parse_args()
304  
305      # Filter tools if specific ones requested
306      if args.tools:
307          if invalid_tools := set(args.tools) - set(all_tool_names):
308              parser.error(
309                  f"Unknown tools: {', '.join(sorted(invalid_tools))}. "
310                  f"Available: {', '.join(all_tool_names)}"
311              )
312          tools_to_install = [t for t in TOOLS if t.name in args.tools]
313      else:
314          tools_to_install = TOOLS
315  
316      dest_dir = Path(__file__).resolve().parent
317      dest_dir.mkdir(parents=True, exist_ok=True)
318  
319      installed_versions = load_installed_versions(dest_dir)
320      outdated_tools = sorted(
321          t.name for t in tools_to_install if installed_versions.get(t.name) != t.version
322      )
323      force_all = args.force_reinstall
324  
325      if force_all:
326          print("Force reinstall: removing existing tools and reinstalling...")
327      elif outdated_tools:
328          print(f"Version changes detected for: {', '.join(outdated_tools)}")
329      else:
330          print("Installing tools to bin/ directory...")
331  
332      for tool in tools_to_install:
333          # Force reinstall if globally forced or if this tool's version changed
334          force = force_all or tool.name in outdated_tools
335          print(f"\nInstalling {tool.name}...")
336          install_tool(tool, dest_dir, force=force)
337          installed_versions[tool.name] = tool.version
338  
339      save_installed_versions(dest_dir, installed_versions)
340      print("\nDone!")
341  
342  
343  if __name__ == "__main__":
344      main()