/ dev / check_function_signatures.py
check_function_signatures.py
  1  from __future__ import annotations
  2  
  3  import argparse
  4  import ast
  5  import os
  6  import subprocess
  7  import sys
  8  from dataclasses import dataclass
  9  from pathlib import Path
 10  
 11  
 12  def is_github_actions() -> bool:
 13      return os.environ.get("GITHUB_ACTIONS") == "true"
 14  
 15  
 16  @dataclass
 17  class Error:
 18      file_path: Path
 19      line: int
 20      column: int
 21      lines: list[str]
 22  
 23      def format(self, github: bool = False) -> str:
 24          message = " ".join(self.lines)
 25          if github:
 26              return f"::warning file={self.file_path},line={self.line},col={self.column}::{message}"
 27          else:
 28              return f"{self.file_path}:{self.line}:{self.column}: {message}"
 29  
 30  
 31  @dataclass
 32  class Parameter:
 33      name: str
 34      position: int | None  # None for keyword-only
 35      is_required: bool
 36      is_positional_only: bool
 37      is_keyword_only: bool
 38      lineno: int
 39      col_offset: int
 40  
 41  
 42  @dataclass
 43  class Signature:
 44      positional: list[Parameter]  # Includes positional-only and regular positional
 45      keyword_only: list[Parameter]
 46      has_var_positional: bool  # *args
 47      has_var_keyword: bool  # **kwargs
 48  
 49  
 50  @dataclass
 51  class ParameterError:
 52      message: str
 53      param_name: str
 54      lineno: int
 55      col_offset: int
 56  
 57  
 58  def parse_signature(args: ast.arguments) -> Signature:
 59      """Convert ast.arguments to a Signature dataclass for easier processing."""
 60      parameters_positional: list[Parameter] = []
 61      parameters_keyword_only: list[Parameter] = []
 62  
 63      # Process positional-only parameters
 64      for i, arg in enumerate(args.posonlyargs):
 65          parameters_positional.append(
 66              Parameter(
 67                  name=arg.arg,
 68                  position=i,
 69                  is_required=True,  # All positional-only are required
 70                  is_positional_only=True,
 71                  is_keyword_only=False,
 72                  lineno=arg.lineno,
 73                  col_offset=arg.col_offset,
 74              )
 75          )
 76  
 77      # Process regular positional parameters
 78      offset = len(args.posonlyargs)
 79      first_optional_idx = len(args.posonlyargs + args.args) - len(args.defaults)
 80  
 81      for i, arg in enumerate(args.args):
 82          pos = offset + i
 83          parameters_positional.append(
 84              Parameter(
 85                  name=arg.arg,
 86                  position=pos,
 87                  is_required=pos < first_optional_idx,
 88                  is_positional_only=False,
 89                  is_keyword_only=False,
 90                  lineno=arg.lineno,
 91                  col_offset=arg.col_offset,
 92              )
 93          )
 94  
 95      # Process keyword-only parameters
 96      for arg, default in zip(args.kwonlyargs, args.kw_defaults):
 97          parameters_keyword_only.append(
 98              Parameter(
 99                  name=arg.arg,
100                  position=None,
101                  is_required=default is None,
102                  is_positional_only=False,
103                  is_keyword_only=True,
104                  lineno=arg.lineno,
105                  col_offset=arg.col_offset,
106              )
107          )
108  
109      return Signature(
110          positional=parameters_positional,
111          keyword_only=parameters_keyword_only,
112          has_var_positional=args.vararg is not None,
113          has_var_keyword=args.kwarg is not None,
114      )
115  
116  
117  def check_signature_compatibility(
118      old_fn: ast.FunctionDef | ast.AsyncFunctionDef,
119      new_fn: ast.FunctionDef | ast.AsyncFunctionDef,
120  ) -> list[ParameterError]:
121      """
122      Return list of error messages when *new_fn* is not backward-compatible with *old_fn*,
123      or None if compatible.
124  
125      Compatibility rules
126      -------------------
127      • Positional / positional-only parameters
128          - Cannot be reordered, renamed, or removed.
129          - Adding **required** ones is breaking.
130          - Adding **optional** ones is allowed only at the end.
131          - Making an optional parameter required is breaking.
132  
133      • Keyword-only parameters (order does not matter)
134          - Cannot be renamed or removed.
135          - Making an optional parameter required is breaking.
136          - Adding a required parameter is breaking; adding an optional parameter is fine.
137      """
138      old_sig = parse_signature(old_fn.args)
139      new_sig = parse_signature(new_fn.args)
140      errors: list[ParameterError] = []
141  
142      # ------------------------------------------------------------------ #
143      # 1. Positional / pos-only parameters
144      # ------------------------------------------------------------------ #
145  
146      # (a) existing parameters must line up
147      for idx, old_param in enumerate(old_sig.positional):
148          if idx >= len(new_sig.positional):
149              errors.append(
150                  ParameterError(
151                      message=f"Positional param '{old_param.name}' was removed.",
152                      param_name=old_param.name,
153                      lineno=old_param.lineno,
154                      col_offset=old_param.col_offset,
155                  )
156              )
157              continue
158  
159          new_param = new_sig.positional[idx]
160          if old_param.name != new_param.name:
161              errors.append(
162                  ParameterError(
163                      message=(
164                          f"Positional param order/name changed: "
165                          f"'{old_param.name}' -> '{new_param.name}'."
166                      ),
167                      param_name=new_param.name,
168                      lineno=new_param.lineno,
169                      col_offset=new_param.col_offset,
170                  )
171              )
172              # Stop checking further positional params after first order/name mismatch
173              break
174  
175          if (not old_param.is_required) and new_param.is_required:
176              errors.append(
177                  ParameterError(
178                      message=f"Optional positional param '{old_param.name}' became required.",
179                      param_name=new_param.name,
180                      lineno=new_param.lineno,
181                      col_offset=new_param.col_offset,
182                  )
183              )
184  
185      # (b) any extra new positional params must be optional and appended
186      if len(new_sig.positional) > len(old_sig.positional):
187          for idx in range(len(old_sig.positional), len(new_sig.positional)):
188              new_param = new_sig.positional[idx]
189              if new_param.is_required:
190                  errors.append(
191                      ParameterError(
192                          message=f"New required positional param '{new_param.name}' added.",
193                          param_name=new_param.name,
194                          lineno=new_param.lineno,
195                          col_offset=new_param.col_offset,
196                      )
197                  )
198  
199      # ------------------------------------------------------------------ #
200      # 2. Keyword-only parameters (order-agnostic)
201      # ------------------------------------------------------------------ #
202      old_kw_names = {p.name for p in old_sig.keyword_only}
203      new_kw_names = {p.name for p in new_sig.keyword_only}
204  
205      # Build mappings for easier lookup
206      old_kw_by_name = {p.name: p for p in old_sig.keyword_only}
207      new_kw_by_name = {p.name: p for p in new_sig.keyword_only}
208  
209      # removed or renamed
210      for name in old_kw_names - new_kw_names:
211          old_param = old_kw_by_name[name]
212          errors.append(
213              ParameterError(
214                  message=f"Keyword-only param '{name}' was removed.",
215                  param_name=name,
216                  lineno=old_param.lineno,
217                  col_offset=old_param.col_offset,
218              )
219          )
220  
221      # optional -> required upgrades
222      for name in old_kw_names & new_kw_names:
223          if not old_kw_by_name[name].is_required and new_kw_by_name[name].is_required:
224              new_param = new_kw_by_name[name]
225              errors.append(
226                  ParameterError(
227                      message=f"Keyword-only param '{name}' became required.",
228                      param_name=name,
229                      lineno=new_param.lineno,
230                      col_offset=new_param.col_offset,
231                  )
232              )
233  
234      # new required keyword-only params
235      errors.extend(
236          ParameterError(
237              message=f"New required keyword-only param '{param.name}' added.",
238              param_name=param.name,
239              lineno=param.lineno,
240              col_offset=param.col_offset,
241          )
242          for param in new_sig.keyword_only
243          if param.is_required and param.name not in old_kw_names
244      )
245  
246      return errors
247  
248  
249  def _is_private(n: str) -> bool:
250      return n.startswith("_") and not n.startswith("__") and not n.endswith("__")
251  
252  
253  class FunctionSignatureExtractor(ast.NodeVisitor):
254      def __init__(self) -> None:
255          self.functions: dict[str, ast.FunctionDef | ast.AsyncFunctionDef] = {}
256          self.stack: list[ast.ClassDef] = []
257  
258      def visit_ClassDef(self, node: ast.ClassDef) -> None:
259          self.stack.append(node)
260          self.generic_visit(node)
261          self.stack.pop()
262  
263      def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
264          # Is this a private function or a function in a private class?
265          # If so, skip it.
266          if _is_private(node.name) or (self.stack and _is_private(self.stack[-1].name)):
267              return
268  
269          names = [*(c.name for c in self.stack), node.name]
270          self.functions[".".join(names)] = node
271  
272      def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
273          if _is_private(node.name) or (self.stack and _is_private(self.stack[-1].name)):
274              return
275  
276          names = [*(c.name for c in self.stack), node.name]
277          self.functions[".".join(names)] = node
278  
279  
280  def get_changed_python_files(base_branch: str = "master") -> list[Path]:
281      # In GitHub Actions PR context, we need to fetch the base branch first
282      if is_github_actions():
283          # Fetch the base branch to ensure we have it locally
284          subprocess.check_call(
285              ["git", "fetch", "origin", f"{base_branch}:{base_branch}"],
286          )
287  
288      result = subprocess.check_output(
289          ["git", "diff", "--name-only", f"{base_branch}...HEAD"], text=True
290      )
291      files = [s.strip() for s in result.splitlines()]
292      return [Path(f) for f in files if f]
293  
294  
295  def parse_functions(content: str) -> dict[str, ast.FunctionDef | ast.AsyncFunctionDef]:
296      tree = ast.parse(content)
297      extractor = FunctionSignatureExtractor()
298      extractor.visit(tree)
299      return extractor.functions
300  
301  
302  def get_file_content_at_revision(file_path: Path, revision: str) -> str | None:
303      try:
304          return subprocess.check_output(["git", "show", f"{revision}:{file_path}"], text=True)
305      except subprocess.CalledProcessError as e:
306          print(f"Warning: Failed to get file content at revision: {e}", file=sys.stderr)
307          return None
308  
309  
310  def compare_signatures(base_branch: str = "master") -> list[Error]:
311      errors: list[Error] = []
312      for file_path in get_changed_python_files(base_branch):
313          # Ignore non-Python files
314          if not file_path.suffix == ".py":
315              continue
316  
317          # Ignore files not in the mlflow directory
318          if file_path.parts[0] != "mlflow":
319              continue
320  
321          # Ignore private modules
322          if any(part.startswith("_") and part != "__init__.py" for part in file_path.parts):
323              continue
324  
325          base_content = get_file_content_at_revision(file_path, base_branch)
326          if base_content is None:
327              # Find not found in the base branch, likely added in the current branch
328              continue
329  
330          if not file_path.exists():
331              # File not found, likely deleted in the current branch
332              continue
333  
334          current_content = file_path.read_text()
335          base_functions = parse_functions(base_content)
336          current_functions = parse_functions(current_content)
337          for func_name in set(base_functions.keys()) & set(current_functions.keys()):
338              base_func = base_functions[func_name]
339              current_func = current_functions[func_name]
340              if param_errors := check_signature_compatibility(base_func, current_func):
341                  # Create individual errors for each problematic parameter
342                  errors.extend(
343                      Error(
344                          file_path=file_path,
345                          line=param_error.lineno,
346                          column=param_error.col_offset + 1,
347                          lines=[
348                              "[Non-blocking | Ignore if not public API]",
349                              param_error.message,
350                              f"This change will break existing `{func_name}` calls.",
351                              "If this is not intended, please fix it.",
352                          ],
353                      )
354                      for param_error in param_errors
355                  )
356  
357      return errors
358  
359  
360  @dataclass
361  class Args:
362      base_branch: str
363  
364  
365  def parse_args() -> Args:
366      parser = argparse.ArgumentParser(
367          description="Check for breaking changes in Python function signatures"
368      )
369      parser.add_argument("--base-branch", default=os.environ.get("GITHUB_BASE_REF", "master"))
370      args = parser.parse_args()
371      return Args(base_branch=args.base_branch)
372  
373  
374  def main() -> None:
375      args = parse_args()
376      errors = compare_signatures(args.base_branch)
377      for error in errors:
378          print(error.format(github=is_github_actions()))
379  
380  
381  if __name__ == "__main__":
382      main()