/ dev / check_patch_prs.py
check_patch_prs.py
  1  import argparse
  2  import concurrent.futures
  3  import itertools
  4  import os
  5  import re
  6  import subprocess
  7  import sys
  8  import tempfile
  9  from dataclasses import dataclass
 10  from datetime import datetime, timedelta, timezone
 11  from pathlib import Path
 12  from typing import Any
 13  
 14  import requests
 15  from packaging.version import Version
 16  
 17  MAX_COMMITS_PER_SCRIPT = 90
 18  
 19  
 20  def chunk_list(lst: list[str], size: int) -> list[list[str]]:
 21      return [lst[i : i + size] for i in range(0, len(lst), size)]
 22  
 23  
 24  def get_token() -> str | None:
 25      if token := os.environ.get("GH_TOKEN"):
 26          return token
 27      try:
 28          token = subprocess.check_output(
 29              ["gh", "auth", "token"], text=True, stderr=subprocess.DEVNULL
 30          ).strip()
 31          if token:
 32              return token
 33      except (subprocess.CalledProcessError, FileNotFoundError):
 34          pass
 35      return None
 36  
 37  
 38  def get_headers() -> dict[str, str]:
 39      if token := get_token():
 40          return {"Authorization": f"token {token}"}
 41      return {}
 42  
 43  
 44  def validate_version(version: str) -> None:
 45      """
 46      Validate that the version has a micro version component.
 47      Raises ValueError if the version is invalid.
 48      """
 49      parsed_version = Version(version)
 50      if len(parsed_version.release) != 3:
 51          raise ValueError(
 52              f"Invalid version: '{version}'. "
 53              "Version must be in the format <major>.<minor>.<micro> (e.g., '2.10.0')"
 54          )
 55  
 56  
 57  def get_release_branch(version: str) -> str:
 58      major_minor_version = ".".join(version.split(".")[:2])
 59      return f"branch-{major_minor_version}"
 60  
 61  
 62  @dataclass(frozen=True)
 63  class Commit:
 64      sha: str
 65      pr_num: int
 66      date: str
 67  
 68  
 69  def get_commit_count(branch: str, since: str) -> int:
 70      """
 71      Get the total count of commits in the branch since the given date using GraphQL API.
 72      """
 73      query = """
 74      query($branch: String!, $since: GitTimestamp!) {
 75        repository(owner: "mlflow", name: "mlflow") {
 76          ref(qualifiedName: $branch) {
 77            target {
 78              ... on Commit {
 79                history(since: $since) {
 80                  totalCount
 81                }
 82              }
 83            }
 84          }
 85        }
 86      }
 87      """
 88      response = requests.post(
 89          "https://api.github.com/graphql",
 90          json={"query": query, "variables": {"branch": branch, "since": since}},
 91          headers=get_headers(),
 92      )
 93      response.raise_for_status()
 94      data = response.json()
 95      ref = data["data"]["repository"]["ref"]
 96      if ref is None:
 97          raise ValueError(f"Branch '{branch}' not found")
 98      total_count: int = ref["target"]["history"]["totalCount"]
 99      return total_count
100  
101  
102  def get_commits(branch: str) -> list[Commit]:
103      """
104      Get the commits in the release branch via GitHub API (last 90 days).
105      Returns commits sorted by date (oldest first).
106      """
107      per_page = 100
108      pr_rgx = re.compile(r".+\s+\(#(\d+)\)$")
109      since = (datetime.now(timezone.utc) - timedelta(days=90)).isoformat()
110  
111      # Get total commit count first
112      total_count = get_commit_count(branch, since)
113      if total_count == 0:
114          print(f"No commits found in {branch} since {since}")
115          return []
116  
117      total_pages = (total_count + per_page - 1) // per_page
118      print(f"Total commits: {total_count}, fetching {total_pages} page(s)...")
119  
120      def fetch_page(page: int) -> list[Commit]:
121          print(f"Fetching page {page}/{total_pages}...")
122          params: dict[str, str | int] = {
123              "sha": branch,
124              "per_page": per_page,
125              "page": page,
126              "since": since,
127          }
128          response = requests.get(
129              "https://api.github.com/repos/mlflow/mlflow/commits",
130              params=params,
131              headers=get_headers(),
132          )
133          response.raise_for_status()
134          commits = []
135          for item in response.json():
136              msg = item["commit"]["message"].split("\n")[0]
137              if m := pr_rgx.search(msg):
138                  # Use committer date (not author date) because cherry-picked commits
139                  # retain the original author date but get a new committer date.
140                  date = item["commit"]["committer"]["date"]
141                  commits.append(Commit(sha=item["sha"], pr_num=int(m.group(1)), date=date))
142          return commits
143  
144      # Fetch all pages in parallel. executor.map preserves order.
145      with concurrent.futures.ThreadPoolExecutor() as executor:
146          results = executor.map(fetch_page, range(1, total_pages + 1))
147  
148      return sorted(itertools.chain.from_iterable(results), key=lambda c: c.date)
149  
150  
151  @dataclass(frozen=True)
152  class PR:
153      pr_num: int
154      merged: bool
155  
156  
157  def is_closed(pr: dict[str, Any]) -> bool:
158      return pr["state"] == "closed" and pr["pull_request"]["merged_at"] is None
159  
160  
161  def fetch_patch_prs(version: str) -> dict[int, bool]:
162      """
163      Fetch PRs labeled with `v{version}` from the MLflow repository.
164      """
165      label = f"v{version}"
166      per_page = 100
167      page = 1
168      pulls: list[dict[str, Any]] = []
169      while True:
170          response = requests.get(
171              f'https://api.github.com/search/issues?q=is:pr+repo:mlflow/mlflow+label:"{label}"&per_page={per_page}&page={page}',
172              headers=get_headers(),
173          )
174          response.raise_for_status()
175          data = response.json()
176          # Exclude closed PRs that are not merged
177          pulls.extend(pr for pr in data["items"] if not is_closed(pr))
178          if len(data["items"]) < per_page:
179              break
180          page += 1
181  
182      return {pr["number"]: pr["pull_request"].get("merged_at") is not None for pr in pulls}
183  
184  
185  def main(version: str, dry_run: bool) -> None:
186      validate_version(version)
187      release_branch = get_release_branch(version)
188      commits = get_commits(release_branch)
189      patch_prs = fetch_patch_prs(version)
190      if not_cherry_picked := set(patch_prs) - {c.pr_num for c in commits}:
191          print(f"The following patch PRs are not cherry-picked to {release_branch}:")
192          for idx, pr_num in enumerate(sorted(not_cherry_picked)):
193              merged = patch_prs[pr_num]
194              url = f"https://github.com/mlflow/mlflow/pull/{pr_num} (merged: {merged})"
195              line = f"  {idx + 1}. {url}"
196              if not merged:
197                  line = f"\033[91m{line}\033[0m"  # Red color using ANSI escape codes
198              print(line)
199  
200          master_commits = get_commits("master")
201          cherry_picks = [c.sha for c in master_commits if c.pr_num in not_cherry_picked]
202  
203          # Split into chunks if needed
204          chunks = chunk_list(cherry_picks, MAX_COMMITS_PER_SCRIPT)
205  
206          # Print warning if splitting
207          if len(chunks) > 1:
208              print(
209                  f"\n⚠️  WARNING: {len(cherry_picks)} commits will be split into "
210                  f"{len(chunks)} scripts."
211              )
212              print("Create one PR per script and merge them sequentially:")
213              print("  file PR 1 → merge PR 1 → pull release branch → file PR 2 → merge PR 2 → ...")
214              print("This is required to stay under GitHub's 100-commit rebase merge limit.\n")
215  
216          print("\n# Steps to cherry-pick the patch PRs:")
217          print(
218              f"1. Make sure your local master and {release_branch} branches are synced with "
219              "upstream."
220          )
221          print(f"2. Cut a new branch from {release_branch} (e.g. {release_branch}-cherry-picks).")
222  
223          # Generate script(s)
224          tmp_dir = Path(tempfile.gettempdir())
225          script_paths: list[tuple[Path, int]] = []
226          for i, chunk in enumerate(chunks, 1):
227              if len(chunks) == 1:
228                  script_path = tmp_dir / "cherry-pick.sh"
229              else:
230                  script_path = tmp_dir / f"cherry-pick-{i}.sh"
231  
232              script_content = f"""\
233  #!/usr/bin/env bash
234  # Cherry-picks for v{version} -> {release_branch}
235  # Generated: {datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC")}
236  # Commits: {len(chunk)} of {len(cherry_picks)} total
237  #
238  # If conflicts occur, resolve them and run:
239  #   git cherry-pick --continue
240  
241  set -euo pipefail
242  
243  # Guard: Prevent running on master branch
244  current_branch=$(git rev-parse --abbrev-ref HEAD)
245  if [[ "$current_branch" == "master" ]]; then
246      echo "ERROR: This script must not be run on the master branch."
247      echo "Please checkout a release branch (e.g., {release_branch}) or a branch derived from it."
248      exit 1
249  fi
250  
251  git cherry-pick {" ".join(chunk)}
252  """
253              script_path.write_text(script_content)
254              script_path.chmod(0o755)
255              script_paths.append((script_path, len(chunk)))
256  
257          if len(chunks) == 1:
258              print("3. Run the cherry-pick script on the new branch:\n")
259              print(f"Cherry-pick script written to: {script_paths[0][0]}")
260              print(f"\n4. File a PR against {release_branch}.")
261          else:
262              print("3. For each script (in order):")
263              print(f"   a. Create a new branch from {release_branch}")
264              print("   b. Run the script")
265              print(f"   c. File a PR against {release_branch}")
266              print(f"   d. After merge, pull {release_branch} from remote before the next script\n")
267              print("   Scripts:")
268              for script_path, commit_count in script_paths:
269                  print(f"     {script_path} ({commit_count} commits)")
270  
271          sys.exit(0 if dry_run else 1)
272  
273  
274  if __name__ == "__main__":
275      parser = argparse.ArgumentParser()
276      parser.add_argument("--version", required=True, help="The version to release")
277      parser.add_argument(
278          "--dry-run",
279          action="store_true",
280          default=os.environ.get("DRY_RUN", "true").lower() == "true",
281          help="Dry run mode (default: True, can be set via DRY_RUN env var)",
282      )
283      parser.add_argument(
284          "--no-dry-run",
285          action="store_false",
286          dest="dry_run",
287          help="Disable dry run mode",
288      )
289      args = parser.parse_args()
290      main(args.version, args.dry_run)