integrate_pytorch.py
1 """ 2 PyTorch Integration Example. 3 4 Demonstrates how to use Triton kernels in PyTorch models. 5 """ 6 7 from __future__ import annotations 8 9 import time 10 11 import torch 12 import torch.nn as nn 13 14 from triton_kernels.pytorch_integration import TritonLinear, TritonLayerNorm, triton_gemm 15 16 17 class SimpleTransformerBlock(nn.Module): 18 """ 19 Simple Transformer block using Triton kernels. 20 21 Demonstrates integration of Triton GEMM and LayerNorm 22 into a real neural network architecture. 23 """ 24 25 def __init__(self, hidden_size: int, intermediate_size: int): 26 super().__init__() 27 self.hidden_size = hidden_size 28 29 # Triton-based layers 30 self.attention_qkv = TritonLinear(hidden_size, 3 * hidden_size) 31 self.attention_out = TritonLinear(hidden_size, hidden_size) 32 self.mlp_up = TritonLinear(hidden_size, intermediate_size) 33 self.mlp_down = TritonLinear(intermediate_size, hidden_size) 34 self.layer_norm1 = TritonLayerNorm(hidden_size) 35 self.layer_norm2 = TritonLayerNorm(hidden_size) 36 37 def forward(self, x: torch.Tensor) -> torch.Tensor: 38 """Forward pass through transformer block.""" 39 # Self-attention (simplified) 40 residual = x 41 x = self.layer_norm1(x) 42 qkv = self.attention_qkv(x) 43 # ... attention computation ... 44 x = self.attention_out(qkv[:, :, : self.hidden_size]) 45 x = x + residual 46 47 # MLP 48 residual = x 49 x = self.layer_norm2(x) 50 x = self.mlp_up(x) 51 x = torch.nn.functional.gelu(x) 52 x = self.mlp_down(x) 53 x = x + residual 54 55 return x 56 57 58 def benchmark_triton_vs_torch(): 59 """Compare Triton GEMM vs PyTorch matmul performance.""" 60 print("=" * 60) 61 print("Triton GEMM vs PyTorch matmul Benchmark") 62 print("=" * 60) 63 64 sizes = [(1024, 1024, 1024), (2048, 2048, 2048), (4096, 4096, 4096)] 65 66 for m, n, k in sizes: 67 print(f"\nSize: M={m}, N={n}, K={k}") 68 69 a = torch.randn(m, k, device="cuda", dtype=torch.float16) 70 b = torch.randn(k, n, device="cuda", dtype=torch.float16) 71 72 # Warmup 73 for _ in range(10): 74 triton_gemm(a, b) 75 torch.matmul(a, b) 76 torch.cuda.synchronize() 77 78 # Benchmark Triton 79 start = time.perf_counter() 80 for _ in range(100): 81 c_triton = triton_gemm(a, b) 82 torch.cuda.synchronize() 83 triton_time = (time.perf_counter() - start) / 100 * 1000 84 85 # Benchmark PyTorch 86 start = time.perf_counter() 87 for _ in range(100): 88 c_torch = torch.matmul(a, b) 89 torch.cuda.synchronize() 90 torch_time = (time.perf_counter() - start) / 100 * 1000 91 92 print(f" Triton: {triton_time:.3f}ms") 93 print(f" PyTorch: {torch_time:.3f}ms") 94 print(f" Ratio: {triton_time / torch_time:.2f}x") 95 96 97 def demo_autograd(): 98 """Demonstrate autograd support.""" 99 print("\n" + "=" * 60) 100 print("Autograd Support Demo") 101 print("=" * 60) 102 103 # Create tensors with gradients 104 a = torch.randn(64, 128, requires_grad=True, device="cuda", dtype=torch.float16) 105 b = torch.randn(128, 256, requires_grad=True, device="cuda", dtype=torch.float16) 106 107 # Forward pass 108 c = triton_gemm(a, b) 109 print(f"\nForward: C shape = {c.shape}") 110 111 # Backward pass 112 loss = c.sum() 113 loss.backward() 114 115 print(f"Backward: A.grad shape = {a.grad.shape}") 116 print(f"Backward: B.grad shape = {b.grad.shape}") 117 118 119 def demo_linear_layer(): 120 """Demonstrate TritonLinear layer.""" 121 print("\n" + "=" * 60) 122 print("TritonLinear Layer Demo") 123 print("=" * 60) 124 125 layer = TritonLinear(768, 1024).cuda().half() 126 x = torch.randn(32, 768, device="cuda", dtype=torch.float16) 127 128 # Forward 129 y = layer(x) 130 print(f"\nInput shape: {x.shape}") 131 print(f"Output shape: {y.shape}") 132 133 # Backward 134 loss = y.sum() 135 loss.backward() 136 print(f"Weight grad shape: {layer.weight.grad.shape}") 137 138 139 def demo_transformer_block(): 140 """Demonstrate Transformer block with Triton kernels.""" 141 print("\n" + "=" * 60) 142 print("Transformer Block Demo") 143 print("=" * 60) 144 145 block = SimpleTransformerBlock(hidden_size=768, intermediate_size=3072).cuda().half() 146 x = torch.randn(32, 128, 768, device="cuda", dtype=torch.float16) 147 148 # Forward 149 y = block(x) 150 print(f"\nInput shape: {x.shape}") 151 print(f"Output shape: {y.shape}") 152 153 # Count parameters 154 total_params = sum(p.numel() for p in block.parameters()) 155 print(f"Total parameters: {total_params:,}") 156 157 158 def demo_torch_compile(): 159 """Demonstrate torch.compile integration.""" 160 print("\n" + "=" * 60) 161 print("torch.compile Integration Demo") 162 print("=" * 60) 163 164 try: 165 # Define a simple model 166 def model(x, weight): 167 return triton_gemm(x, weight) 168 169 # Compile the model 170 compiled_model = torch.compile(model, mode="reduce-overhead") 171 172 a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16) 173 b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16) 174 175 # Run compiled model 176 y = compiled_model(a, b) 177 print(f"\nCompiled model output shape: {y.shape}") 178 print("torch.compile integration successful!") 179 180 except Exception as e: 181 print(f"\ntorch.compile not fully supported yet: {e}") 182 print("This is expected for Triton kernels in some PyTorch versions.") 183 184 185 def main(): 186 """Run all demos.""" 187 print("=" * 60) 188 print("Triton PyTorch Integration Examples") 189 print("=" * 60) 190 191 # Check CUDA 192 if not torch.cuda.is_available(): 193 print("CUDA not available. Exiting.") 194 return 195 196 # Run demos 197 demo_autograd() 198 demo_linear_layer() 199 demo_transformer_block() 200 demo_torch_compile() 201 benchmark_triton_vs_torch() 202 203 print("\n" + "=" * 60) 204 print("All demos completed!") 205 print("=" * 60) 206 207 208 if __name__ == "__main__": 209 main()