check_function_signatures.py
1 from __future__ import annotations 2 3 import argparse 4 import ast 5 import os 6 import subprocess 7 import sys 8 from dataclasses import dataclass 9 from pathlib import Path 10 11 12 def is_github_actions() -> bool: 13 return os.environ.get("GITHUB_ACTIONS") == "true" 14 15 16 @dataclass 17 class Error: 18 file_path: Path 19 line: int 20 column: int 21 lines: list[str] 22 23 def format(self, github: bool = False) -> str: 24 message = " ".join(self.lines) 25 if github: 26 return f"::warning file={self.file_path},line={self.line},col={self.column}::{message}" 27 else: 28 return f"{self.file_path}:{self.line}:{self.column}: {message}" 29 30 31 @dataclass 32 class Parameter: 33 name: str 34 position: int | None # None for keyword-only 35 is_required: bool 36 is_positional_only: bool 37 is_keyword_only: bool 38 lineno: int 39 col_offset: int 40 41 42 @dataclass 43 class Signature: 44 positional: list[Parameter] # Includes positional-only and regular positional 45 keyword_only: list[Parameter] 46 has_var_positional: bool # *args 47 has_var_keyword: bool # **kwargs 48 49 50 @dataclass 51 class ParameterError: 52 message: str 53 param_name: str 54 lineno: int 55 col_offset: int 56 57 58 def parse_signature(args: ast.arguments) -> Signature: 59 """Convert ast.arguments to a Signature dataclass for easier processing.""" 60 parameters_positional: list[Parameter] = [] 61 parameters_keyword_only: list[Parameter] = [] 62 63 # Process positional-only parameters 64 for i, arg in enumerate(args.posonlyargs): 65 parameters_positional.append( 66 Parameter( 67 name=arg.arg, 68 position=i, 69 is_required=True, # All positional-only are required 70 is_positional_only=True, 71 is_keyword_only=False, 72 lineno=arg.lineno, 73 col_offset=arg.col_offset, 74 ) 75 ) 76 77 # Process regular positional parameters 78 offset = len(args.posonlyargs) 79 first_optional_idx = len(args.posonlyargs + args.args) - len(args.defaults) 80 81 for i, arg in enumerate(args.args): 82 pos = offset + i 83 parameters_positional.append( 84 Parameter( 85 name=arg.arg, 86 position=pos, 87 is_required=pos < first_optional_idx, 88 is_positional_only=False, 89 is_keyword_only=False, 90 lineno=arg.lineno, 91 col_offset=arg.col_offset, 92 ) 93 ) 94 95 # Process keyword-only parameters 96 for arg, default in zip(args.kwonlyargs, args.kw_defaults): 97 parameters_keyword_only.append( 98 Parameter( 99 name=arg.arg, 100 position=None, 101 is_required=default is None, 102 is_positional_only=False, 103 is_keyword_only=True, 104 lineno=arg.lineno, 105 col_offset=arg.col_offset, 106 ) 107 ) 108 109 return Signature( 110 positional=parameters_positional, 111 keyword_only=parameters_keyword_only, 112 has_var_positional=args.vararg is not None, 113 has_var_keyword=args.kwarg is not None, 114 ) 115 116 117 def check_signature_compatibility( 118 old_fn: ast.FunctionDef | ast.AsyncFunctionDef, 119 new_fn: ast.FunctionDef | ast.AsyncFunctionDef, 120 ) -> list[ParameterError]: 121 """ 122 Return list of error messages when *new_fn* is not backward-compatible with *old_fn*, 123 or None if compatible. 124 125 Compatibility rules 126 ------------------- 127 • Positional / positional-only parameters 128 - Cannot be reordered, renamed, or removed. 129 - Adding **required** ones is breaking. 130 - Adding **optional** ones is allowed only at the end. 131 - Making an optional parameter required is breaking. 132 133 • Keyword-only parameters (order does not matter) 134 - Cannot be renamed or removed. 135 - Making an optional parameter required is breaking. 136 - Adding a required parameter is breaking; adding an optional parameter is fine. 137 """ 138 old_sig = parse_signature(old_fn.args) 139 new_sig = parse_signature(new_fn.args) 140 errors: list[ParameterError] = [] 141 142 # ------------------------------------------------------------------ # 143 # 1. Positional / pos-only parameters 144 # ------------------------------------------------------------------ # 145 146 # (a) existing parameters must line up 147 for idx, old_param in enumerate(old_sig.positional): 148 if idx >= len(new_sig.positional): 149 errors.append( 150 ParameterError( 151 message=f"Positional param '{old_param.name}' was removed.", 152 param_name=old_param.name, 153 lineno=old_param.lineno, 154 col_offset=old_param.col_offset, 155 ) 156 ) 157 continue 158 159 new_param = new_sig.positional[idx] 160 if old_param.name != new_param.name: 161 errors.append( 162 ParameterError( 163 message=( 164 f"Positional param order/name changed: " 165 f"'{old_param.name}' -> '{new_param.name}'." 166 ), 167 param_name=new_param.name, 168 lineno=new_param.lineno, 169 col_offset=new_param.col_offset, 170 ) 171 ) 172 # Stop checking further positional params after first order/name mismatch 173 break 174 175 if (not old_param.is_required) and new_param.is_required: 176 errors.append( 177 ParameterError( 178 message=f"Optional positional param '{old_param.name}' became required.", 179 param_name=new_param.name, 180 lineno=new_param.lineno, 181 col_offset=new_param.col_offset, 182 ) 183 ) 184 185 # (b) any extra new positional params must be optional and appended 186 if len(new_sig.positional) > len(old_sig.positional): 187 for idx in range(len(old_sig.positional), len(new_sig.positional)): 188 new_param = new_sig.positional[idx] 189 if new_param.is_required: 190 errors.append( 191 ParameterError( 192 message=f"New required positional param '{new_param.name}' added.", 193 param_name=new_param.name, 194 lineno=new_param.lineno, 195 col_offset=new_param.col_offset, 196 ) 197 ) 198 199 # ------------------------------------------------------------------ # 200 # 2. Keyword-only parameters (order-agnostic) 201 # ------------------------------------------------------------------ # 202 old_kw_names = {p.name for p in old_sig.keyword_only} 203 new_kw_names = {p.name for p in new_sig.keyword_only} 204 205 # Build mappings for easier lookup 206 old_kw_by_name = {p.name: p for p in old_sig.keyword_only} 207 new_kw_by_name = {p.name: p for p in new_sig.keyword_only} 208 209 # removed or renamed 210 for name in old_kw_names - new_kw_names: 211 old_param = old_kw_by_name[name] 212 errors.append( 213 ParameterError( 214 message=f"Keyword-only param '{name}' was removed.", 215 param_name=name, 216 lineno=old_param.lineno, 217 col_offset=old_param.col_offset, 218 ) 219 ) 220 221 # optional -> required upgrades 222 for name in old_kw_names & new_kw_names: 223 if not old_kw_by_name[name].is_required and new_kw_by_name[name].is_required: 224 new_param = new_kw_by_name[name] 225 errors.append( 226 ParameterError( 227 message=f"Keyword-only param '{name}' became required.", 228 param_name=name, 229 lineno=new_param.lineno, 230 col_offset=new_param.col_offset, 231 ) 232 ) 233 234 # new required keyword-only params 235 errors.extend( 236 ParameterError( 237 message=f"New required keyword-only param '{param.name}' added.", 238 param_name=param.name, 239 lineno=param.lineno, 240 col_offset=param.col_offset, 241 ) 242 for param in new_sig.keyword_only 243 if param.is_required and param.name not in old_kw_names 244 ) 245 246 return errors 247 248 249 def _is_private(n: str) -> bool: 250 return n.startswith("_") and not n.startswith("__") and not n.endswith("__") 251 252 253 class FunctionSignatureExtractor(ast.NodeVisitor): 254 def __init__(self) -> None: 255 self.functions: dict[str, ast.FunctionDef | ast.AsyncFunctionDef] = {} 256 self.stack: list[ast.ClassDef] = [] 257 258 def visit_ClassDef(self, node: ast.ClassDef) -> None: 259 self.stack.append(node) 260 self.generic_visit(node) 261 self.stack.pop() 262 263 def visit_FunctionDef(self, node: ast.FunctionDef) -> None: 264 # Is this a private function or a function in a private class? 265 # If so, skip it. 266 if _is_private(node.name) or (self.stack and _is_private(self.stack[-1].name)): 267 return 268 269 names = [*(c.name for c in self.stack), node.name] 270 self.functions[".".join(names)] = node 271 272 def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: 273 if _is_private(node.name) or (self.stack and _is_private(self.stack[-1].name)): 274 return 275 276 names = [*(c.name for c in self.stack), node.name] 277 self.functions[".".join(names)] = node 278 279 280 def get_changed_python_files(base_branch: str = "master") -> list[Path]: 281 # In GitHub Actions PR context, we need to fetch the base branch first 282 if is_github_actions(): 283 # Fetch the base branch to ensure we have it locally 284 subprocess.check_call( 285 ["git", "fetch", "origin", f"{base_branch}:{base_branch}"], 286 ) 287 288 result = subprocess.check_output( 289 ["git", "diff", "--name-only", f"{base_branch}...HEAD"], text=True 290 ) 291 files = [s.strip() for s in result.splitlines()] 292 return [Path(f) for f in files if f] 293 294 295 def parse_functions(content: str) -> dict[str, ast.FunctionDef | ast.AsyncFunctionDef]: 296 tree = ast.parse(content) 297 extractor = FunctionSignatureExtractor() 298 extractor.visit(tree) 299 return extractor.functions 300 301 302 def get_file_content_at_revision(file_path: Path, revision: str) -> str | None: 303 try: 304 return subprocess.check_output(["git", "show", f"{revision}:{file_path}"], text=True) 305 except subprocess.CalledProcessError as e: 306 print(f"Warning: Failed to get file content at revision: {e}", file=sys.stderr) 307 return None 308 309 310 def compare_signatures(base_branch: str = "master") -> list[Error]: 311 errors: list[Error] = [] 312 for file_path in get_changed_python_files(base_branch): 313 # Ignore non-Python files 314 if not file_path.suffix == ".py": 315 continue 316 317 # Ignore files not in the mlflow directory 318 if file_path.parts[0] != "mlflow": 319 continue 320 321 # Ignore private modules 322 if any(part.startswith("_") and part != "__init__.py" for part in file_path.parts): 323 continue 324 325 base_content = get_file_content_at_revision(file_path, base_branch) 326 if base_content is None: 327 # Find not found in the base branch, likely added in the current branch 328 continue 329 330 if not file_path.exists(): 331 # File not found, likely deleted in the current branch 332 continue 333 334 current_content = file_path.read_text() 335 base_functions = parse_functions(base_content) 336 current_functions = parse_functions(current_content) 337 for func_name in set(base_functions.keys()) & set(current_functions.keys()): 338 base_func = base_functions[func_name] 339 current_func = current_functions[func_name] 340 if param_errors := check_signature_compatibility(base_func, current_func): 341 # Create individual errors for each problematic parameter 342 errors.extend( 343 Error( 344 file_path=file_path, 345 line=param_error.lineno, 346 column=param_error.col_offset + 1, 347 lines=[ 348 "[Non-blocking | Ignore if not public API]", 349 param_error.message, 350 f"This change will break existing `{func_name}` calls.", 351 "If this is not intended, please fix it.", 352 ], 353 ) 354 for param_error in param_errors 355 ) 356 357 return errors 358 359 360 @dataclass 361 class Args: 362 base_branch: str 363 364 365 def parse_args() -> Args: 366 parser = argparse.ArgumentParser( 367 description="Check for breaking changes in Python function signatures" 368 ) 369 parser.add_argument("--base-branch", default=os.environ.get("GITHUB_BASE_REF", "master")) 370 args = parser.parse_args() 371 return Args(base_branch=args.base_branch) 372 373 374 def main() -> None: 375 args = parse_args() 376 errors = compare_signatures(args.base_branch) 377 for error in errors: 378 print(error.format(github=is_github_actions())) 379 380 381 if __name__ == "__main__": 382 main()