test_python_snippets.py
1 #!/usr/bin/env python3 2 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 3 # 4 # SPDX-License-Identifier: Apache-2.0 5 6 """ 7 Background tester for Python code snippets embedded in Docusaurus Markdown/MDX files. 8 9 Features: 10 - Recursively scans specified directories for .md and .mdx files 11 - Extracts triple-backtick fenced blocks labeled with "python" or "py" 12 - Skips blocks preceded by an immediate "<!-- test-ignore -->" marker 13 - Supports markers above a block: 14 - "<!-- test-run -->" to force running even if heuristically considered a concept 15 - "<!-- test-concept -->" to force skipping as illustrative 16 - "<!-- test-require-files: path1 path2 -->" to require files to exist (skip if missing) 17 - Optionally skips blocks containing unsafe patterns 18 - Executes each snippet in isolation via a temporary file using a Python subprocess 19 - Times out long-running snippets 20 - Emits GitHub Actions annotations for failures with file and line details 21 - Summarizes results and sets a non-zero exit code on failures 22 23 Usage: 24 # Scan default trees 25 python scripts/test_python_snippets.py --paths docs versioned_docs --timeout-seconds 30 26 27 # Run a single file (positional target) 28 python scripts/test_python_snippets.py docs/concepts/pipelines.mdx 29 30 # Run multiple specific files (positional targets) 31 python scripts/test_python_snippets.py docs/overview/intro.mdx docs/concepts/components.mdx 32 33 # Force-run a snippet without imports via marker above the block 34 <!-- test-run --> 35 ```python 36 print("hello world") 37 ``` 38 39 # Mark an illustrative snippet to skip 40 <!-- test-concept --> 41 ```python 42 @dataclass 43 class Foo: 44 ... 45 ``` 46 47 # Require fixtures; snippet will be skipped if files are missing 48 <!-- test-require-files: assets/dog.jpg data/example.json --> 49 ```python 50 from haystack.dataclasses import ByteStream 51 image = ByteStream.from_file_path("assets/dog.jpg") 52 ``` 53 """ 54 55 from __future__ import annotations 56 57 import argparse 58 import os 59 import re 60 import subprocess 61 import sys 62 import tempfile 63 import textwrap 64 from collections.abc import Iterable 65 from dataclasses import dataclass 66 from enum import Enum 67 from pathlib import Path 68 69 FENCE_START_RE = re.compile(r"^\s*```(?P<lang>[^\n\r]*)\s*$") 70 FENCE_END_RE = re.compile(r"^\s*```\s*$") 71 TEST_IGNORE_MARK = "<!-- test-ignore -->" 72 TEST_CONCEPT_MARK = "<!-- test-concept -->" 73 TEST_RUN_MARK = "<!-- test-run -->" 74 TEST_REQUIRE_FILES_PREFIX = "<!-- test-require-files:" 75 76 77 UNSAFE_PATTERNS = [ 78 # Basic patterns to avoid obviously unsafe operations in CI examples 79 re.compile(r"\bos\.system\s*\("), 80 re.compile(r"\bsubprocess\."), 81 re.compile(r"\bshutil\.rmtree\s*\("), 82 re.compile(r"\bPopen\s*\("), 83 re.compile(r"rm\s+-rf\b"), 84 ] 85 86 87 SUPPORTED_EXTENSIONS = {".md", ".mdx"} 88 89 90 @dataclass 91 class Snippet: 92 file_path: str 93 """Absolute file path of the Markdown/MDX file.""" 94 95 relative_path: str 96 """Path relative to the repository root, for nicer output and GH annotations.""" 97 98 snippet_index: int 99 """Monotonic index of snippet within the file (1-based).""" 100 101 start_line: int 102 """Line number (1-based) where the snippet's first code line appears.""" 103 104 code: str 105 """The code content of the snippet.""" 106 107 skipped_reason: str | None = None 108 forced_run: bool = False 109 forced_concept: bool = False 110 requires_files: list[str] | None = None # paths relative to repo root 111 112 113 def find_markdown_files(paths: Iterable[str]) -> list[str]: 114 """Return sorted absolute paths to Markdown/MDX files under the provided targets.""" 115 116 files: list[str] = [] 117 for base in paths: 118 if not os.path.exists(base): 119 continue 120 if os.path.isfile(base): 121 _, ext = os.path.splitext(base) 122 if ext in SUPPORTED_EXTENSIONS: 123 files.append(os.path.abspath(base)) 124 continue 125 for root, _dirs, filenames in os.walk(base): 126 for name in filenames: 127 _, ext = os.path.splitext(name) 128 if ext in SUPPORTED_EXTENSIONS: 129 files.append(os.path.abspath(os.path.join(root, name))) 130 return sorted(files) 131 132 133 def extract_python_snippets(file_path: str, repo_root: str) -> list[Snippet]: 134 """Extract runnable Python snippets from a Markdown/MDX file.""" 135 with open(file_path, encoding="utf-8") as f: 136 lines = f.read().splitlines() 137 138 snippets: list[Snippet] = [] 139 snippet_index = 0 140 141 def is_python_language_tag(tag: str) -> bool: 142 tag = tag.strip().lower() 143 if not tag: 144 return False 145 # handle cases like "python", "python title=...", "py" 146 lang = tag.split()[0] 147 return lang in {"python", "py"} 148 149 i = 0 150 while i < len(lines): 151 start_match = FENCE_START_RE.match(lines[i]) 152 if not start_match: 153 i += 1 154 continue 155 156 tag = (start_match.group("lang") or "").strip() 157 if not is_python_language_tag(tag): 158 i += 1 159 continue 160 161 snippet_index += 1 162 line_no = i + 1 163 164 markers: list[str] = [] 165 j = i - 1 166 while j >= 0: 167 prev = lines[j].strip() 168 if prev == "": 169 j -= 1 170 continue 171 if prev.startswith("<!--") and prev.endswith("-->"): 172 markers.append(prev) 173 j -= 1 174 continue 175 break 176 177 pending_skipped_reason: str | None = None 178 pending_forced_run = False 179 pending_forced_concept = False 180 pending_requires_files: list[str] = [] 181 182 if TEST_IGNORE_MARK in markers: 183 pending_skipped_reason = "test-ignore marker" 184 if TEST_CONCEPT_MARK in markers: 185 pending_forced_concept = True 186 if TEST_RUN_MARK in markers: 187 pending_forced_run = True 188 for marker in markers: 189 if marker.startswith(TEST_REQUIRE_FILES_PREFIX) and marker.endswith("-->"): 190 content = marker[len(TEST_REQUIRE_FILES_PREFIX) : -3].strip() 191 if content: 192 pending_requires_files.extend(content.split()) 193 194 block_lines: list[str] = [] 195 i += 1 196 while i < len(lines) and not FENCE_END_RE.match(lines[i]): 197 block_lines.append(lines[i]) 198 i += 1 199 200 snippet = Snippet( 201 file_path=file_path, 202 relative_path=os.path.relpath(file_path, repo_root), 203 snippet_index=snippet_index, 204 start_line=line_no + 1, 205 code="\n".join(block_lines).rstrip("\n"), 206 skipped_reason=pending_skipped_reason, 207 forced_run=pending_forced_run, 208 forced_concept=pending_forced_concept, 209 requires_files=pending_requires_files.copy() if pending_requires_files else None, 210 ) 211 snippets.append(snippet) 212 213 i += 1 # Skip closing fence 214 215 return snippets 216 217 218 def _should_skip_snippet(snippet: Snippet, repo_root: str, skip_unsafe: bool) -> ExecutionResult | None: 219 """Return an ExecutionResult for skipped snippets, or None if runnable.""" 220 221 if snippet.skipped_reason: 222 return ExecutionResult(snippet=snippet, status=ExecutionStatus.SKIPPED, reason=snippet.skipped_reason) 223 224 if snippet.forced_concept and not snippet.forced_run: 225 return ExecutionResult(snippet=snippet, status=ExecutionStatus.SKIPPED, reason="concept marker") 226 227 if snippet.requires_files: 228 missing = [p for p in snippet.requires_files if not os.path.exists(os.path.join(repo_root, p))] 229 if missing: 230 return ExecutionResult( 231 snippet=snippet, status=ExecutionStatus.SKIPPED, reason=f"missing required files: {', '.join(missing)}" 232 ) 233 234 runnable = is_heuristically_runnable(snippet.code) 235 if not runnable and not snippet.forced_run: 236 return ExecutionResult( 237 snippet=snippet, status=ExecutionStatus.SKIPPED, reason="heuristic: no imports (concept)" 238 ) 239 240 if skip_unsafe: 241 unsafe = contains_unsafe_pattern(snippet.code) 242 if unsafe: 243 return ExecutionResult(snippet=snippet, status=ExecutionStatus.SKIPPED, reason=f"unsafe pattern: {unsafe}") 244 245 return None 246 247 248 def contains_unsafe_pattern(code: str) -> str | None: 249 """Return the unsafe pattern found in code, if any.""" 250 251 for pat in UNSAFE_PATTERNS: 252 if pat.search(code): 253 return pat.pattern 254 return None 255 256 257 IMPORT_RE = re.compile(r"^\s*(?:from\s+\S+\s+import\s+|import\s+\S+)") 258 259 260 def is_heuristically_runnable(code: str) -> bool: 261 """Heuristic to detect import statements signalling runnable code.""" 262 263 return any(IMPORT_RE.search(line) for line in code.splitlines()) 264 265 266 class ExecutionStatus(Enum): 267 PASSED = "passed" 268 FAILED = "failed" 269 SKIPPED = "skipped" 270 271 272 @dataclass 273 class ExecutionResult: 274 snippet: Snippet 275 status: ExecutionStatus 276 return_code: int | None = None 277 stdout: str | None = None 278 stderr: str | None = None 279 reason: str | None = None 280 281 282 def run_snippet(snippet: Snippet, timeout_seconds: int, cwd: str, skip_unsafe: bool) -> ExecutionResult: 283 """Execute a single snippet and return the outcome.""" 284 skip_result = _should_skip_snippet(snippet, cwd, skip_unsafe) 285 if skip_result is not None: 286 return skip_result 287 288 # Write to a temp file for better tracebacks (with file path and correct line numbers) 289 # Use a stable informative temp file name in a dedicated temp dir 290 safe_rel = snippet.relative_path.replace(os.sep, "__") 291 temp_dir = os.path.join(tempfile.gettempdir(), "doc_snippet_tests") 292 os.makedirs(temp_dir, exist_ok=True) 293 temp_name = f"{safe_rel}__snippet_{snippet.snippet_index}.py" 294 temp_path = os.path.join(temp_dir, temp_name) 295 296 # Prepend a line directive comment to facilitate mapping if needed 297 prelude = textwrap.dedent( 298 f""" 299 # File: {snippet.relative_path} 300 # Snippet: {snippet.snippet_index} 301 # Start line in source: {snippet.start_line} 302 """ 303 ).lstrip("\n") 304 305 with open(temp_path, "w", encoding="utf-8") as tf: 306 tf.write(prelude) 307 tf.write(snippet.code) 308 tf.write("\n") 309 310 try: 311 completed = subprocess.run( 312 [sys.executable, temp_path], 313 check=False, 314 cwd=cwd, 315 capture_output=True, 316 timeout=timeout_seconds, 317 text=True, 318 env={**os.environ, "PYTHONUNBUFFERED": "1"}, 319 ) 320 if completed.returncode == 0: 321 return ExecutionResult( 322 snippet=snippet, status=ExecutionStatus.PASSED, return_code=0, stdout=completed.stdout 323 ) 324 325 return ExecutionResult( 326 snippet=snippet, 327 status=ExecutionStatus.FAILED, 328 return_code=completed.returncode, 329 stdout=completed.stdout, 330 stderr=completed.stderr, 331 ) 332 except subprocess.TimeoutExpired as exc: 333 # Handle stderr which might be bytes or str 334 stderr_text = exc.stderr 335 if stderr_text is None: 336 stderr_text = "" 337 elif isinstance(stderr_text, bytes): 338 stderr_text = stderr_text.decode("utf-8", errors="replace") 339 stderr_text = stderr_text + f"\n[timeout after {timeout_seconds}s]" 340 341 # Handle stdout which might be bytes or str 342 stdout_text = exc.stdout 343 if stdout_text is not None and isinstance(stdout_text, bytes): 344 stdout_text = stdout_text.decode("utf-8", errors="replace") 345 346 return ExecutionResult( 347 snippet=snippet, 348 status=ExecutionStatus.FAILED, 349 reason=f"timeout after {timeout_seconds}s", 350 stdout=stdout_text, 351 stderr=stderr_text, 352 ) 353 354 355 def print_failure_annotation(result: ExecutionResult) -> None: 356 """Print a GitHub Actions error annotation so failures are clickable in CI logs.""" 357 rel = result.snippet.relative_path 358 line = result.snippet.start_line 359 # Escape newlines and percents per GH annotation rules 360 message = f"Doc snippet #{result.snippet.snippet_index} failed" 361 stderr_text = result.stderr.strip() if result.stderr else "" 362 stdout_text = result.stdout.strip() if result.stdout else "" 363 details = stderr_text or stdout_text 364 if result.reason: 365 details = f"{result.reason}\n\n" + details 366 details = details.replace("%", "%25").replace("\r", "%0D").replace("\n", "%0A") 367 sys.stdout.write(f"::error file={rel},line={line}::{message} — see details below%0A{details}\n") 368 369 370 def process_file_snippets( 371 file_rel: str, snippets: list[Snippet], repo_root: str, timeout_seconds: int, allow_unsafe: bool, verbose: bool 372 ) -> tuple[list[ExecutionResult], dict[str, int]]: 373 """Process all snippets in a single markdown file and return results and statistics.""" 374 if verbose: 375 print(f"[RUN] {file_rel}") 376 else: 377 print(f"Running {file_rel} ({len(snippets)} snippet(s))") 378 379 results: list[ExecutionResult] = [] 380 file_passed = file_failed = file_skipped = 0 381 382 for snippet in snippets: 383 result = run_snippet(snippet, timeout_seconds=timeout_seconds, cwd=repo_root, skip_unsafe=not allow_unsafe) 384 results.append(result) 385 386 if result.status == ExecutionStatus.PASSED: 387 file_passed += 1 388 if verbose: 389 print(f"[PASS] {snippet.relative_path}#snippet{snippet.snippet_index} (line {snippet.start_line})") 390 elif result.status == ExecutionStatus.SKIPPED: 391 file_skipped += 1 392 if verbose: 393 reason = f" — {result.reason}" if result.reason else "" 394 print(f"[SKIP] {snippet.relative_path}#snippet{snippet.snippet_index}{reason}") 395 else: 396 file_failed += 1 397 print_failure_annotation(result) 398 # Also print a concise human-readable failure line 399 print( 400 f"FAILED {snippet.relative_path}:snippet{snippet.snippet_index} " 401 f"(line {snippet.start_line}) — rc={result.return_code or 'N/A'}" 402 ) 403 if result.stdout and result.stdout.strip(): 404 print("--- stdout ---\n" + result.stdout) 405 if result.stderr and result.stderr.strip(): 406 print("--- stderr ---\n" + result.stderr) 407 stats = {"total": len(snippets), "passed": file_passed, "failed": file_failed, "skipped": file_skipped} 408 409 return results, stats 410 411 412 def main(argv: list[str] | None = None) -> int: 413 """CLI entry point for snippet execution.""" 414 parser = argparse.ArgumentParser(description="Test Python code snippets in Docusaurus docs") 415 parser.add_argument( 416 "targets", 417 nargs="*", 418 help="Optional positional list of files or directories to scan. If omitted, --paths is used.", 419 ) 420 parser.add_argument( 421 "--paths", 422 nargs="+", 423 default=["docs", "versioned_docs"], 424 help=( 425 "Fallback directories or files to scan when no positional targets are provided " 426 "(defaults to docs and versioned_docs)" 427 ), 428 ) 429 parser.add_argument("--timeout-seconds", type=int, default=600, help="Timeout per snippet execution (seconds)") 430 parser.add_argument( 431 "--allow-unsafe", action="store_true", help="Allow execution of snippets with potentially unsafe patterns" 432 ) 433 parser.add_argument("--verbose", action="store_true", help="Print verbose logs") 434 435 args = parser.parse_args(argv) 436 # Find haystack root 437 repo_root = str(Path(__file__).parent.parent.parent) 438 raw_paths = args.targets if args.targets else args.paths 439 scan_paths = [os.path.join(repo_root, p) if not os.path.isabs(p) else p for p in raw_paths] 440 441 md_files = find_markdown_files(scan_paths) 442 if args.verbose: 443 print(f"Repo root: {repo_root}") 444 print(f"Scanning targets: {', '.join(raw_paths)}") 445 print(f"Discovered {len(md_files)} Markdown files") 446 else: 447 print(f"Discovered {len(md_files)} Markdown files") 448 449 all_snippets: list[Snippet] = [] 450 for idx, fpath in enumerate(md_files, start=1): 451 rel = os.path.relpath(fpath, repo_root) 452 if args.verbose: 453 print(f"[SCAN {idx}/{len(md_files)}] {rel}") 454 snippets = extract_python_snippets(fpath, repo_root) 455 if snippets: 456 all_snippets.extend(snippets) 457 if args.verbose: 458 print(f"[FOUND] {rel}: {len(snippets)} python snippet(s)") 459 460 if args.verbose: 461 print(f"Extracted {len(all_snippets)} Python snippets") 462 else: 463 print(f"Total Python snippets found: {len(all_snippets)}") 464 465 total = len(all_snippets) 466 passed = 0 467 failed = 0 468 skipped = 0 469 results: list[ExecutionResult] = [] 470 471 # Ensure deterministic execution order grouped by file, then line 472 all_snippets.sort(key=lambda s: (s.relative_path, s.start_line, s.snippet_index)) 473 474 # Group by file 475 file_to_snippets: dict[str, list[Snippet]] = {} 476 for sn in all_snippets: 477 file_to_snippets.setdefault(sn.relative_path, []).append(sn) 478 479 file_stats: dict[str, dict[str, int]] = {} 480 for file_rel, snippets in file_to_snippets.items(): 481 file_results, stats = process_file_snippets( 482 file_rel=file_rel, 483 snippets=snippets, 484 repo_root=repo_root, 485 timeout_seconds=args.timeout_seconds, 486 allow_unsafe=args.allow_unsafe, 487 verbose=args.verbose, 488 ) 489 results.extend(file_results) 490 file_stats[file_rel] = stats 491 492 # Update totals 493 passed += stats["passed"] 494 failed += stats["failed"] 495 skipped += stats["skipped"] 496 497 print(f"Summary: total={total}, passed={passed}, failed={failed}, skipped={skipped}") 498 499 # Per-file summary - show only files with failures by default, all in verbose mode 500 print("Files summary:") 501 for file_rel in sorted(file_stats.keys()): 502 fs = file_stats[file_rel] 503 # Show file if it has failures or if verbose mode is on 504 if fs["failed"] > 0 or args.verbose: 505 print( 506 f" - {file_rel}: total={fs['total']}, passed={fs['passed']}, " 507 f"failed={fs['failed']}, skipped={fs['skipped']}" 508 ) 509 510 return 1 if failed > 0 else 0 511 512 513 if __name__ == "__main__": 514 sys.exit(main())