benchmark_comparison.py
1 """ 2 Benchmark Comparison: Triton GEMM vs cuBLAS vs PyTorch. 3 4 This script compares the performance of Triton GEMM implementation 5 against reference implementations. 6 7 Requirements: 5.1, 5.2 8 """ 9 10 from __future__ import annotations 11 12 import argparse 13 import time 14 from dataclasses import dataclass 15 from typing import Any 16 17 import numpy as np 18 import torch 19 20 # Triton imports (with graceful fallback) 21 try: 22 import triton 23 24 TRITON_AVAILABLE = True 25 except ImportError: 26 TRITON_AVAILABLE = False 27 28 from common.benchmark.timer import TimingResult, benchmark_function 29 30 31 @dataclass 32 class BenchmarkResult: 33 """Result from a single benchmark run.""" 34 35 name: str 36 size: tuple[int, int, int] 37 latency_ms: float 38 std_ms: float 39 tflops: float 40 notes: str = "" 41 42 43 def get_tflops(m: int, n: int, k: int, latency_ms: float) -> float: 44 """Calculate TFLOPS for GEMM operation.""" 45 # GEMM requires 2 * M * N * K FLOPs (multiply + add) 46 flops = 2 * m * n * k 47 tflops = flops / (latency_ms * 1e-3) / 1e12 48 return tflops 49 50 51 def benchmark_torch_matmul(a: torch.Tensor, b: torch.Tensor, warmup: int, iters: int) -> TimingResult: 52 """Benchmark PyTorch matmul.""" 53 54 def run(): 55 torch.matmul(a, b) 56 if a.is_cuda: 57 torch.cuda.synchronize() 58 59 return benchmark_function(run, warmup_iters=warmup, bench_iters=iters, sync_cuda=True) 60 61 62 def benchmark_triton_matmul( 63 a: torch.Tensor, b: torch.Tensor, warmup: int, iters: int, use_autotuned: bool = True 64 ) -> TimingResult: 65 """Benchmark Triton GEMM kernel.""" 66 from triton_kernels.gemm import matmul, matmul_autotuned 67 68 func = matmul_autotuned if use_autotuned else matmul 69 70 def run(): 71 func(a, b) 72 torch.cuda.synchronize() 73 74 return benchmark_function(run, warmup_iters=warmup, bench_iters=iters, sync_cuda=True) 75 76 77 def benchmark_cublas(a: torch.Tensor, b: torch.Tensor, warmup: int, iters: int) -> TimingResult: 78 """Benchmark cuBLAS via PyTorch (same as torch.matmul for CUDA tensors).""" 79 return benchmark_torch_matmul(a, b, warmup, iters) 80 81 82 def run_benchmarks( 83 sizes: list[tuple[int, int, int]], 84 warmup: int = 10, 85 iters: int = 100, 86 dtype: torch.dtype = torch.float16, 87 device: str = "cuda", 88 ) -> list[BenchmarkResult]: 89 """ 90 Run benchmarks across different matrix sizes. 91 92 Args: 93 sizes: List of (M, N, K) tuples 94 warmup: Number of warmup iterations 95 iters: Number of benchmark iterations 96 dtype: Data type for matrices 97 device: Device to run on 98 99 Returns: 100 List of benchmark results 101 """ 102 results = [] 103 104 for m, n, k in sizes: 105 print(f"\nBenchmarking M={m}, N={n}, K={k}...") 106 107 # Create random matrices 108 a = torch.randn(m, k, dtype=dtype, device=device) 109 b = torch.randn(k, n, dtype=dtype, device=device) 110 111 # PyTorch/cuBLAS baseline 112 print(" Running PyTorch/cuBLAS...") 113 torch_result = benchmark_torch_matmul(a, b, warmup, iters) 114 results.append( 115 BenchmarkResult( 116 name="PyTorch/cuBLAS", 117 size=(m, n, k), 118 latency_ms=torch_result.mean_ms, 119 std_ms=torch_result.std_ms, 120 tflops=get_tflops(m, n, k, torch_result.mean_ms), 121 ) 122 ) 123 124 # Triton 125 if TRITON_AVAILABLE: 126 print(" Running Triton GEMM...") 127 triton_result = benchmark_triton_matmul(a, b, warmup, iters) 128 results.append( 129 BenchmarkResult( 130 name="Triton GEMM", 131 size=(m, n, k), 132 latency_ms=triton_result.mean_ms, 133 std_ms=triton_result.std_ms, 134 tflops=get_tflops(m, n, k, triton_result.mean_ms), 135 ) 136 ) 137 138 # Clear GPU memory 139 del a, b 140 if device == "cuda": 141 torch.cuda.empty_cache() 142 143 return results 144 145 146 def print_results_table(results: list[BenchmarkResult]) -> None: 147 """Print results in a formatted table.""" 148 print("\n" + "=" * 90) 149 print("GEMM Benchmark Results") 150 print("=" * 90) 151 print( 152 f"{'Size (M, N, K)':<20} | {'Implementation':<15} | {'Latency (ms)':<12} | " 153 f"{'TFLOPS':<10} | {'vs cuBLAS':<10}" 154 ) 155 print("-" * 90) 156 157 # Group by size 158 sizes_seen = {} 159 for r in results: 160 size_str = str(r.size) 161 if size_str not in sizes_seen: 162 sizes_seen[size_str] = r.latency_ms # First is baseline 163 164 for r in results: 165 size_str = str(r.size) 166 baseline = sizes_seen.get(size_str, r.latency_ms) 167 speedup = baseline / r.latency_ms if r.name != "PyTorch/cuBLAS" else 1.0 168 169 print( 170 f"{str(r.size):<20} | {r.name:<15} | {r.latency_ms:>10.3f} ± {r.std_ms:.3f} | " 171 f"{r.tflops:>8.1f} | {speedup:>8.1%}" 172 ) 173 174 print("=" * 90) 175 176 177 def print_markdown_table(results: list[BenchmarkResult]) -> None: 178 """Print results as markdown table for README inclusion.""" 179 print("\n```markdown") 180 print("| Size (M, N, K) | Implementation | Latency (ms) | TFLOPS | Efficiency |") 181 print("|----------------|----------------|--------------|--------|------------|") 182 183 # Group by size 184 sizes_seen = {} 185 for r in results: 186 size_str = str(r.size) 187 if size_str not in sizes_seen: 188 sizes_seen[size_str] = r.latency_ms 189 190 for r in results: 191 baseline = sizes_seen.get(str(r.size), r.latency_ms) 192 efficiency = r.tflops / 312 * 100 if r.tflops > 0 else 0 # A100 peak ~312 TFLOPS FP16 193 194 print( 195 f"| {r.size} | {r.name} | {r.latency_ms:.3f} | {r.tflops:.1f} | {efficiency:.1f}% |" 196 ) 197 198 print("```") 199 200 201 def main() -> int: 202 """Main entry point.""" 203 parser = argparse.ArgumentParser(description="GEMM Benchmark Comparison") 204 parser.add_argument("--sizes", type=str, default="128,1024,4096", help="Matrix sizes to test (comma-separated)") 205 parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations") 206 parser.add_argument("--iters", type=int, default=100, help="Benchmark iterations") 207 parser.add_argument("--dtype", type=str, default="float16", choices=["float16", "float32"]) 208 parser.add_argument("--device", type=str, default="cuda") 209 parser.add_argument("--markdown", action="store_true", help="Output as markdown table") 210 args = parser.parse_args() 211 212 # Check CUDA availability 213 if args.device == "cuda" and not torch.cuda.is_available(): 214 print("CUDA not available, falling back to CPU") 215 args.device = "cpu" 216 217 # Parse sizes 218 size_values = [int(s) for s in args.sizes.split(",")] 219 # Generate square matrices for each size 220 sizes = [(s, s, s) for s in size_values] 221 222 # Get dtype 223 dtype = torch.float16 if args.dtype == "float16" else torch.float32 224 225 print("=" * 60) 226 print("Triton GEMM Benchmark") 227 print("=" * 60) 228 print(f"Device: {args.device}") 229 print(f"Dtype: {args.dtype}") 230 print(f"Sizes: {sizes}") 231 print(f"Warmup: {args.warmup}, Iterations: {args.iters}") 232 233 # Run benchmarks 234 results = run_benchmarks(sizes, args.warmup, args.iters, dtype, args.device) 235 236 # Print results 237 print_results_table(results) 238 239 if args.markdown: 240 print_markdown_table(results) 241 242 return 0 243 244 245 if __name__ == "__main__": 246 exit(main())