/ dev / check_action_pins.py
check_action_pins.py
  1  """Validate that all remote GitHub Actions are SHA-pinned with a version comment."""
  2  
  3  import json
  4  import re
  5  import subprocess
  6  import sys
  7  from collections import defaultdict
  8  from collections.abc import Iterator
  9  from dataclasses import dataclass
 10  from pathlib import Path
 11  
 12  # Matches a `uses:` line that references a remote action (not a local `./` path).
 13  # Captures:  owner/repo[/subpath]  @  ref  [  # comment  ]
 14  _USES_RE = re.compile(
 15      r"""
 16      ^\s*-?\s*uses:\s+          # leading `- uses:` or `uses:`
 17      (?P<action>[^@\s]+)        # owner/repo[/subpath]
 18      @
 19      (?P<ref>[^\s#]+)           # ref (SHA, tag, or branch)
 20      (?:\s+\#\s*(?P<comment>\S+))?  # optional  # comment
 21      """,
 22      re.VERBOSE,
 23  )
 24  
 25  # A full 40-character hexadecimal SHA.
 26  _SHA_RE = re.compile(r"^[0-9a-f]{40}$")
 27  
 28  # Requires at least vMAJOR.MINOR.PATCH to avoid ambiguous moving tags like v4.
 29  _VERSION_COMMENT_RE = re.compile(r"^v\d+\.\d+\.\d+(?:\.\d+)*$")
 30  
 31  _CACHE_PATH = Path(".cache/action-pins.json")
 32  
 33  
 34  def _load_cache() -> dict[str, bool]:
 35      if _CACHE_PATH.exists():
 36          try:
 37              return json.loads(_CACHE_PATH.read_text())  # type: ignore[no-any-return]
 38          except (json.JSONDecodeError, OSError):
 39              pass
 40      return {}
 41  
 42  
 43  def _save_cache(cache: dict[str, bool]) -> None:
 44      try:
 45          _CACHE_PATH.parent.mkdir(parents=True, exist_ok=True)
 46          _CACHE_PATH.write_text(json.dumps(cache, indent=2, sort_keys=True))
 47      except OSError:
 48          pass
 49  
 50  
 51  def _repo_from_action(action: str) -> str:
 52      match action.split("/"):
 53          case [owner, repo, *_]:
 54              return f"{owner}/{repo}"
 55          case _:
 56              raise ValueError(f"Invalid action format: {action!r}")
 57  
 58  
 59  def _verify_sha_tag(action: str, sha: str, tag: str, cache: dict[str, bool]) -> bool | None:
 60  
 61      cache_key = f"{action}@{sha}#{tag}"
 62      if cache_key in cache:
 63          return cache[cache_key]
 64  
 65      repo = _repo_from_action(action)
 66      try:
 67          result = _resolve_tag(repo, sha, tag)
 68      except (subprocess.CalledProcessError, subprocess.TimeoutExpired):
 69          return None
 70  
 71      cache[cache_key] = result
 72      return result
 73  
 74  
 75  def _resolve_tag(repo: str, sha: str, tag: str) -> bool:
 76      output = subprocess.check_output(
 77          ["git", "ls-remote", "--tags", f"https://github.com/{repo}.git", tag],
 78          text=True,
 79          timeout=10,
 80      )
 81      return any(line.split()[0] == sha for line in output.splitlines() if line)
 82  
 83  
 84  def _iter_files() -> Iterator[Path]:
 85      root = Path(".github")
 86      for pattern in (
 87          "workflows/*.yml",
 88          "workflows/*.yaml",
 89          "actions/**/*.yml",
 90          "actions/**/*.yaml",
 91      ):
 92          yield from root.glob(pattern)
 93  
 94  
 95  @dataclass(frozen=True, slots=True)
 96  class ActionRef:
 97      prefix: str
 98      action: str
 99      ref: str
100      comment: str | None
101  
102  
103  def _iter_actions(path: Path) -> Iterator[ActionRef]:
104      with path.open(encoding="utf-8") as f:
105          for lineno, line in enumerate(f, start=1):
106              if m := _USES_RE.match(line):
107                  action = m.group("action")
108                  if not action.startswith("./"):
109                      prefix = f"{path}:{lineno}: {line.strip()!r}"
110                      yield ActionRef(prefix, action, m.group("ref"), m.group("comment"))
111  
112  
113  def _check_action(a: ActionRef, cache: dict[str, bool]) -> str | None:
114      if not _SHA_RE.match(a.ref):
115          return f"{a.prefix}\n  error: ref '{a.ref}' is not a 40-character SHA"
116  
117      if not a.comment or not _VERSION_COMMENT_RE.match(a.comment):
118          return (
119              f"{a.prefix}\n  error: missing or invalid version comment"
120              f" (expected '# vX.Y.Z', got {a.comment!r})"
121          )
122  
123      verified = _verify_sha_tag(a.action, a.ref, a.comment, cache)
124      if verified is None:
125          return (
126              f"{a.prefix}\n  error: could not verify SHA against tag '{a.comment}'"
127              f" for {_repo_from_action(a.action)} (GitHub API unavailable)"
128          )
129      if not verified:
130          return (
131              f"{a.prefix}\n  error: SHA '{a.ref}' does not match tag '{a.comment}'"
132              f" for {_repo_from_action(a.action)}"
133          )
134      return None
135  
136  
137  def _check_version_consistency(all_action_refs: list[ActionRef]) -> Iterator[str]:
138      by_action: dict[str, list[ActionRef]] = defaultdict(list)
139      for action_ref in all_action_refs:
140          by_action[action_ref.action].append(action_ref)
141  
142      for action, refs in sorted(by_action.items()):
143          versions = {(ref.ref, ref.comment) for ref in refs}
144          if len(versions) > 1:
145              lines = "\n".join(f"  {ref.prefix}" for ref in sorted(refs, key=lambda r: r.prefix))
146              yield f"{action} is pinned to multiple versions:\n{lines}"
147  
148  
149  def main() -> int:
150      cache = _load_cache()
151      all_errors: list[str] = []
152      all_action_refs: list[ActionRef] = []
153      try:
154          for path in _iter_files():
155              for action_ref in _iter_actions(path):
156                  if error := _check_action(action_ref, cache):
157                      all_errors.append(error)
158                  else:
159                      all_action_refs.append(action_ref)
160      finally:
161          _save_cache(cache)
162      all_errors.extend(_check_version_consistency(all_action_refs))
163  
164      if all_errors:
165          print("action-pins: the following violations were found:\n", file=sys.stderr)
166          for err in all_errors:
167              print(err, file=sys.stderr)
168          print(f"\n{len(all_errors)} violation(s) found.", file=sys.stderr)
169          return 1
170  
171      return 0
172  
173  
174  if __name__ == "__main__":
175      sys.exit(main())