/ 05_Triton_GPU_Kernels / examples / integrate_pytorch.py
integrate_pytorch.py
  1  """
  2  PyTorch Integration Example.
  3  
  4  Demonstrates how to use Triton kernels in PyTorch models.
  5  """
  6  
  7  from __future__ import annotations
  8  
  9  import time
 10  
 11  import torch
 12  import torch.nn as nn
 13  
 14  from triton_kernels.pytorch_integration import TritonLinear, TritonLayerNorm, triton_gemm
 15  
 16  
 17  class SimpleTransformerBlock(nn.Module):
 18      """
 19      Simple Transformer block using Triton kernels.
 20  
 21      Demonstrates integration of Triton GEMM and LayerNorm
 22      into a real neural network architecture.
 23      """
 24  
 25      def __init__(self, hidden_size: int, intermediate_size: int):
 26          super().__init__()
 27          self.hidden_size = hidden_size
 28  
 29          # Triton-based layers
 30          self.attention_qkv = TritonLinear(hidden_size, 3 * hidden_size)
 31          self.attention_out = TritonLinear(hidden_size, hidden_size)
 32          self.mlp_up = TritonLinear(hidden_size, intermediate_size)
 33          self.mlp_down = TritonLinear(intermediate_size, hidden_size)
 34          self.layer_norm1 = TritonLayerNorm(hidden_size)
 35          self.layer_norm2 = TritonLayerNorm(hidden_size)
 36  
 37      def forward(self, x: torch.Tensor) -> torch.Tensor:
 38          """Forward pass through transformer block."""
 39          # Self-attention (simplified)
 40          residual = x
 41          x = self.layer_norm1(x)
 42          qkv = self.attention_qkv(x)
 43          # ... attention computation ...
 44          x = self.attention_out(qkv[:, :, : self.hidden_size])
 45          x = x + residual
 46  
 47          # MLP
 48          residual = x
 49          x = self.layer_norm2(x)
 50          x = self.mlp_up(x)
 51          x = torch.nn.functional.gelu(x)
 52          x = self.mlp_down(x)
 53          x = x + residual
 54  
 55          return x
 56  
 57  
 58  def benchmark_triton_vs_torch():
 59      """Compare Triton GEMM vs PyTorch matmul performance."""
 60      print("=" * 60)
 61      print("Triton GEMM vs PyTorch matmul Benchmark")
 62      print("=" * 60)
 63  
 64      sizes = [(1024, 1024, 1024), (2048, 2048, 2048), (4096, 4096, 4096)]
 65  
 66      for m, n, k in sizes:
 67          print(f"\nSize: M={m}, N={n}, K={k}")
 68  
 69          a = torch.randn(m, k, device="cuda", dtype=torch.float16)
 70          b = torch.randn(k, n, device="cuda", dtype=torch.float16)
 71  
 72          # Warmup
 73          for _ in range(10):
 74              triton_gemm(a, b)
 75              torch.matmul(a, b)
 76          torch.cuda.synchronize()
 77  
 78          # Benchmark Triton
 79          start = time.perf_counter()
 80          for _ in range(100):
 81              c_triton = triton_gemm(a, b)
 82          torch.cuda.synchronize()
 83          triton_time = (time.perf_counter() - start) / 100 * 1000
 84  
 85          # Benchmark PyTorch
 86          start = time.perf_counter()
 87          for _ in range(100):
 88              c_torch = torch.matmul(a, b)
 89          torch.cuda.synchronize()
 90          torch_time = (time.perf_counter() - start) / 100 * 1000
 91  
 92          print(f"  Triton: {triton_time:.3f}ms")
 93          print(f"  PyTorch: {torch_time:.3f}ms")
 94          print(f"  Ratio: {triton_time / torch_time:.2f}x")
 95  
 96  
 97  def demo_autograd():
 98      """Demonstrate autograd support."""
 99      print("\n" + "=" * 60)
100      print("Autograd Support Demo")
101      print("=" * 60)
102  
103      # Create tensors with gradients
104      a = torch.randn(64, 128, requires_grad=True, device="cuda", dtype=torch.float16)
105      b = torch.randn(128, 256, requires_grad=True, device="cuda", dtype=torch.float16)
106  
107      # Forward pass
108      c = triton_gemm(a, b)
109      print(f"\nForward: C shape = {c.shape}")
110  
111      # Backward pass
112      loss = c.sum()
113      loss.backward()
114  
115      print(f"Backward: A.grad shape = {a.grad.shape}")
116      print(f"Backward: B.grad shape = {b.grad.shape}")
117  
118  
119  def demo_linear_layer():
120      """Demonstrate TritonLinear layer."""
121      print("\n" + "=" * 60)
122      print("TritonLinear Layer Demo")
123      print("=" * 60)
124  
125      layer = TritonLinear(768, 1024).cuda().half()
126      x = torch.randn(32, 768, device="cuda", dtype=torch.float16)
127  
128      # Forward
129      y = layer(x)
130      print(f"\nInput shape: {x.shape}")
131      print(f"Output shape: {y.shape}")
132  
133      # Backward
134      loss = y.sum()
135      loss.backward()
136      print(f"Weight grad shape: {layer.weight.grad.shape}")
137  
138  
139  def demo_transformer_block():
140      """Demonstrate Transformer block with Triton kernels."""
141      print("\n" + "=" * 60)
142      print("Transformer Block Demo")
143      print("=" * 60)
144  
145      block = SimpleTransformerBlock(hidden_size=768, intermediate_size=3072).cuda().half()
146      x = torch.randn(32, 128, 768, device="cuda", dtype=torch.float16)
147  
148      # Forward
149      y = block(x)
150      print(f"\nInput shape: {x.shape}")
151      print(f"Output shape: {y.shape}")
152  
153      # Count parameters
154      total_params = sum(p.numel() for p in block.parameters())
155      print(f"Total parameters: {total_params:,}")
156  
157  
158  def demo_torch_compile():
159      """Demonstrate torch.compile integration."""
160      print("\n" + "=" * 60)
161      print("torch.compile Integration Demo")
162      print("=" * 60)
163  
164      try:
165          # Define a simple model
166          def model(x, weight):
167              return triton_gemm(x, weight)
168  
169          # Compile the model
170          compiled_model = torch.compile(model, mode="reduce-overhead")
171  
172          a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
173          b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
174  
175          # Run compiled model
176          y = compiled_model(a, b)
177          print(f"\nCompiled model output shape: {y.shape}")
178          print("torch.compile integration successful!")
179  
180      except Exception as e:
181          print(f"\ntorch.compile not fully supported yet: {e}")
182          print("This is expected for Triton kernels in some PyTorch versions.")
183  
184  
185  def main():
186      """Run all demos."""
187      print("=" * 60)
188      print("Triton PyTorch Integration Examples")
189      print("=" * 60)
190  
191      # Check CUDA
192      if not torch.cuda.is_available():
193          print("CUDA not available. Exiting.")
194          return
195  
196      # Run demos
197      demo_autograd()
198      demo_linear_layer()
199      demo_transformer_block()
200      demo_torch_compile()
201      benchmark_triton_vs_torch()
202  
203      print("\n" + "=" * 60)
204      print("All demos completed!")
205      print("=" * 60)
206  
207  
208  if __name__ == "__main__":
209      main()