/ scripts / ruff_format_docs.py
ruff_format_docs.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  """
  6  Pre-commit hook that runs ruff format on Python code blocks in Markdown/MDX files.
  7  
  8  Uses the ruff configuration from pyproject.toml automatically.
  9  """
 10  
 11  import argparse
 12  import logging
 13  import os
 14  import re
 15  import subprocess
 16  import sys
 17  import tempfile
 18  
 19  logger = logging.getLogger(__name__)
 20  logger.setLevel(logging.INFO)
 21  handler = logging.StreamHandler(sys.stderr)
 22  handler.setFormatter(logging.Formatter("%(message)s"))
 23  logger.addHandler(handler)
 24  
 25  # ANSI color codes (disabled when stderr is not a terminal)
 26  _USE_COLOR = hasattr(sys.stderr, "isatty") and sys.stderr.isatty()
 27  
 28  PYTHON_FENCE_RE = re.compile(
 29      r"(?P<before>^```python\s*\n)"
 30      r"(?P<code>.*?)"
 31      r"(?P<after>^```\s*$)",
 32      re.MULTILINE | re.DOTALL,
 33  )
 34  
 35  
 36  def _color(code: str, text: str) -> str:
 37      """Colorize the text"""
 38      if _USE_COLOR:
 39          return f"\033[{code}m{text}\033[0m"
 40      return text
 41  
 42  
 43  def _find_tool(name: str) -> str:
 44      """Find a tool installed in the same virtualenv as the running Python."""
 45      return os.path.join(os.path.dirname(sys.executable), name)
 46  
 47  
 48  def _ruff(code: str, *, line_length: int) -> str:
 49      return subprocess.run(
 50          [
 51              _find_tool("ruff"),
 52              "format",
 53              f"--line-length={line_length}",
 54              "--config",
 55              "format.skip-magic-trailing-comma = false",
 56              "--stdin-filename",
 57              "block.py",
 58              "-",
 59          ],
 60          input=code,
 61          capture_output=True,
 62          text=True,
 63          check=True,
 64      ).stdout
 65  
 66  
 67  def _add_trailing_commas(code: str) -> str:
 68      """Add trailing commas to multi-line expressions using add-trailing-comma."""
 69      with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
 70          f.write(code)
 71          tmpfile = f.name
 72      try:
 73          subprocess.run([_find_tool("add-trailing-comma"), tmpfile], capture_output=True, check=False)
 74          with open(tmpfile) as f:
 75              return f.read()
 76      finally:
 77          os.unlink(tmpfile)
 78  
 79  
 80  def _format_code_block(match: re.Match, *, line_length: int, path: str) -> str:
 81      """Format a single code block"""
 82      code = match.group("code")
 83      try:
 84          # 1. ruff format (may create new multi-line expressions)
 85          # 2. add trailing commas to all multi-line expressions
 86          # 3. ruff format again (respects trailing commas, ensures stable output)
 87          formatted = _ruff(_add_trailing_commas(_ruff(code, line_length=line_length)), line_length=line_length)
 88      except subprocess.CalledProcessError as exc:
 89          snippet = code.strip().splitlines()
 90          preview = "\n".join(snippet[:5])
 91          if len(snippet) > 5:
 92              preview += f"\n... ({len(snippet) - 5} more lines)"
 93          logger.warning(
 94              "%s %s\n%s\n%s %s",
 95              _color("33", "WARNING:"),
 96              _color("1", path),
 97              _color("2", preview),
 98              _color("31", "ruff stderr:"),
 99              exc.stderr.strip(),
100          )
101          return match.group(0)
102      return match.group("before") + formatted + match.group("after")
103  
104  
105  def main() -> int:
106      """Main entrypoint"""
107      parser = argparse.ArgumentParser()
108      parser.add_argument("--line-length", type=int, default=120)
109      parser.add_argument("files", nargs="*")
110      args = parser.parse_args()
111  
112      ret = 0
113      for path in args.files:
114          with open(path) as f:
115              original = f.read()
116          new = PYTHON_FENCE_RE.sub(
117              lambda m, path=path: _format_code_block(m, line_length=args.line_length, path=path), original
118          )
119          if new != original:
120              with open(path, "w") as f:
121                  f.write(new)
122              logger.debug("%s %s", _color("32", "Rewriting:"), _color("1", path))
123              ret = 1
124      return ret
125  
126  
127  if __name__ == "__main__":
128      raise SystemExit(main())