lint.py
1 """ 2 Lightweight hook for validating code written by Claude Code. 3 """ 4 5 import ast 6 import json 7 import os 8 import re 9 import subprocess 10 import sys 11 from dataclasses import dataclass 12 from pathlib import Path 13 from typing import Literal 14 15 KILL_SWITCH_ENV_VAR = "CLAUDE_LINT_HOOK_DISABLED" 16 17 18 @dataclass 19 class LintError: 20 file: Path 21 line: int 22 column: int 23 message: str 24 25 def __str__(self) -> str: 26 return f"{self.file}:{self.line}:{self.column}: {self.message}" 27 28 29 @dataclass 30 class DiffRange: 31 start: int 32 end: int 33 34 def overlaps(self, start: int, end: int) -> bool: 35 return start <= self.end and self.start <= end 36 37 38 def parse_diff_ranges(diff_output: str) -> list[DiffRange]: 39 """Parse unified diff output and extract added line ranges.""" 40 ranges: list[DiffRange] = [] 41 for line in diff_output.splitlines(): 42 if line.startswith("@@ "): 43 if match := re.search(r"\+(\d+)(?:,(\d+))?", line): 44 start = int(match.group(1)) 45 count = int(match.group(2)) if match.group(2) else 1 46 ranges.append(DiffRange(start=start, end=start + count)) 47 return ranges 48 49 50 def overlaps_with_diff(node: ast.AST, ranges: list[DiffRange]) -> bool: 51 return any(r.overlaps(node.lineno, node.end_lineno or node.lineno) for r in ranges) 52 53 54 class Visitor(ast.NodeVisitor): 55 def __init__(self, file_path: Path, diff_ranges: list[DiffRange]) -> None: 56 self.file_path = file_path 57 self.diff_ranges = diff_ranges 58 self.errors: list[LintError] = [] 59 60 def visit_FunctionDef(self, node: ast.FunctionDef) -> None: 61 self.generic_visit(node) 62 63 def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: 64 self.generic_visit(node) 65 66 67 def lint(file_path: Path, source: str, diff_ranges: list[DiffRange]) -> list[LintError]: 68 try: 69 tree = ast.parse(source, filename=str(file_path)) 70 except SyntaxError as e: 71 return [LintError(file=file_path, line=0, column=0, message=f"Failed to parse: {e}")] 72 73 visitor = Visitor(file_path=file_path, diff_ranges=diff_ranges) 74 visitor.visit(tree) 75 return visitor.errors 76 77 78 def is_test_file(path: Path) -> bool: 79 return path.parts[0] == "tests" and path.name.startswith("test_") 80 81 82 @dataclass 83 class HookInput: 84 tool_name: Literal["Edit", "Write"] 85 file_path: Path 86 87 @classmethod 88 def parse(cls) -> "HookInput | None": 89 # https://code.claude.com/docs/en/hooks#posttooluse-input 90 data = json.loads(sys.stdin.read()) 91 tool_name = data.get("tool_name") 92 tool_input = data.get("tool_input") 93 if tool_name not in ("Edit", "Write"): 94 return None 95 file_path_str = tool_input.get("file_path") 96 if not file_path_str: 97 return None 98 file_path = Path(file_path_str) 99 if project_dir := os.environ.get("CLAUDE_PROJECT_DIR"): 100 file_path = file_path.relative_to(project_dir) 101 return cls( 102 tool_name=tool_name, 103 file_path=file_path, 104 ) 105 106 107 def is_tracked(file_path: Path) -> bool: 108 result = subprocess.run(["git", "ls-files", "--error-unmatch", file_path], capture_output=True) 109 return result.returncode == 0 110 111 112 def get_source_and_diff_ranges(hook_input: HookInput) -> tuple[str, list[DiffRange]]: 113 if hook_input.tool_name == "Edit" and is_tracked(hook_input.file_path): 114 # For Edit on tracked files, use git diff to get only changed lines 115 diff_output = subprocess.check_output( 116 ["git", "--no-pager", "diff", "-U0", "HEAD", "--", hook_input.file_path], 117 text=True, 118 ) 119 diff_ranges = parse_diff_ranges(diff_output) 120 else: 121 # For Write or Edit on untracked files, lint the whole file 122 diff_ranges = [DiffRange(start=1, end=sys.maxsize)] 123 source = hook_input.file_path.read_text() 124 return source, diff_ranges 125 126 127 def main() -> int: 128 # Kill switch: disable hook if environment variable is set 129 if os.environ.get(KILL_SWITCH_ENV_VAR): 130 return 0 131 132 hook_input = HookInput.parse() 133 if not hook_input: 134 return 0 135 136 # Ignore non-Python files 137 if hook_input.file_path.suffix != ".py": 138 return 0 139 140 # Ignore non-test files 141 if not is_test_file(hook_input.file_path): 142 return 0 143 144 source, diff_ranges = get_source_and_diff_ranges(hook_input) 145 if errors := lint(hook_input.file_path, source, diff_ranges): 146 error_details = "\n".join(f" - {error}" for error in errors) 147 reason = ( 148 f"Lint errors found:\n{error_details}\n\n" 149 f"To disable this hook, set {KILL_SWITCH_ENV_VAR}=1" 150 ) 151 # Exit code 2 = blocking error. stderr is fed back to Claude. 152 # See: https://code.claude.com/docs/en/hooks#hook-output 153 sys.stderr.write(reason + "\n") 154 return 2 155 156 return 0 157 158 159 if __name__ == "__main__": 160 sys.exit(main())