/ 05_Triton_GPU_Kernels / examples / benchmark_comparison.py
benchmark_comparison.py
  1  """
  2  Benchmark Comparison: Triton GEMM vs cuBLAS vs PyTorch.
  3  
  4  This script compares the performance of Triton GEMM implementation
  5  against reference implementations.
  6  
  7  Requirements: 5.1, 5.2
  8  """
  9  
 10  from __future__ import annotations
 11  
 12  import argparse
 13  import time
 14  from dataclasses import dataclass
 15  from typing import Any
 16  
 17  import numpy as np
 18  import torch
 19  
 20  # Triton imports (with graceful fallback)
 21  try:
 22      import triton
 23  
 24      TRITON_AVAILABLE = True
 25  except ImportError:
 26      TRITON_AVAILABLE = False
 27  
 28  from common.benchmark.timer import TimingResult, benchmark_function
 29  
 30  
 31  @dataclass
 32  class BenchmarkResult:
 33      """Result from a single benchmark run."""
 34  
 35      name: str
 36      size: tuple[int, int, int]
 37      latency_ms: float
 38      std_ms: float
 39      tflops: float
 40      notes: str = ""
 41  
 42  
 43  def get_tflops(m: int, n: int, k: int, latency_ms: float) -> float:
 44      """Calculate TFLOPS for GEMM operation."""
 45      # GEMM requires 2 * M * N * K FLOPs (multiply + add)
 46      flops = 2 * m * n * k
 47      tflops = flops / (latency_ms * 1e-3) / 1e12
 48      return tflops
 49  
 50  
 51  def benchmark_torch_matmul(a: torch.Tensor, b: torch.Tensor, warmup: int, iters: int) -> TimingResult:
 52      """Benchmark PyTorch matmul."""
 53  
 54      def run():
 55          torch.matmul(a, b)
 56          if a.is_cuda:
 57              torch.cuda.synchronize()
 58  
 59      return benchmark_function(run, warmup_iters=warmup, bench_iters=iters, sync_cuda=True)
 60  
 61  
 62  def benchmark_triton_matmul(
 63      a: torch.Tensor, b: torch.Tensor, warmup: int, iters: int, use_autotuned: bool = True
 64  ) -> TimingResult:
 65      """Benchmark Triton GEMM kernel."""
 66      from triton_kernels.gemm import matmul, matmul_autotuned
 67  
 68      func = matmul_autotuned if use_autotuned else matmul
 69  
 70      def run():
 71          func(a, b)
 72          torch.cuda.synchronize()
 73  
 74      return benchmark_function(run, warmup_iters=warmup, bench_iters=iters, sync_cuda=True)
 75  
 76  
 77  def benchmark_cublas(a: torch.Tensor, b: torch.Tensor, warmup: int, iters: int) -> TimingResult:
 78      """Benchmark cuBLAS via PyTorch (same as torch.matmul for CUDA tensors)."""
 79      return benchmark_torch_matmul(a, b, warmup, iters)
 80  
 81  
 82  def run_benchmarks(
 83      sizes: list[tuple[int, int, int]],
 84      warmup: int = 10,
 85      iters: int = 100,
 86      dtype: torch.dtype = torch.float16,
 87      device: str = "cuda",
 88  ) -> list[BenchmarkResult]:
 89      """
 90      Run benchmarks across different matrix sizes.
 91  
 92      Args:
 93          sizes: List of (M, N, K) tuples
 94          warmup: Number of warmup iterations
 95          iters: Number of benchmark iterations
 96          dtype: Data type for matrices
 97          device: Device to run on
 98  
 99      Returns:
100          List of benchmark results
101      """
102      results = []
103  
104      for m, n, k in sizes:
105          print(f"\nBenchmarking M={m}, N={n}, K={k}...")
106  
107          # Create random matrices
108          a = torch.randn(m, k, dtype=dtype, device=device)
109          b = torch.randn(k, n, dtype=dtype, device=device)
110  
111          # PyTorch/cuBLAS baseline
112          print("  Running PyTorch/cuBLAS...")
113          torch_result = benchmark_torch_matmul(a, b, warmup, iters)
114          results.append(
115              BenchmarkResult(
116                  name="PyTorch/cuBLAS",
117                  size=(m, n, k),
118                  latency_ms=torch_result.mean_ms,
119                  std_ms=torch_result.std_ms,
120                  tflops=get_tflops(m, n, k, torch_result.mean_ms),
121              )
122          )
123  
124          # Triton
125          if TRITON_AVAILABLE:
126              print("  Running Triton GEMM...")
127              triton_result = benchmark_triton_matmul(a, b, warmup, iters)
128              results.append(
129                  BenchmarkResult(
130                      name="Triton GEMM",
131                      size=(m, n, k),
132                      latency_ms=triton_result.mean_ms,
133                      std_ms=triton_result.std_ms,
134                      tflops=get_tflops(m, n, k, triton_result.mean_ms),
135                  )
136              )
137  
138          # Clear GPU memory
139          del a, b
140          if device == "cuda":
141              torch.cuda.empty_cache()
142  
143      return results
144  
145  
146  def print_results_table(results: list[BenchmarkResult]) -> None:
147      """Print results in a formatted table."""
148      print("\n" + "=" * 90)
149      print("GEMM Benchmark Results")
150      print("=" * 90)
151      print(
152          f"{'Size (M, N, K)':<20} | {'Implementation':<15} | {'Latency (ms)':<12} | "
153          f"{'TFLOPS':<10} | {'vs cuBLAS':<10}"
154      )
155      print("-" * 90)
156  
157      # Group by size
158      sizes_seen = {}
159      for r in results:
160          size_str = str(r.size)
161          if size_str not in sizes_seen:
162              sizes_seen[size_str] = r.latency_ms  # First is baseline
163  
164      for r in results:
165          size_str = str(r.size)
166          baseline = sizes_seen.get(size_str, r.latency_ms)
167          speedup = baseline / r.latency_ms if r.name != "PyTorch/cuBLAS" else 1.0
168  
169          print(
170              f"{str(r.size):<20} | {r.name:<15} | {r.latency_ms:>10.3f} ± {r.std_ms:.3f} | "
171              f"{r.tflops:>8.1f} | {speedup:>8.1%}"
172          )
173  
174      print("=" * 90)
175  
176  
177  def print_markdown_table(results: list[BenchmarkResult]) -> None:
178      """Print results as markdown table for README inclusion."""
179      print("\n```markdown")
180      print("| Size (M, N, K) | Implementation | Latency (ms) | TFLOPS | Efficiency |")
181      print("|----------------|----------------|--------------|--------|------------|")
182  
183      # Group by size
184      sizes_seen = {}
185      for r in results:
186          size_str = str(r.size)
187          if size_str not in sizes_seen:
188              sizes_seen[size_str] = r.latency_ms
189  
190      for r in results:
191          baseline = sizes_seen.get(str(r.size), r.latency_ms)
192          efficiency = r.tflops / 312 * 100 if r.tflops > 0 else 0  # A100 peak ~312 TFLOPS FP16
193  
194          print(
195              f"| {r.size} | {r.name} | {r.latency_ms:.3f} | {r.tflops:.1f} | {efficiency:.1f}% |"
196          )
197  
198      print("```")
199  
200  
201  def main() -> int:
202      """Main entry point."""
203      parser = argparse.ArgumentParser(description="GEMM Benchmark Comparison")
204      parser.add_argument("--sizes", type=str, default="128,1024,4096", help="Matrix sizes to test (comma-separated)")
205      parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations")
206      parser.add_argument("--iters", type=int, default=100, help="Benchmark iterations")
207      parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "float32"])
208      parser.add_argument("--device", type=str, default="cuda")
209      parser.add_argument("--markdown", action="store_true", help="Output as markdown table")
210      args = parser.parse_args()
211  
212      # Check CUDA availability
213      if args.device == "cuda" and not torch.cuda.is_available():
214          print("CUDA not available, falling back to CPU")
215          args.device = "cpu"
216  
217      # Parse sizes
218      size_values = [int(s) for s in args.sizes.split(",")]
219      # Generate square matrices for each size
220      sizes = [(s, s, s) for s in size_values]
221  
222      # Get dtype
223      dtype = torch.float16 if args.dtype == "float16" else torch.float32
224  
225      print("=" * 60)
226      print("Triton GEMM Benchmark")
227      print("=" * 60)
228      print(f"Device: {args.device}")
229      print(f"Dtype: {args.dtype}")
230      print(f"Sizes: {sizes}")
231      print(f"Warmup: {args.warmup}, Iterations: {args.iters}")
232  
233      # Run benchmarks
234      results = run_benchmarks(sizes, args.warmup, args.iters, dtype, args.device)
235  
236      # Print results
237      print_results_table(results)
238  
239      if args.markdown:
240          print_markdown_table(results)
241  
242      return 0
243  
244  
245  if __name__ == "__main__":
246      exit(main())