/ .claude / hooks / lint.py
lint.py
  1  """
  2  Lightweight hook for validating code written by Claude Code.
  3  """
  4  
  5  import ast
  6  import json
  7  import os
  8  import re
  9  import subprocess
 10  import sys
 11  from dataclasses import dataclass
 12  from pathlib import Path
 13  from typing import Literal
 14  
 15  KILL_SWITCH_ENV_VAR = "CLAUDE_LINT_HOOK_DISABLED"
 16  
 17  
 18  @dataclass
 19  class LintError:
 20      file: Path
 21      line: int
 22      column: int
 23      message: str
 24  
 25      def __str__(self) -> str:
 26          return f"{self.file}:{self.line}:{self.column}: {self.message}"
 27  
 28  
 29  @dataclass
 30  class DiffRange:
 31      start: int
 32      end: int
 33  
 34      def overlaps(self, start: int, end: int) -> bool:
 35          return start <= self.end and self.start <= end
 36  
 37  
 38  def parse_diff_ranges(diff_output: str) -> list[DiffRange]:
 39      """Parse unified diff output and extract added line ranges."""
 40      ranges: list[DiffRange] = []
 41      for line in diff_output.splitlines():
 42          if line.startswith("@@ "):
 43              if match := re.search(r"\+(\d+)(?:,(\d+))?", line):
 44                  start = int(match.group(1))
 45                  count = int(match.group(2)) if match.group(2) else 1
 46                  ranges.append(DiffRange(start=start, end=start + count))
 47      return ranges
 48  
 49  
 50  def overlaps_with_diff(node: ast.AST, ranges: list[DiffRange]) -> bool:
 51      return any(r.overlaps(node.lineno, node.end_lineno or node.lineno) for r in ranges)
 52  
 53  
 54  class Visitor(ast.NodeVisitor):
 55      def __init__(self, file_path: Path, diff_ranges: list[DiffRange]) -> None:
 56          self.file_path = file_path
 57          self.diff_ranges = diff_ranges
 58          self.errors: list[LintError] = []
 59  
 60      def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
 61          self.generic_visit(node)
 62  
 63      def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
 64          self.generic_visit(node)
 65  
 66  
 67  def lint(file_path: Path, source: str, diff_ranges: list[DiffRange]) -> list[LintError]:
 68      try:
 69          tree = ast.parse(source, filename=str(file_path))
 70      except SyntaxError as e:
 71          return [LintError(file=file_path, line=0, column=0, message=f"Failed to parse: {e}")]
 72  
 73      visitor = Visitor(file_path=file_path, diff_ranges=diff_ranges)
 74      visitor.visit(tree)
 75      return visitor.errors
 76  
 77  
 78  def is_test_file(path: Path) -> bool:
 79      return path.parts[0] == "tests" and path.name.startswith("test_")
 80  
 81  
 82  @dataclass
 83  class HookInput:
 84      tool_name: Literal["Edit", "Write"]
 85      file_path: Path
 86  
 87      @classmethod
 88      def parse(cls) -> "HookInput | None":
 89          # https://code.claude.com/docs/en/hooks#posttooluse-input
 90          data = json.loads(sys.stdin.read())
 91          tool_name = data.get("tool_name")
 92          tool_input = data.get("tool_input")
 93          if tool_name not in ("Edit", "Write"):
 94              return None
 95          file_path_str = tool_input.get("file_path")
 96          if not file_path_str:
 97              return None
 98          file_path = Path(file_path_str)
 99          if project_dir := os.environ.get("CLAUDE_PROJECT_DIR"):
100              file_path = file_path.relative_to(project_dir)
101          return cls(
102              tool_name=tool_name,
103              file_path=file_path,
104          )
105  
106  
107  def is_tracked(file_path: Path) -> bool:
108      result = subprocess.run(["git", "ls-files", "--error-unmatch", file_path], capture_output=True)
109      return result.returncode == 0
110  
111  
112  def get_source_and_diff_ranges(hook_input: HookInput) -> tuple[str, list[DiffRange]]:
113      if hook_input.tool_name == "Edit" and is_tracked(hook_input.file_path):
114          # For Edit on tracked files, use git diff to get only changed lines
115          diff_output = subprocess.check_output(
116              ["git", "--no-pager", "diff", "-U0", "HEAD", "--", hook_input.file_path],
117              text=True,
118          )
119          diff_ranges = parse_diff_ranges(diff_output)
120      else:
121          # For Write or Edit on untracked files, lint the whole file
122          diff_ranges = [DiffRange(start=1, end=sys.maxsize)]
123      source = hook_input.file_path.read_text()
124      return source, diff_ranges
125  
126  
127  def main() -> int:
128      # Kill switch: disable hook if environment variable is set
129      if os.environ.get(KILL_SWITCH_ENV_VAR):
130          return 0
131  
132      hook_input = HookInput.parse()
133      if not hook_input:
134          return 0
135  
136      # Ignore non-Python files
137      if hook_input.file_path.suffix != ".py":
138          return 0
139  
140      # Ignore non-test files
141      if not is_test_file(hook_input.file_path):
142          return 0
143  
144      source, diff_ranges = get_source_and_diff_ranges(hook_input)
145      if errors := lint(hook_input.file_path, source, diff_ranges):
146          error_details = "\n".join(f"  - {error}" for error in errors)
147          reason = (
148              f"Lint errors found:\n{error_details}\n\n"
149              f"To disable this hook, set {KILL_SWITCH_ENV_VAR}=1"
150          )
151          # Exit code 2 = blocking error. stderr is fed back to Claude.
152          # See: https://code.claude.com/docs/en/hooks#hook-output
153          sys.stderr.write(reason + "\n")
154          return 2
155  
156      return 0
157  
158  
159  if __name__ == "__main__":
160      sys.exit(main())