check_patch_prs.py
1 import argparse 2 import concurrent.futures 3 import itertools 4 import os 5 import re 6 import subprocess 7 import sys 8 import tempfile 9 from dataclasses import dataclass 10 from datetime import datetime, timedelta, timezone 11 from pathlib import Path 12 from typing import Any 13 14 import requests 15 from packaging.version import Version 16 17 MAX_COMMITS_PER_SCRIPT = 90 18 19 20 def chunk_list(lst: list[str], size: int) -> list[list[str]]: 21 return [lst[i : i + size] for i in range(0, len(lst), size)] 22 23 24 def get_token() -> str | None: 25 if token := os.environ.get("GH_TOKEN"): 26 return token 27 try: 28 token = subprocess.check_output( 29 ["gh", "auth", "token"], text=True, stderr=subprocess.DEVNULL 30 ).strip() 31 if token: 32 return token 33 except (subprocess.CalledProcessError, FileNotFoundError): 34 pass 35 return None 36 37 38 def get_headers() -> dict[str, str]: 39 if token := get_token(): 40 return {"Authorization": f"token {token}"} 41 return {} 42 43 44 def validate_version(version: str) -> None: 45 """ 46 Validate that the version has a micro version component. 47 Raises ValueError if the version is invalid. 48 """ 49 parsed_version = Version(version) 50 if len(parsed_version.release) != 3: 51 raise ValueError( 52 f"Invalid version: '{version}'. " 53 "Version must be in the format <major>.<minor>.<micro> (e.g., '2.10.0')" 54 ) 55 56 57 def get_release_branch(version: str) -> str: 58 major_minor_version = ".".join(version.split(".")[:2]) 59 return f"branch-{major_minor_version}" 60 61 62 @dataclass(frozen=True) 63 class Commit: 64 sha: str 65 pr_num: int 66 date: str 67 68 69 def get_commit_count(branch: str, since: str) -> int: 70 """ 71 Get the total count of commits in the branch since the given date using GraphQL API. 72 """ 73 query = """ 74 query($branch: String!, $since: GitTimestamp!) { 75 repository(owner: "mlflow", name: "mlflow") { 76 ref(qualifiedName: $branch) { 77 target { 78 ... on Commit { 79 history(since: $since) { 80 totalCount 81 } 82 } 83 } 84 } 85 } 86 } 87 """ 88 response = requests.post( 89 "https://api.github.com/graphql", 90 json={"query": query, "variables": {"branch": branch, "since": since}}, 91 headers=get_headers(), 92 ) 93 response.raise_for_status() 94 data = response.json() 95 ref = data["data"]["repository"]["ref"] 96 if ref is None: 97 raise ValueError(f"Branch '{branch}' not found") 98 total_count: int = ref["target"]["history"]["totalCount"] 99 return total_count 100 101 102 def get_commits(branch: str) -> list[Commit]: 103 """ 104 Get the commits in the release branch via GitHub API (last 90 days). 105 Returns commits sorted by date (oldest first). 106 """ 107 per_page = 100 108 pr_rgx = re.compile(r".+\s+\(#(\d+)\)$") 109 since = (datetime.now(timezone.utc) - timedelta(days=90)).isoformat() 110 111 # Get total commit count first 112 total_count = get_commit_count(branch, since) 113 if total_count == 0: 114 print(f"No commits found in {branch} since {since}") 115 return [] 116 117 total_pages = (total_count + per_page - 1) // per_page 118 print(f"Total commits: {total_count}, fetching {total_pages} page(s)...") 119 120 def fetch_page(page: int) -> list[Commit]: 121 print(f"Fetching page {page}/{total_pages}...") 122 params: dict[str, str | int] = { 123 "sha": branch, 124 "per_page": per_page, 125 "page": page, 126 "since": since, 127 } 128 response = requests.get( 129 "https://api.github.com/repos/mlflow/mlflow/commits", 130 params=params, 131 headers=get_headers(), 132 ) 133 response.raise_for_status() 134 commits = [] 135 for item in response.json(): 136 msg = item["commit"]["message"].split("\n")[0] 137 if m := pr_rgx.search(msg): 138 # Use committer date (not author date) because cherry-picked commits 139 # retain the original author date but get a new committer date. 140 date = item["commit"]["committer"]["date"] 141 commits.append(Commit(sha=item["sha"], pr_num=int(m.group(1)), date=date)) 142 return commits 143 144 # Fetch all pages in parallel. executor.map preserves order. 145 with concurrent.futures.ThreadPoolExecutor() as executor: 146 results = executor.map(fetch_page, range(1, total_pages + 1)) 147 148 return sorted(itertools.chain.from_iterable(results), key=lambda c: c.date) 149 150 151 @dataclass(frozen=True) 152 class PR: 153 pr_num: int 154 merged: bool 155 156 157 def is_closed(pr: dict[str, Any]) -> bool: 158 return pr["state"] == "closed" and pr["pull_request"]["merged_at"] is None 159 160 161 def fetch_patch_prs(version: str) -> dict[int, bool]: 162 """ 163 Fetch PRs labeled with `v{version}` from the MLflow repository. 164 """ 165 label = f"v{version}" 166 per_page = 100 167 page = 1 168 pulls: list[dict[str, Any]] = [] 169 while True: 170 response = requests.get( 171 f'https://api.github.com/search/issues?q=is:pr+repo:mlflow/mlflow+label:"{label}"&per_page={per_page}&page={page}', 172 headers=get_headers(), 173 ) 174 response.raise_for_status() 175 data = response.json() 176 # Exclude closed PRs that are not merged 177 pulls.extend(pr for pr in data["items"] if not is_closed(pr)) 178 if len(data["items"]) < per_page: 179 break 180 page += 1 181 182 return {pr["number"]: pr["pull_request"].get("merged_at") is not None for pr in pulls} 183 184 185 def main(version: str, dry_run: bool) -> None: 186 validate_version(version) 187 release_branch = get_release_branch(version) 188 commits = get_commits(release_branch) 189 patch_prs = fetch_patch_prs(version) 190 if not_cherry_picked := set(patch_prs) - {c.pr_num for c in commits}: 191 print(f"The following patch PRs are not cherry-picked to {release_branch}:") 192 for idx, pr_num in enumerate(sorted(not_cherry_picked)): 193 merged = patch_prs[pr_num] 194 url = f"https://github.com/mlflow/mlflow/pull/{pr_num} (merged: {merged})" 195 line = f" {idx + 1}. {url}" 196 if not merged: 197 line = f"\033[91m{line}\033[0m" # Red color using ANSI escape codes 198 print(line) 199 200 master_commits = get_commits("master") 201 cherry_picks = [c.sha for c in master_commits if c.pr_num in not_cherry_picked] 202 203 # Split into chunks if needed 204 chunks = chunk_list(cherry_picks, MAX_COMMITS_PER_SCRIPT) 205 206 # Print warning if splitting 207 if len(chunks) > 1: 208 print( 209 f"\n⚠️ WARNING: {len(cherry_picks)} commits will be split into " 210 f"{len(chunks)} scripts." 211 ) 212 print("Create one PR per script and merge them sequentially:") 213 print(" file PR 1 → merge PR 1 → pull release branch → file PR 2 → merge PR 2 → ...") 214 print("This is required to stay under GitHub's 100-commit rebase merge limit.\n") 215 216 print("\n# Steps to cherry-pick the patch PRs:") 217 print( 218 f"1. Make sure your local master and {release_branch} branches are synced with " 219 "upstream." 220 ) 221 print(f"2. Cut a new branch from {release_branch} (e.g. {release_branch}-cherry-picks).") 222 223 # Generate script(s) 224 tmp_dir = Path(tempfile.gettempdir()) 225 script_paths: list[tuple[Path, int]] = [] 226 for i, chunk in enumerate(chunks, 1): 227 if len(chunks) == 1: 228 script_path = tmp_dir / "cherry-pick.sh" 229 else: 230 script_path = tmp_dir / f"cherry-pick-{i}.sh" 231 232 script_content = f"""\ 233 #!/usr/bin/env bash 234 # Cherry-picks for v{version} -> {release_branch} 235 # Generated: {datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC")} 236 # Commits: {len(chunk)} of {len(cherry_picks)} total 237 # 238 # If conflicts occur, resolve them and run: 239 # git cherry-pick --continue 240 241 set -euo pipefail 242 243 # Guard: Prevent running on master branch 244 current_branch=$(git rev-parse --abbrev-ref HEAD) 245 if [[ "$current_branch" == "master" ]]; then 246 echo "ERROR: This script must not be run on the master branch." 247 echo "Please checkout a release branch (e.g., {release_branch}) or a branch derived from it." 248 exit 1 249 fi 250 251 git cherry-pick {" ".join(chunk)} 252 """ 253 script_path.write_text(script_content) 254 script_path.chmod(0o755) 255 script_paths.append((script_path, len(chunk))) 256 257 if len(chunks) == 1: 258 print("3. Run the cherry-pick script on the new branch:\n") 259 print(f"Cherry-pick script written to: {script_paths[0][0]}") 260 print(f"\n4. File a PR against {release_branch}.") 261 else: 262 print("3. For each script (in order):") 263 print(f" a. Create a new branch from {release_branch}") 264 print(" b. Run the script") 265 print(f" c. File a PR against {release_branch}") 266 print(f" d. After merge, pull {release_branch} from remote before the next script\n") 267 print(" Scripts:") 268 for script_path, commit_count in script_paths: 269 print(f" {script_path} ({commit_count} commits)") 270 271 sys.exit(0 if dry_run else 1) 272 273 274 if __name__ == "__main__": 275 parser = argparse.ArgumentParser() 276 parser.add_argument("--version", required=True, help="The version to release") 277 parser.add_argument( 278 "--dry-run", 279 action="store_true", 280 default=os.environ.get("DRY_RUN", "true").lower() == "true", 281 help="Dry run mode (default: True, can be set via DRY_RUN env var)", 282 ) 283 parser.add_argument( 284 "--no-dry-run", 285 action="store_false", 286 dest="dry_run", 287 help="Disable dry run mode", 288 ) 289 args = parser.parse_args() 290 main(args.version, args.dry_run)