/ scripts / protocol_enforcement.py
protocol_enforcement.py
  1  #!/usr/bin/env python3
  2  """
  3  Protocol Enforcement - Belt and Suspenders for Critical Patterns
  4  
  5  This module provides infrastructure enforcement for patterns that were
  6  previously documentation-only. Follows the principle:
  7  "Infrastructure over suggestion" - proto-028
  8  
  9  Critical patterns enforced:
 10  1. test-before-ship (proto-029): Verify artifacts before committing
 11  2. validation-loop (proto-027): Re-run checks after fixes
 12  3. mandatory-phoenix-extraction (proto-005): Block without extraction
 13  
 14  Each enforcement has:
 15  - BELT: Primary check mechanism
 16  - SUSPENDERS: Secondary/redundant check
 17  - ABORT: Blocking gate when critical
 18  
 19  Usage:
 20      from scripts.protocol_enforcement import (
 21          test_before_ship,
 22          validation_loop,
 23          phoenix_extraction_gate,
 24      )
 25  """
 26  
 27  import subprocess
 28  import sys
 29  import os
 30  import json
 31  from pathlib import Path
 32  from datetime import datetime, timedelta
 33  from typing import Dict, List, Any, Optional, Tuple
 34  from dataclasses import dataclass, field
 35  
 36  # Find repo root
 37  REPO_ROOT = Path(__file__).parent.parent
 38  
 39  # =============================================================================
 40  # TEST-BEFORE-SHIP ENFORCEMENT (proto-029)
 41  # "Run the artifact before committing. Verify it works, not just compiles."
 42  # =============================================================================
 43  
 44  @dataclass
 45  class TestResult:
 46      """Result of a test-before-ship check."""
 47      file_path: str
 48      test_level: int  # 0-6 verification ladder
 49      test_passed: bool
 50      test_output: str
 51      error: Optional[str] = None
 52  
 53      @property
 54      def level_name(self) -> str:
 55          levels = {
 56              0: "Nothing",
 57              1: "Syntax",
 58              2: "Imports",
 59              3: "--help",
 60              4: "Basic run",
 61              5: "Edge cases",
 62              6: "Integration"
 63          }
 64          return levels.get(self.test_level, "Unknown")
 65  
 66  
 67  def test_python_syntax(file_path: Path) -> Tuple[bool, str]:
 68      """Level 1: Check Python syntax."""
 69      try:
 70          result = subprocess.run(
 71              [sys.executable, "-m", "py_compile", str(file_path)],
 72              capture_output=True,
 73              text=True,
 74              timeout=30
 75          )
 76          if result.returncode == 0:
 77              return True, "Syntax OK"
 78          return False, result.stderr
 79      except Exception as e:
 80          return False, str(e)
 81  
 82  
 83  def test_python_imports(file_path: Path) -> Tuple[bool, str]:
 84      """Level 2: Check imports resolve."""
 85      try:
 86          # Try to import the module
 87          result = subprocess.run(
 88              [sys.executable, "-c", f"import importlib.util; spec = importlib.util.spec_from_file_location('m', '{file_path}'); m = importlib.util.module_from_spec(spec)"],
 89              capture_output=True,
 90              text=True,
 91              timeout=30,
 92              cwd=str(REPO_ROOT)
 93          )
 94          if result.returncode == 0:
 95              return True, "Imports OK"
 96          return False, result.stderr
 97      except Exception as e:
 98          return False, str(e)
 99  
100  
101  def test_python_help(file_path: Path) -> Tuple[bool, str]:
102      """Level 3: Check --help works (if CLI)."""
103      try:
104          result = subprocess.run(
105              [sys.executable, str(file_path), "--help"],
106              capture_output=True,
107              text=True,
108              timeout=30,
109              cwd=str(REPO_ROOT)
110          )
111          # --help often returns 0, but even returncode 2 with help text is OK
112          if result.returncode == 0 or "usage" in result.stdout.lower():
113              return True, "CLI responds to --help"
114          # If no --help, that's OK - not all scripts are CLIs
115          if "unrecognized arguments" in result.stderr:
116              return True, "Not a CLI (no --help), syntax OK"
117          return False, result.stderr
118      except subprocess.TimeoutExpired:
119          return False, "Timeout on --help"
120      except Exception as e:
121          return False, str(e)
122  
123  
124  def test_before_ship(
125      file_path: Path,
126      min_level: int = 2,  # Minimum: imports must work
127      abort_on_fail: bool = False
128  ) -> TestResult:
129      """
130      BELT: Test a file before shipping/committing.
131  
132      Verification ladder:
133          0. Nothing (just committed)
134          1. Syntax (no parse errors)
135          2. Imports (dependencies resolve) - MINIMUM
136          3. --help (CLI at least starts)
137          4. Basic run (happy path)
138          5. Edge cases
139          6. Integration
140  
141      Args:
142          file_path: Path to the file to test
143          min_level: Minimum level required to pass (default: 2)
144          abort_on_fail: If True, raise exception on failure
145  
146      Returns:
147          TestResult with level achieved and pass/fail
148      """
149      file_path = Path(file_path)
150  
151      if not file_path.exists():
152          return TestResult(
153              file_path=str(file_path),
154              test_level=0,
155              test_passed=False,
156              test_output="File does not exist",
157              error="FileNotFoundError"
158          )
159  
160      # Only test Python files
161      if file_path.suffix != ".py":
162          return TestResult(
163              file_path=str(file_path),
164              test_level=1,
165              test_passed=True,
166              test_output="Non-Python file, skipping deep tests"
167          )
168  
169      achieved_level = 0
170      outputs = []
171  
172      # Level 1: Syntax
173      passed, output = test_python_syntax(file_path)
174      outputs.append(f"L1 Syntax: {output}")
175      if not passed:
176          result = TestResult(
177              file_path=str(file_path),
178              test_level=1,
179              test_passed=False,
180              test_output="\n".join(outputs),
181              error="Syntax error"
182          )
183          if abort_on_fail and min_level >= 1:
184              raise RuntimeError(f"test-before-ship FAILED: {result.error}")
185          return result
186      achieved_level = 1
187  
188      # Level 2: Imports
189      passed, output = test_python_imports(file_path)
190      outputs.append(f"L2 Imports: {output}")
191      if not passed:
192          result = TestResult(
193              file_path=str(file_path),
194              test_level=2,
195              test_passed=achieved_level >= min_level,
196              test_output="\n".join(outputs),
197              error="Import error"
198          )
199          if abort_on_fail and min_level >= 2:
200              raise RuntimeError(f"test-before-ship FAILED: {result.error}")
201          return result
202      achieved_level = 2
203  
204      # Level 3: --help
205      passed, output = test_python_help(file_path)
206      outputs.append(f"L3 --help: {output}")
207      if passed:
208          achieved_level = 3
209  
210      # Levels 4-6 require specific test infrastructure, skip for now
211  
212      return TestResult(
213          file_path=str(file_path),
214          test_level=achieved_level,
215          test_passed=achieved_level >= min_level,
216          test_output="\n".join(outputs)
217      )
218  
219  
220  def test_staged_files(min_level: int = 2) -> List[TestResult]:
221      """
222      SUSPENDERS: Test all staged Python files before commit.
223  
224      Returns list of TestResults for each staged .py file.
225      """
226      try:
227          result = subprocess.run(
228              ["git", "diff", "--cached", "--name-only", "--diff-filter=ACM"],
229              capture_output=True,
230              text=True,
231              cwd=str(REPO_ROOT)
232          )
233          staged_files = [f.strip() for f in result.stdout.split("\n") if f.strip()]
234      except Exception:
235          return []
236  
237      results = []
238      for file in staged_files:
239          if file.endswith(".py"):
240              file_path = REPO_ROOT / file
241              results.append(test_before_ship(file_path, min_level))
242  
243      return results
244  
245  
246  # =============================================================================
247  # VALIDATION-LOOP ENFORCEMENT (proto-027)
248  # "After any fix, re-run the check to verify improvement. Report the delta."
249  # =============================================================================
250  
251  @dataclass
252  class ValidationResult:
253      """Result of a validation loop check."""
254      check_name: str
255      before_value: Any
256      after_value: Any
257      delta: Any
258      improved: bool
259      validation_passed: bool
260      details: Dict[str, Any] = field(default_factory=dict)
261  
262  
263  def validation_loop(
264      check_function: callable,
265      check_name: str,
266      baseline: Any = None,
267      improvement_required: bool = True,
268      abort_on_regression: bool = False
269  ) -> ValidationResult:
270      """
271      BELT: Run a validation loop to verify a fix worked.
272  
273      Pattern: Measure (before) -> Fix -> Measure (after) -> Report delta
274  
275      Args:
276          check_function: Function that returns a measurable value
277          check_name: Name of this check for reporting
278          baseline: Previous value (if None, just returns current)
279          improvement_required: If True, after must be <= before
280          abort_on_regression: If True, raise exception on regression
281  
282      Returns:
283          ValidationResult with before/after/delta
284      """
285      # Get current value
286      try:
287          current_value = check_function()
288      except Exception as e:
289          return ValidationResult(
290              check_name=check_name,
291              before_value=baseline,
292              after_value=None,
293              delta=None,
294              improved=False,
295              validation_passed=False,
296              details={"error": str(e)}
297          )
298  
299      if baseline is None:
300          # No baseline, just return current as baseline for next time
301          return ValidationResult(
302              check_name=check_name,
303              before_value=None,
304              after_value=current_value,
305              delta=None,
306              improved=True,  # No baseline = assumed OK
307              validation_passed=True,
308              details={"note": "No baseline provided, establishing baseline"}
309          )
310  
311      # Calculate delta
312      try:
313          if isinstance(current_value, (int, float)) and isinstance(baseline, (int, float)):
314              delta = current_value - baseline
315              improved = delta <= 0  # Lower is better for violations/errors
316          elif isinstance(current_value, dict) and isinstance(baseline, dict):
317              delta = {k: current_value.get(k, 0) - baseline.get(k, 0)
318                      for k in set(current_value.keys()) | set(baseline.keys())}
319              improved = sum(delta.values()) <= 0
320          else:
321              delta = f"{baseline} -> {current_value}"
322              improved = current_value <= baseline if hasattr(current_value, '__le__') else True
323      except Exception:
324          delta = "Unable to compute"
325          improved = False
326  
327      validation_passed = improved if improvement_required else True
328  
329      result = ValidationResult(
330          check_name=check_name,
331          before_value=baseline,
332          after_value=current_value,
333          delta=delta,
334          improved=improved,
335          validation_passed=validation_passed
336      )
337  
338      if abort_on_regression and not validation_passed:
339          raise RuntimeError(
340              f"validation-loop REGRESSION: {check_name} got worse "
341              f"({baseline} -> {current_value}, delta={delta})"
342          )
343  
344      return result
345  
346  
347  class ValidationTracker:
348      """
349      SUSPENDERS: Track validation state across multiple checks.
350  
351      Maintains baselines and validates improvements over time.
352      """
353  
354      def __init__(self, tracker_path: Path = None):
355          self.tracker_path = tracker_path or (REPO_ROOT / "sessions" / "validation-tracker.json")
356          self.baselines: Dict[str, Any] = {}
357          self._load()
358  
359      def _load(self):
360          """Load baselines from disk."""
361          if self.tracker_path.exists():
362              try:
363                  self.baselines = json.loads(self.tracker_path.read_text())
364              except Exception:
365                  self.baselines = {}
366  
367      def _save(self):
368          """Save baselines to disk."""
369          self.tracker_path.parent.mkdir(parents=True, exist_ok=True)
370          self.tracker_path.write_text(json.dumps(self.baselines, indent=2, default=str))
371  
372      def check_and_update(
373          self,
374          check_name: str,
375          check_function: callable,
376          improvement_required: bool = True
377      ) -> ValidationResult:
378          """Run validation loop with persistent baseline tracking."""
379          baseline = self.baselines.get(check_name)
380          result = validation_loop(
381              check_function=check_function,
382              check_name=check_name,
383              baseline=baseline,
384              improvement_required=improvement_required
385          )
386  
387          # Update baseline if improved or first run
388          if result.validation_passed or baseline is None:
389              self.baselines[check_name] = result.after_value
390              self._save()
391  
392          return result
393  
394  
395  # =============================================================================
396  # MANDATORY-PHOENIX-EXTRACTION ENFORCEMENT (proto-005)
397  # "Never let context decay gradually. Extract SHARPLY before compression."
398  # =============================================================================
399  
400  # Phoenix state staleness thresholds
401  PHOENIX_WARNING_MINUTES = 30
402  PHOENIX_CRITICAL_MINUTES = 60
403  PHOENIX_BLOCKING_MINUTES = 120  # 2 hours = must extract before proceeding
404  
405  
406  @dataclass
407  class PhoenixGateResult:
408      """Result of a phoenix extraction gate check."""
409      can_proceed: bool
410      staleness_minutes: float
411      status: str  # OK, WARNING, CRITICAL, BLOCKED
412      message: str
413      last_updated: Optional[datetime] = None
414  
415  
416  def get_phoenix_staleness() -> Tuple[float, Optional[datetime]]:
417      """Get staleness of LIVE-COMPRESSION.md in minutes."""
418      live_compression = REPO_ROOT / "sessions" / "LIVE-COMPRESSION.md"
419  
420      if not live_compression.exists():
421          return float('inf'), None
422  
423      try:
424          mtime = datetime.fromtimestamp(live_compression.stat().st_mtime)
425          staleness = (datetime.now() - mtime).total_seconds() / 60
426          return staleness, mtime
427      except Exception:
428          return float('inf'), None
429  
430  
431  def phoenix_extraction_gate(
432      operation: str = "unknown",
433      abort_on_critical: bool = False,
434      abort_on_blocked: bool = True
435  ) -> PhoenixGateResult:
436      """
437      BELT: Gate that blocks operations if phoenix state is too stale.
438  
439      Enforces mandatory-phoenix-extraction by requiring recent phoenix
440      state before allowing context-intensive operations.
441  
442      Args:
443          operation: Name of operation being gated
444          abort_on_critical: If True, abort at critical threshold
445          abort_on_blocked: If True, abort at blocked threshold
446  
447      Returns:
448          PhoenixGateResult with proceed/block decision
449      """
450      staleness, last_updated = get_phoenix_staleness()
451  
452      if staleness < PHOENIX_WARNING_MINUTES:
453          result = PhoenixGateResult(
454              can_proceed=True,
455              staleness_minutes=staleness,
456              status="OK",
457              message=f"Phoenix state is fresh ({staleness:.0f}m old)",
458              last_updated=last_updated
459          )
460      elif staleness < PHOENIX_CRITICAL_MINUTES:
461          result = PhoenixGateResult(
462              can_proceed=True,
463              staleness_minutes=staleness,
464              status="WARNING",
465              message=f"Phoenix state getting stale ({staleness:.0f}m). Consider updating.",
466              last_updated=last_updated
467          )
468      elif staleness < PHOENIX_BLOCKING_MINUTES:
469          result = PhoenixGateResult(
470              can_proceed=not abort_on_critical,
471              staleness_minutes=staleness,
472              status="CRITICAL",
473              message=f"Phoenix state is stale ({staleness:.0f}m). Update before {operation}.",
474              last_updated=last_updated
475          )
476          if abort_on_critical:
477              raise RuntimeError(f"phoenix-extraction-gate BLOCKED: {result.message}")
478      else:
479          result = PhoenixGateResult(
480              can_proceed=not abort_on_blocked,
481              staleness_minutes=staleness,
482              status="BLOCKED",
483              message=f"Phoenix state too old ({staleness:.0f}m). MUST extract before {operation}.",
484              last_updated=last_updated
485          )
486          if abort_on_blocked:
487              raise RuntimeError(f"phoenix-extraction-gate BLOCKED: {result.message}")
488  
489      return result
490  
491  
492  def phoenix_extraction_reminder(interval_minutes: int = 30) -> Optional[str]:
493      """
494      SUSPENDERS: Periodic reminder to update phoenix state.
495  
496      Returns reminder message if phoenix needs updating, None otherwise.
497      """
498      staleness, _ = get_phoenix_staleness()
499  
500      if staleness >= PHOENIX_CRITICAL_MINUTES:
501          return f"CRITICAL: Phoenix state is {staleness:.0f}m old. Update LIVE-COMPRESSION.md NOW."
502      elif staleness >= PHOENIX_WARNING_MINUTES:
503          return f"WARNING: Phoenix state is {staleness:.0f}m old. Consider updating LIVE-COMPRESSION.md."
504  
505      return None
506  
507  
508  # =============================================================================
509  # UNIFIED PREFLIGHT CHECK
510  # =============================================================================
511  
512  @dataclass
513  class PreflightResult:
514      """Result of unified preflight check."""
515      all_passed: bool
516      test_before_ship: Optional[List[TestResult]] = None
517      validation_loop: Optional[ValidationResult] = None
518      phoenix_gate: Optional[PhoenixGateResult] = None
519      messages: List[str] = field(default_factory=list)
520  
521  
522  def run_preflight(
523      check_staged_files: bool = True,
524      check_phoenix: bool = True,
525      operation: str = "operation"
526  ) -> PreflightResult:
527      """
528      Run unified preflight check combining all enforcement mechanisms.
529  
530      Belt + Suspenders for all critical protocols in one call.
531      """
532      result = PreflightResult(all_passed=True)
533  
534      # Test-before-ship: Check staged files
535      if check_staged_files:
536          test_results = test_staged_files(min_level=2)
537          result.test_before_ship = test_results
538  
539          failed = [r for r in test_results if not r.test_passed]
540          if failed:
541              result.all_passed = False
542              for r in failed:
543                  result.messages.append(
544                      f"test-before-ship FAILED: {r.file_path} at level {r.test_level}"
545                  )
546  
547      # Phoenix gate
548      if check_phoenix:
549          phoenix = phoenix_extraction_gate(operation=operation, abort_on_blocked=False)
550          result.phoenix_gate = phoenix
551  
552          if not phoenix.can_proceed:
553              result.all_passed = False
554              result.messages.append(f"phoenix-gate {phoenix.status}: {phoenix.message}")
555  
556      return result
557  
558  
559  # =============================================================================
560  # CLI
561  # =============================================================================
562  
563  def main():
564      """Run protocol enforcement checks."""
565      import argparse
566  
567      parser = argparse.ArgumentParser(
568          description="Protocol Enforcement - Belt and Suspenders",
569          formatter_class=argparse.RawDescriptionHelpFormatter,
570          epilog="""
571  Examples:
572      %(prog)s --test-staged          Test all staged Python files
573      %(prog)s --phoenix              Check phoenix extraction gate
574      %(prog)s --preflight            Run all preflight checks
575      %(prog)s --test-file FILE       Test a specific file
576          """
577      )
578  
579      parser.add_argument('--test-staged', action='store_true',
580                         help='Test all staged Python files')
581      parser.add_argument('--test-file', type=Path,
582                         help='Test a specific file')
583      parser.add_argument('--phoenix', action='store_true',
584                         help='Check phoenix extraction gate')
585      parser.add_argument('--preflight', action='store_true',
586                         help='Run all preflight checks')
587      parser.add_argument('--min-level', type=int, default=2,
588                         help='Minimum test level required (default: 2)')
589  
590      args = parser.parse_args()
591  
592      if args.preflight or (not args.test_staged and not args.test_file and not args.phoenix):
593          # Default: run preflight
594          print("=" * 60)
595          print("PROTOCOL ENFORCEMENT PREFLIGHT")
596          print("=" * 60)
597  
598          result = run_preflight()
599  
600          if result.test_before_ship:
601              print(f"\nTest-Before-Ship: {len(result.test_before_ship)} files checked")
602              for tr in result.test_before_ship:
603                  status = "PASS" if tr.test_passed else "FAIL"
604                  print(f"  [{status}] {tr.file_path} - Level {tr.test_level} ({tr.level_name})")
605  
606          if result.phoenix_gate:
607              print(f"\nPhoenix Gate: {result.phoenix_gate.status}")
608              print(f"  {result.phoenix_gate.message}")
609  
610          print("\n" + "-" * 60)
611          if result.all_passed:
612              print("ALL CHECKS PASSED")
613              sys.exit(0)
614          else:
615              print("CHECKS FAILED:")
616              for msg in result.messages:
617                  print(f"  - {msg}")
618              sys.exit(1)
619  
620      if args.test_staged:
621          print("Testing staged Python files...")
622          results = test_staged_files(min_level=args.min_level)
623  
624          if not results:
625              print("No staged Python files found")
626              sys.exit(0)
627  
628          all_passed = True
629          for r in results:
630              status = "PASS" if r.test_passed else "FAIL"
631              print(f"[{status}] {r.file_path} - Level {r.test_level} ({r.level_name})")
632              if not r.test_passed:
633                  all_passed = False
634                  print(f"       {r.test_output}")
635  
636          sys.exit(0 if all_passed else 1)
637  
638      if args.test_file:
639          print(f"Testing {args.test_file}...")
640          result = test_before_ship(args.test_file, min_level=args.min_level)
641  
642          status = "PASS" if result.test_passed else "FAIL"
643          print(f"[{status}] Level {result.test_level} ({result.level_name})")
644          print(result.test_output)
645  
646          sys.exit(0 if result.test_passed else 1)
647  
648      if args.phoenix:
649          print("Checking phoenix extraction gate...")
650          result = phoenix_extraction_gate(operation="manual check")
651  
652          print(f"Status: {result.status}")
653          print(f"Staleness: {result.staleness_minutes:.0f} minutes")
654          print(f"Message: {result.message}")
655  
656          sys.exit(0 if result.can_proceed else 2)
657  
658  
659  if __name__ == "__main__":
660      main()