update_changelog.py
1 import argparse 2 import os 3 import re 4 import subprocess 5 from collections import defaultdict 6 from datetime import datetime 7 from pathlib import Path 8 from typing import Any, NamedTuple 9 10 import requests 11 from packaging.version import Version 12 13 14 def get_header_for_version(version: str) -> str: 15 return "## {} ({})".format(version, datetime.now().strftime("%Y-%m-%d")) 16 17 18 def extract_pr_num_from_git_log_entry(git_log_entry: str) -> int | None: 19 m = re.search(r"\(#(\d+)\)$", git_log_entry) 20 return int(m.group(1)) if m else None 21 22 23 def format_label(label: str) -> str: 24 key = label.split("/", 1)[-1] 25 return { 26 "model-registry": "Model Registry", 27 "uiux": "UI", 28 }.get(key, key.capitalize()) 29 30 31 class PullRequest(NamedTuple): 32 title: str 33 number: int 34 author: str 35 labels: list[str] 36 37 @property 38 def url(self) -> str: 39 return f"https://github.com/mlflow/mlflow/pull/{self.number}" 40 41 @property 42 def release_note_labels(self) -> list[str]: 43 return [l for l in self.labels if l.startswith("rn/")] 44 45 def __str__(self) -> str: 46 areas = " / ".join( 47 sorted( 48 map( 49 format_label, 50 filter(lambda l: l.split("/")[0] in ("area", "language"), self.labels), 51 ) 52 ) 53 ) 54 return f"[{areas}] {self.title} (#{self.number}, @{self.author})" 55 56 def __repr__(self) -> str: 57 return str(self) 58 59 60 class Section(NamedTuple): 61 title: str 62 items: list[Any] 63 64 def __str__(self) -> str: 65 if not self.items: 66 return "" 67 return "\n\n".join([ 68 self.title, 69 "\n".join(f"- {item}" for item in self.items), 70 ]) 71 72 73 def is_shallow() -> bool: 74 return ( 75 subprocess.check_output( 76 [ 77 "git", 78 "rev-parse", 79 "--is-shallow-repository", 80 ], 81 text=True, 82 ).strip() 83 == "true" 84 ) 85 86 87 def batch_fetch_prs_graphql(pr_numbers: list[int]) -> list[PullRequest]: 88 """ 89 Batch fetch PR data using GitHub GraphQL API. 90 """ 91 if not pr_numbers: 92 return [] 93 94 # GitHub GraphQL has query size limits, so batch in chunks 95 MAX_PRS_PER_QUERY = 50 # Conservative limit to avoid query size issues 96 all_prs: list[PullRequest] = [] 97 98 for i in range(0, len(pr_numbers), MAX_PRS_PER_QUERY): 99 chunk = pr_numbers[i : i + MAX_PRS_PER_QUERY] 100 chunk_prs = _fetch_pr_chunk_graphql(chunk) 101 all_prs.extend(chunk_prs) 102 103 return all_prs 104 105 106 def _fetch_pr_chunk_graphql(pr_numbers: list[int]) -> list[PullRequest]: 107 """ 108 Fetch a chunk of PRs using GraphQL. 109 """ 110 # Build GraphQL query with aliases for each PR 111 query_parts = [ 112 "query($owner: String!, $repo: String!) {", 113 " repository(owner: $owner, name: $repo) {", 114 ] 115 116 for i, pr_num in enumerate(pr_numbers): 117 query_parts.append(f""" 118 pr{i}: pullRequest(number: {pr_num}) {{ 119 number 120 title 121 author {{ 122 login 123 }} 124 labels(first: 100) {{ 125 nodes {{ 126 name 127 }} 128 }} 129 }}""") 130 131 query_parts.extend([" }", "}"]) 132 query = "\n".join(query_parts) 133 134 # Headers with authentication 135 headers = {"Content-Type": "application/json"} 136 if token := os.environ.get("GH_TOKEN"): 137 headers["Authorization"] = f"Bearer {token}" 138 print(f"Batch fetching {len(pr_numbers)} PRs with GraphQL...") 139 resp = requests.post( 140 "https://api.github.com/graphql", 141 json={ 142 "query": query, 143 "variables": {"owner": "mlflow", "repo": "mlflow"}, 144 }, 145 headers=headers, 146 ) 147 resp.raise_for_status() 148 data = resp.json() 149 if "errors" in data: 150 raise Exception(f"GraphQL errors: {data['errors']}") 151 152 # Extract PR data from response and create PullRequest objects 153 repository_data = data["data"]["repository"] 154 prs = [] 155 for i, pr_num in enumerate(pr_numbers): 156 pr_info = repository_data.get(f"pr{i}") 157 if pr_info and pr_info.get("author"): 158 prs.append( 159 PullRequest( 160 title=pr_info["title"], 161 number=pr_info["number"], 162 author=pr_info["author"]["login"], 163 labels=[label["name"] for label in pr_info["labels"]["nodes"]], 164 ) 165 ) 166 else: 167 print(f"Warning: Could not fetch data for PR #{pr_num}") 168 169 return prs 170 171 172 def main(prev_version: str, release_version: str, remote: str) -> None: 173 if is_shallow(): 174 print("Unshallowing repository to ensure `git log` works correctly") 175 subprocess.check_call(["git", "fetch", "--unshallow"]) 176 print("Modifying .git/config to fetch remote branches") 177 subprocess.check_call([ 178 "git", 179 "config", 180 "remote.origin.fetch", 181 "+refs/heads/*:refs/remotes/origin/*", 182 ]) 183 release_tag = f"v{prev_version}" 184 ver = Version(release_version) 185 branch = f"branch-{ver.major}.{ver.minor}" 186 subprocess.check_call(["git", "fetch", remote, "tag", release_tag]) 187 subprocess.check_call(["git", "fetch", remote, branch]) 188 git_log_output = subprocess.check_output( 189 [ 190 "git", 191 "log", 192 "--left-right", 193 "--graph", 194 "--cherry-pick", 195 "--pretty=format:%s", 196 f"tags/{release_tag}...{remote}/{branch}", 197 ], 198 text=True, 199 ) 200 logs = [l[2:] for l in git_log_output.splitlines() if l.startswith("> ")] 201 202 # Extract all PR numbers first 203 pr_numbers = [pr_num for log in logs if (pr_num := extract_pr_num_from_git_log_entry(log))] 204 205 prs = batch_fetch_prs_graphql(pr_numbers) 206 label_to_prs = defaultdict(list) 207 author_to_prs = defaultdict(list) 208 unlabelled_prs = [] 209 for pr in prs: 210 if pr.author == "mlflow-app": 211 continue 212 213 if len(pr.release_note_labels) == 0: 214 unlabelled_prs.append(pr) 215 216 for label in pr.release_note_labels: 217 if label == "rn/none": 218 author_to_prs[pr.author].append(pr) 219 else: 220 label_to_prs[label].append(pr) 221 222 assert len(unlabelled_prs) == 0, "The following PRs need to be categorized:\n" + "\n".join( 223 f"- {pr.url}" for pr in unlabelled_prs 224 ) 225 226 unknown_labels = set(label_to_prs.keys()) - { 227 "rn/highlight", 228 "rn/feature", 229 "rn/breaking-change", 230 "rn/bug-fix", 231 "rn/documentation", 232 "rn/none", 233 } 234 assert len(unknown_labels) == 0, f"Unknown labels: {unknown_labels}" 235 236 breaking_changes = Section("Breaking changes:", label_to_prs.get("rn/breaking-change", [])) 237 highlights = Section("Major new features:", label_to_prs.get("rn/highlight", [])) 238 features = Section("Features:", label_to_prs.get("rn/feature", [])) 239 bug_fixes = Section("Bug fixes:", label_to_prs.get("rn/bug-fix", [])) 240 doc_updates = Section("Documentation updates:", label_to_prs.get("rn/documentation", [])) 241 small_updates_items = [ 242 ", ".join([f"#{pr.number}" for pr in prs] + [f"@{author}"]) 243 for author, prs in author_to_prs.items() 244 ] 245 small_updates = "Small bug fixes and documentation updates:\n\n" + "; ".join( 246 small_updates_items 247 ) 248 sections = [ 249 s 250 for sec in [ 251 get_header_for_version(release_version), 252 f"MLflow {release_version} includes several major features and improvements", 253 breaking_changes, 254 highlights, 255 features, 256 bug_fixes, 257 doc_updates, 258 small_updates, 259 ] 260 if (s := str(sec).strip()) 261 ] 262 new_changelog = "\n\n".join(sections) 263 changelog_header = "# CHANGELOG" 264 changelog = Path("CHANGELOG.md") 265 old_changelog = changelog.read_text().replace(f"{changelog_header}\n\n", "", 1) 266 new_changelog = "\n\n".join([ 267 changelog_header, 268 new_changelog, 269 old_changelog, 270 ]) 271 changelog.write_text(new_changelog) 272 273 274 if __name__ == "__main__": 275 parser = argparse.ArgumentParser(description="Update CHANGELOG.md") 276 parser.add_argument("--prev-version", required=True, help="Previous version") 277 parser.add_argument("--release-version", required=True, help="MLflow version to release") 278 parser.add_argument("--remote", default="origin", help="Git remote to use (default: origin)") 279 args = parser.parse_args() 280 main(args.prev_version, args.release_version, args.remote)