/ docs-website / scripts / test_python_snippets.py
test_python_snippets.py
  1  #!/usr/bin/env python3
  2  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  3  #
  4  # SPDX-License-Identifier: Apache-2.0
  5  
  6  """
  7  Background tester for Python code snippets embedded in Docusaurus Markdown/MDX files.
  8  
  9  Features:
 10  - Recursively scans specified directories for .md and .mdx files
 11  - Extracts triple-backtick fenced blocks labeled with "python" or "py"
 12  - Skips blocks preceded by an immediate "<!-- test-ignore -->" marker
 13  - Supports markers above a block:
 14    - "<!-- test-run -->" to force running even if heuristically considered a concept
 15    - "<!-- test-concept -->" to force skipping as illustrative
 16    - "<!-- test-require-files: path1 path2 -->" to require files to exist (skip if missing)
 17  - Optionally skips blocks containing unsafe patterns
 18  - Executes each snippet in isolation via a temporary file using a Python subprocess
 19  - Times out long-running snippets
 20  - Emits GitHub Actions annotations for failures with file and line details
 21  - Summarizes results and sets a non-zero exit code on failures
 22  
 23  Usage:
 24    # Scan default trees
 25    python scripts/test_python_snippets.py --paths docs versioned_docs --timeout-seconds 30
 26  
 27    # Run a single file (positional target)
 28    python scripts/test_python_snippets.py docs/concepts/pipelines.mdx
 29  
 30    # Run multiple specific files (positional targets)
 31    python scripts/test_python_snippets.py docs/overview/intro.mdx docs/concepts/components.mdx
 32  
 33    # Force-run a snippet without imports via marker above the block
 34    <!-- test-run -->
 35    ```python
 36    print("hello world")
 37    ```
 38  
 39    # Mark an illustrative snippet to skip
 40    <!-- test-concept -->
 41    ```python
 42    @dataclass
 43    class Foo:
 44        ...
 45    ```
 46  
 47    # Require fixtures; snippet will be skipped if files are missing
 48    <!-- test-require-files: assets/dog.jpg data/example.json -->
 49    ```python
 50    from haystack.dataclasses import ByteStream
 51    image = ByteStream.from_file_path("assets/dog.jpg")
 52    ```
 53  """
 54  
 55  from __future__ import annotations
 56  
 57  import argparse
 58  import os
 59  import re
 60  import subprocess
 61  import sys
 62  import tempfile
 63  import textwrap
 64  from collections.abc import Iterable
 65  from dataclasses import dataclass
 66  from enum import Enum
 67  from pathlib import Path
 68  
 69  FENCE_START_RE = re.compile(r"^\s*```(?P<lang>[^\n\r]*)\s*$")
 70  FENCE_END_RE = re.compile(r"^\s*```\s*$")
 71  TEST_IGNORE_MARK = "<!-- test-ignore -->"
 72  TEST_CONCEPT_MARK = "<!-- test-concept -->"
 73  TEST_RUN_MARK = "<!-- test-run -->"
 74  TEST_REQUIRE_FILES_PREFIX = "<!-- test-require-files:"
 75  
 76  
 77  UNSAFE_PATTERNS = [
 78      # Basic patterns to avoid obviously unsafe operations in CI examples
 79      re.compile(r"\bos\.system\s*\("),
 80      re.compile(r"\bsubprocess\."),
 81      re.compile(r"\bshutil\.rmtree\s*\("),
 82      re.compile(r"\bPopen\s*\("),
 83      re.compile(r"rm\s+-rf\b"),
 84  ]
 85  
 86  
 87  SUPPORTED_EXTENSIONS = {".md", ".mdx"}
 88  
 89  
 90  @dataclass
 91  class Snippet:
 92      file_path: str
 93      """Absolute file path of the Markdown/MDX file."""
 94  
 95      relative_path: str
 96      """Path relative to the repository root, for nicer output and GH annotations."""
 97  
 98      snippet_index: int
 99      """Monotonic index of snippet within the file (1-based)."""
100  
101      start_line: int
102      """Line number (1-based) where the snippet's first code line appears."""
103  
104      code: str
105      """The code content of the snippet."""
106  
107      skipped_reason: str | None = None
108      forced_run: bool = False
109      forced_concept: bool = False
110      requires_files: list[str] | None = None  # paths relative to repo root
111  
112  
113  def find_markdown_files(paths: Iterable[str]) -> list[str]:
114      """Return sorted absolute paths to Markdown/MDX files under the provided targets."""
115  
116      files: list[str] = []
117      for base in paths:
118          if not os.path.exists(base):
119              continue
120          if os.path.isfile(base):
121              _, ext = os.path.splitext(base)
122              if ext in SUPPORTED_EXTENSIONS:
123                  files.append(os.path.abspath(base))
124              continue
125          for root, _dirs, filenames in os.walk(base):
126              for name in filenames:
127                  _, ext = os.path.splitext(name)
128                  if ext in SUPPORTED_EXTENSIONS:
129                      files.append(os.path.abspath(os.path.join(root, name)))
130      return sorted(files)
131  
132  
133  def extract_python_snippets(file_path: str, repo_root: str) -> list[Snippet]:
134      """Extract runnable Python snippets from a Markdown/MDX file."""
135      with open(file_path, encoding="utf-8") as f:
136          lines = f.read().splitlines()
137  
138      snippets: list[Snippet] = []
139      snippet_index = 0
140  
141      def is_python_language_tag(tag: str) -> bool:
142          tag = tag.strip().lower()
143          if not tag:
144              return False
145          # handle cases like "python", "python title=...", "py"
146          lang = tag.split()[0]
147          return lang in {"python", "py"}
148  
149      i = 0
150      while i < len(lines):
151          start_match = FENCE_START_RE.match(lines[i])
152          if not start_match:
153              i += 1
154              continue
155  
156          tag = (start_match.group("lang") or "").strip()
157          if not is_python_language_tag(tag):
158              i += 1
159              continue
160  
161          snippet_index += 1
162          line_no = i + 1
163  
164          markers: list[str] = []
165          j = i - 1
166          while j >= 0:
167              prev = lines[j].strip()
168              if prev == "":
169                  j -= 1
170                  continue
171              if prev.startswith("<!--") and prev.endswith("-->"):
172                  markers.append(prev)
173                  j -= 1
174                  continue
175              break
176  
177          pending_skipped_reason: str | None = None
178          pending_forced_run = False
179          pending_forced_concept = False
180          pending_requires_files: list[str] = []
181  
182          if TEST_IGNORE_MARK in markers:
183              pending_skipped_reason = "test-ignore marker"
184          if TEST_CONCEPT_MARK in markers:
185              pending_forced_concept = True
186          if TEST_RUN_MARK in markers:
187              pending_forced_run = True
188          for marker in markers:
189              if marker.startswith(TEST_REQUIRE_FILES_PREFIX) and marker.endswith("-->"):
190                  content = marker[len(TEST_REQUIRE_FILES_PREFIX) : -3].strip()
191                  if content:
192                      pending_requires_files.extend(content.split())
193  
194          block_lines: list[str] = []
195          i += 1
196          while i < len(lines) and not FENCE_END_RE.match(lines[i]):
197              block_lines.append(lines[i])
198              i += 1
199  
200          snippet = Snippet(
201              file_path=file_path,
202              relative_path=os.path.relpath(file_path, repo_root),
203              snippet_index=snippet_index,
204              start_line=line_no + 1,
205              code="\n".join(block_lines).rstrip("\n"),
206              skipped_reason=pending_skipped_reason,
207              forced_run=pending_forced_run,
208              forced_concept=pending_forced_concept,
209              requires_files=pending_requires_files.copy() if pending_requires_files else None,
210          )
211          snippets.append(snippet)
212  
213          i += 1  # Skip closing fence
214  
215      return snippets
216  
217  
218  def _should_skip_snippet(snippet: Snippet, repo_root: str, skip_unsafe: bool) -> ExecutionResult | None:
219      """Return an ExecutionResult for skipped snippets, or None if runnable."""
220  
221      if snippet.skipped_reason:
222          return ExecutionResult(snippet=snippet, status=ExecutionStatus.SKIPPED, reason=snippet.skipped_reason)
223  
224      if snippet.forced_concept and not snippet.forced_run:
225          return ExecutionResult(snippet=snippet, status=ExecutionStatus.SKIPPED, reason="concept marker")
226  
227      if snippet.requires_files:
228          missing = [p for p in snippet.requires_files if not os.path.exists(os.path.join(repo_root, p))]
229          if missing:
230              return ExecutionResult(
231                  snippet=snippet, status=ExecutionStatus.SKIPPED, reason=f"missing required files: {', '.join(missing)}"
232              )
233  
234      runnable = is_heuristically_runnable(snippet.code)
235      if not runnable and not snippet.forced_run:
236          return ExecutionResult(
237              snippet=snippet, status=ExecutionStatus.SKIPPED, reason="heuristic: no imports (concept)"
238          )
239  
240      if skip_unsafe:
241          unsafe = contains_unsafe_pattern(snippet.code)
242          if unsafe:
243              return ExecutionResult(snippet=snippet, status=ExecutionStatus.SKIPPED, reason=f"unsafe pattern: {unsafe}")
244  
245      return None
246  
247  
248  def contains_unsafe_pattern(code: str) -> str | None:
249      """Return the unsafe pattern found in code, if any."""
250  
251      for pat in UNSAFE_PATTERNS:
252          if pat.search(code):
253              return pat.pattern
254      return None
255  
256  
257  IMPORT_RE = re.compile(r"^\s*(?:from\s+\S+\s+import\s+|import\s+\S+)")
258  
259  
260  def is_heuristically_runnable(code: str) -> bool:
261      """Heuristic to detect import statements signalling runnable code."""
262  
263      return any(IMPORT_RE.search(line) for line in code.splitlines())
264  
265  
266  class ExecutionStatus(Enum):
267      PASSED = "passed"
268      FAILED = "failed"
269      SKIPPED = "skipped"
270  
271  
272  @dataclass
273  class ExecutionResult:
274      snippet: Snippet
275      status: ExecutionStatus
276      return_code: int | None = None
277      stdout: str | None = None
278      stderr: str | None = None
279      reason: str | None = None
280  
281  
282  def run_snippet(snippet: Snippet, timeout_seconds: int, cwd: str, skip_unsafe: bool) -> ExecutionResult:
283      """Execute a single snippet and return the outcome."""
284      skip_result = _should_skip_snippet(snippet, cwd, skip_unsafe)
285      if skip_result is not None:
286          return skip_result
287  
288      # Write to a temp file for better tracebacks (with file path and correct line numbers)
289      # Use a stable informative temp file name in a dedicated temp dir
290      safe_rel = snippet.relative_path.replace(os.sep, "__")
291      temp_dir = os.path.join(tempfile.gettempdir(), "doc_snippet_tests")
292      os.makedirs(temp_dir, exist_ok=True)
293      temp_name = f"{safe_rel}__snippet_{snippet.snippet_index}.py"
294      temp_path = os.path.join(temp_dir, temp_name)
295  
296      # Prepend a line directive comment to facilitate mapping if needed
297      prelude = textwrap.dedent(
298          f"""
299          # File: {snippet.relative_path}
300          # Snippet: {snippet.snippet_index}
301          # Start line in source: {snippet.start_line}
302          """
303      ).lstrip("\n")
304  
305      with open(temp_path, "w", encoding="utf-8") as tf:
306          tf.write(prelude)
307          tf.write(snippet.code)
308          tf.write("\n")
309  
310      try:
311          completed = subprocess.run(
312              [sys.executable, temp_path],
313              check=False,
314              cwd=cwd,
315              capture_output=True,
316              timeout=timeout_seconds,
317              text=True,
318              env={**os.environ, "PYTHONUNBUFFERED": "1"},
319          )
320          if completed.returncode == 0:
321              return ExecutionResult(
322                  snippet=snippet, status=ExecutionStatus.PASSED, return_code=0, stdout=completed.stdout
323              )
324  
325          return ExecutionResult(
326              snippet=snippet,
327              status=ExecutionStatus.FAILED,
328              return_code=completed.returncode,
329              stdout=completed.stdout,
330              stderr=completed.stderr,
331          )
332      except subprocess.TimeoutExpired as exc:
333          # Handle stderr which might be bytes or str
334          stderr_text = exc.stderr
335          if stderr_text is None:
336              stderr_text = ""
337          elif isinstance(stderr_text, bytes):
338              stderr_text = stderr_text.decode("utf-8", errors="replace")
339          stderr_text = stderr_text + f"\n[timeout after {timeout_seconds}s]"
340  
341          # Handle stdout which might be bytes or str
342          stdout_text = exc.stdout
343          if stdout_text is not None and isinstance(stdout_text, bytes):
344              stdout_text = stdout_text.decode("utf-8", errors="replace")
345  
346          return ExecutionResult(
347              snippet=snippet,
348              status=ExecutionStatus.FAILED,
349              reason=f"timeout after {timeout_seconds}s",
350              stdout=stdout_text,
351              stderr=stderr_text,
352          )
353  
354  
355  def print_failure_annotation(result: ExecutionResult) -> None:
356      """Print a GitHub Actions error annotation so failures are clickable in CI logs."""
357      rel = result.snippet.relative_path
358      line = result.snippet.start_line
359      # Escape newlines and percents per GH annotation rules
360      message = f"Doc snippet #{result.snippet.snippet_index} failed"
361      stderr_text = result.stderr.strip() if result.stderr else ""
362      stdout_text = result.stdout.strip() if result.stdout else ""
363      details = stderr_text or stdout_text
364      if result.reason:
365          details = f"{result.reason}\n\n" + details
366      details = details.replace("%", "%25").replace("\r", "%0D").replace("\n", "%0A")
367      sys.stdout.write(f"::error file={rel},line={line}::{message} — see details below%0A{details}\n")
368  
369  
370  def process_file_snippets(
371      file_rel: str, snippets: list[Snippet], repo_root: str, timeout_seconds: int, allow_unsafe: bool, verbose: bool
372  ) -> tuple[list[ExecutionResult], dict[str, int]]:
373      """Process all snippets in a single markdown file and return results and statistics."""
374      if verbose:
375          print(f"[RUN] {file_rel}")
376      else:
377          print(f"Running {file_rel} ({len(snippets)} snippet(s))")
378  
379      results: list[ExecutionResult] = []
380      file_passed = file_failed = file_skipped = 0
381  
382      for snippet in snippets:
383          result = run_snippet(snippet, timeout_seconds=timeout_seconds, cwd=repo_root, skip_unsafe=not allow_unsafe)
384          results.append(result)
385  
386          if result.status == ExecutionStatus.PASSED:
387              file_passed += 1
388              if verbose:
389                  print(f"[PASS] {snippet.relative_path}#snippet{snippet.snippet_index} (line {snippet.start_line})")
390          elif result.status == ExecutionStatus.SKIPPED:
391              file_skipped += 1
392              if verbose:
393                  reason = f" — {result.reason}" if result.reason else ""
394                  print(f"[SKIP] {snippet.relative_path}#snippet{snippet.snippet_index}{reason}")
395          else:
396              file_failed += 1
397              print_failure_annotation(result)
398              # Also print a concise human-readable failure line
399              print(
400                  f"FAILED {snippet.relative_path}:snippet{snippet.snippet_index} "
401                  f"(line {snippet.start_line}) — rc={result.return_code or 'N/A'}"
402              )
403              if result.stdout and result.stdout.strip():
404                  print("--- stdout ---\n" + result.stdout)
405              if result.stderr and result.stderr.strip():
406                  print("--- stderr ---\n" + result.stderr)
407      stats = {"total": len(snippets), "passed": file_passed, "failed": file_failed, "skipped": file_skipped}
408  
409      return results, stats
410  
411  
412  def main(argv: list[str] | None = None) -> int:
413      """CLI entry point for snippet execution."""
414      parser = argparse.ArgumentParser(description="Test Python code snippets in Docusaurus docs")
415      parser.add_argument(
416          "targets",
417          nargs="*",
418          help="Optional positional list of files or directories to scan. If omitted, --paths is used.",
419      )
420      parser.add_argument(
421          "--paths",
422          nargs="+",
423          default=["docs", "versioned_docs"],
424          help=(
425              "Fallback directories or files to scan when no positional targets are provided "
426              "(defaults to docs and versioned_docs)"
427          ),
428      )
429      parser.add_argument("--timeout-seconds", type=int, default=600, help="Timeout per snippet execution (seconds)")
430      parser.add_argument(
431          "--allow-unsafe", action="store_true", help="Allow execution of snippets with potentially unsafe patterns"
432      )
433      parser.add_argument("--verbose", action="store_true", help="Print verbose logs")
434  
435      args = parser.parse_args(argv)
436      # Find haystack root
437      repo_root = str(Path(__file__).parent.parent.parent)
438      raw_paths = args.targets if args.targets else args.paths
439      scan_paths = [os.path.join(repo_root, p) if not os.path.isabs(p) else p for p in raw_paths]
440  
441      md_files = find_markdown_files(scan_paths)
442      if args.verbose:
443          print(f"Repo root: {repo_root}")
444          print(f"Scanning targets: {', '.join(raw_paths)}")
445          print(f"Discovered {len(md_files)} Markdown files")
446      else:
447          print(f"Discovered {len(md_files)} Markdown files")
448  
449      all_snippets: list[Snippet] = []
450      for idx, fpath in enumerate(md_files, start=1):
451          rel = os.path.relpath(fpath, repo_root)
452          if args.verbose:
453              print(f"[SCAN {idx}/{len(md_files)}] {rel}")
454          snippets = extract_python_snippets(fpath, repo_root)
455          if snippets:
456              all_snippets.extend(snippets)
457          if args.verbose:
458              print(f"[FOUND] {rel}: {len(snippets)} python snippet(s)")
459  
460      if args.verbose:
461          print(f"Extracted {len(all_snippets)} Python snippets")
462      else:
463          print(f"Total Python snippets found: {len(all_snippets)}")
464  
465      total = len(all_snippets)
466      passed = 0
467      failed = 0
468      skipped = 0
469      results: list[ExecutionResult] = []
470  
471      # Ensure deterministic execution order grouped by file, then line
472      all_snippets.sort(key=lambda s: (s.relative_path, s.start_line, s.snippet_index))
473  
474      # Group by file
475      file_to_snippets: dict[str, list[Snippet]] = {}
476      for sn in all_snippets:
477          file_to_snippets.setdefault(sn.relative_path, []).append(sn)
478  
479      file_stats: dict[str, dict[str, int]] = {}
480      for file_rel, snippets in file_to_snippets.items():
481          file_results, stats = process_file_snippets(
482              file_rel=file_rel,
483              snippets=snippets,
484              repo_root=repo_root,
485              timeout_seconds=args.timeout_seconds,
486              allow_unsafe=args.allow_unsafe,
487              verbose=args.verbose,
488          )
489          results.extend(file_results)
490          file_stats[file_rel] = stats
491  
492          # Update totals
493          passed += stats["passed"]
494          failed += stats["failed"]
495          skipped += stats["skipped"]
496  
497      print(f"Summary: total={total}, passed={passed}, failed={failed}, skipped={skipped}")
498  
499      # Per-file summary - show only files with failures by default, all in verbose mode
500      print("Files summary:")
501      for file_rel in sorted(file_stats.keys()):
502          fs = file_stats[file_rel]
503          # Show file if it has failures or if verbose mode is on
504          if fs["failed"] > 0 or args.verbose:
505              print(
506                  f" - {file_rel}: total={fs['total']}, passed={fs['passed']}, "
507                  f"failed={fs['failed']}, skipped={fs['skipped']}"
508              )
509  
510      return 1 if failed > 0 else 0
511  
512  
513  if __name__ == "__main__":
514      sys.exit(main())