/ dev / update_changelog.py
update_changelog.py
  1  import argparse
  2  import os
  3  import re
  4  import subprocess
  5  from collections import defaultdict
  6  from datetime import datetime
  7  from pathlib import Path
  8  from typing import Any, NamedTuple
  9  
 10  import requests
 11  from packaging.version import Version
 12  
 13  
 14  def get_header_for_version(version: str) -> str:
 15      return "## {} ({})".format(version, datetime.now().strftime("%Y-%m-%d"))
 16  
 17  
 18  def extract_pr_num_from_git_log_entry(git_log_entry: str) -> int | None:
 19      m = re.search(r"\(#(\d+)\)$", git_log_entry)
 20      return int(m.group(1)) if m else None
 21  
 22  
 23  def format_label(label: str) -> str:
 24      key = label.split("/", 1)[-1]
 25      return {
 26          "model-registry": "Model Registry",
 27          "uiux": "UI",
 28      }.get(key, key.capitalize())
 29  
 30  
 31  class PullRequest(NamedTuple):
 32      title: str
 33      number: int
 34      author: str
 35      labels: list[str]
 36  
 37      @property
 38      def url(self) -> str:
 39          return f"https://github.com/mlflow/mlflow/pull/{self.number}"
 40  
 41      @property
 42      def release_note_labels(self) -> list[str]:
 43          return [l for l in self.labels if l.startswith("rn/")]
 44  
 45      def __str__(self) -> str:
 46          areas = " / ".join(
 47              sorted(
 48                  map(
 49                      format_label,
 50                      filter(lambda l: l.split("/")[0] in ("area", "language"), self.labels),
 51                  )
 52              )
 53          )
 54          return f"[{areas}] {self.title} (#{self.number}, @{self.author})"
 55  
 56      def __repr__(self) -> str:
 57          return str(self)
 58  
 59  
 60  class Section(NamedTuple):
 61      title: str
 62      items: list[Any]
 63  
 64      def __str__(self) -> str:
 65          if not self.items:
 66              return ""
 67          return "\n\n".join([
 68              self.title,
 69              "\n".join(f"- {item}" for item in self.items),
 70          ])
 71  
 72  
 73  def is_shallow() -> bool:
 74      return (
 75          subprocess.check_output(
 76              [
 77                  "git",
 78                  "rev-parse",
 79                  "--is-shallow-repository",
 80              ],
 81              text=True,
 82          ).strip()
 83          == "true"
 84      )
 85  
 86  
 87  def batch_fetch_prs_graphql(pr_numbers: list[int]) -> list[PullRequest]:
 88      """
 89      Batch fetch PR data using GitHub GraphQL API.
 90      """
 91      if not pr_numbers:
 92          return []
 93  
 94      # GitHub GraphQL has query size limits, so batch in chunks
 95      MAX_PRS_PER_QUERY = 50  # Conservative limit to avoid query size issues
 96      all_prs: list[PullRequest] = []
 97  
 98      for i in range(0, len(pr_numbers), MAX_PRS_PER_QUERY):
 99          chunk = pr_numbers[i : i + MAX_PRS_PER_QUERY]
100          chunk_prs = _fetch_pr_chunk_graphql(chunk)
101          all_prs.extend(chunk_prs)
102  
103      return all_prs
104  
105  
106  def _fetch_pr_chunk_graphql(pr_numbers: list[int]) -> list[PullRequest]:
107      """
108      Fetch a chunk of PRs using GraphQL.
109      """
110      # Build GraphQL query with aliases for each PR
111      query_parts = [
112          "query($owner: String!, $repo: String!) {",
113          "  repository(owner: $owner, name: $repo) {",
114      ]
115  
116      for i, pr_num in enumerate(pr_numbers):
117          query_parts.append(f"""
118      pr{i}: pullRequest(number: {pr_num}) {{
119        number
120        title
121        author {{
122          login
123        }}
124        labels(first: 100) {{
125          nodes {{
126            name
127          }}
128        }}
129      }}""")
130  
131      query_parts.extend(["  }", "}"])
132      query = "\n".join(query_parts)
133  
134      # Headers with authentication
135      headers = {"Content-Type": "application/json"}
136      if token := os.environ.get("GH_TOKEN"):
137          headers["Authorization"] = f"Bearer {token}"
138      print(f"Batch fetching {len(pr_numbers)} PRs with GraphQL...")
139      resp = requests.post(
140          "https://api.github.com/graphql",
141          json={
142              "query": query,
143              "variables": {"owner": "mlflow", "repo": "mlflow"},
144          },
145          headers=headers,
146      )
147      resp.raise_for_status()
148      data = resp.json()
149      if "errors" in data:
150          raise Exception(f"GraphQL errors: {data['errors']}")
151  
152      # Extract PR data from response and create PullRequest objects
153      repository_data = data["data"]["repository"]
154      prs = []
155      for i, pr_num in enumerate(pr_numbers):
156          pr_info = repository_data.get(f"pr{i}")
157          if pr_info and pr_info.get("author"):
158              prs.append(
159                  PullRequest(
160                      title=pr_info["title"],
161                      number=pr_info["number"],
162                      author=pr_info["author"]["login"],
163                      labels=[label["name"] for label in pr_info["labels"]["nodes"]],
164                  )
165              )
166          else:
167              print(f"Warning: Could not fetch data for PR #{pr_num}")
168  
169      return prs
170  
171  
172  def main(prev_version: str, release_version: str, remote: str) -> None:
173      if is_shallow():
174          print("Unshallowing repository to ensure `git log` works correctly")
175          subprocess.check_call(["git", "fetch", "--unshallow"])
176          print("Modifying .git/config to fetch remote branches")
177          subprocess.check_call([
178              "git",
179              "config",
180              "remote.origin.fetch",
181              "+refs/heads/*:refs/remotes/origin/*",
182          ])
183      release_tag = f"v{prev_version}"
184      ver = Version(release_version)
185      branch = f"branch-{ver.major}.{ver.minor}"
186      subprocess.check_call(["git", "fetch", remote, "tag", release_tag])
187      subprocess.check_call(["git", "fetch", remote, branch])
188      git_log_output = subprocess.check_output(
189          [
190              "git",
191              "log",
192              "--left-right",
193              "--graph",
194              "--cherry-pick",
195              "--pretty=format:%s",
196              f"tags/{release_tag}...{remote}/{branch}",
197          ],
198          text=True,
199      )
200      logs = [l[2:] for l in git_log_output.splitlines() if l.startswith("> ")]
201  
202      # Extract all PR numbers first
203      pr_numbers = [pr_num for log in logs if (pr_num := extract_pr_num_from_git_log_entry(log))]
204  
205      prs = batch_fetch_prs_graphql(pr_numbers)
206      label_to_prs = defaultdict(list)
207      author_to_prs = defaultdict(list)
208      unlabelled_prs = []
209      for pr in prs:
210          if pr.author == "mlflow-app":
211              continue
212  
213          if len(pr.release_note_labels) == 0:
214              unlabelled_prs.append(pr)
215  
216          for label in pr.release_note_labels:
217              if label == "rn/none":
218                  author_to_prs[pr.author].append(pr)
219              else:
220                  label_to_prs[label].append(pr)
221  
222      assert len(unlabelled_prs) == 0, "The following PRs need to be categorized:\n" + "\n".join(
223          f"- {pr.url}" for pr in unlabelled_prs
224      )
225  
226      unknown_labels = set(label_to_prs.keys()) - {
227          "rn/highlight",
228          "rn/feature",
229          "rn/breaking-change",
230          "rn/bug-fix",
231          "rn/documentation",
232          "rn/none",
233      }
234      assert len(unknown_labels) == 0, f"Unknown labels: {unknown_labels}"
235  
236      breaking_changes = Section("Breaking changes:", label_to_prs.get("rn/breaking-change", []))
237      highlights = Section("Major new features:", label_to_prs.get("rn/highlight", []))
238      features = Section("Features:", label_to_prs.get("rn/feature", []))
239      bug_fixes = Section("Bug fixes:", label_to_prs.get("rn/bug-fix", []))
240      doc_updates = Section("Documentation updates:", label_to_prs.get("rn/documentation", []))
241      small_updates_items = [
242          ", ".join([f"#{pr.number}" for pr in prs] + [f"@{author}"])
243          for author, prs in author_to_prs.items()
244      ]
245      small_updates = "Small bug fixes and documentation updates:\n\n" + "; ".join(
246          small_updates_items
247      )
248      sections = [
249          s
250          for sec in [
251              get_header_for_version(release_version),
252              f"MLflow {release_version} includes several major features and improvements",
253              breaking_changes,
254              highlights,
255              features,
256              bug_fixes,
257              doc_updates,
258              small_updates,
259          ]
260          if (s := str(sec).strip())
261      ]
262      new_changelog = "\n\n".join(sections)
263      changelog_header = "# CHANGELOG"
264      changelog = Path("CHANGELOG.md")
265      old_changelog = changelog.read_text().replace(f"{changelog_header}\n\n", "", 1)
266      new_changelog = "\n\n".join([
267          changelog_header,
268          new_changelog,
269          old_changelog,
270      ])
271      changelog.write_text(new_changelog)
272  
273  
274  if __name__ == "__main__":
275      parser = argparse.ArgumentParser(description="Update CHANGELOG.md")
276      parser.add_argument("--prev-version", required=True, help="Previous version")
277      parser.add_argument("--release-version", required=True, help="MLflow version to release")
278      parser.add_argument("--remote", default="origin", help="Git remote to use (default: origin)")
279      args = parser.parse_args()
280      main(args.prev_version, args.release_version, args.remote)