quick_start.py
1 """ 2 Triton Quick Start Example. 3 4 A simple introduction to Triton GPU programming. 5 This example demonstrates: 6 1. Writing a Triton kernel 7 2. Launching the kernel 8 3. Verifying correctness 9 """ 10 11 from __future__ import annotations 12 13 import torch 14 import triton 15 import triton.language as tl 16 17 18 # ============================================================================= 19 # Example 1: Vector Addition (Hello World of GPU Programming) 20 # ============================================================================= 21 22 @triton.jit 23 def vector_add_kernel( 24 x_ptr, # Pointer to first input vector 25 y_ptr, # Pointer to second input vector 26 output_ptr, # Pointer to output vector 27 N: tl.constexpr, # Number of elements 28 BLOCK_SIZE: tl.constexpr, # Block size (compile-time constant) 29 ): 30 """ 31 Add two vectors element-wise: output[i] = x[i] + y[i] 32 33 This is the "Hello World" of GPU programming. 34 35 Key concepts: 36 - tl.program_id(): Get the ID of this program (block) 37 - tl.arange(): Generate a range of indices 38 - tl.load()/tl.store(): Read/write memory 39 - mask=: Handle boundary conditions 40 """ 41 # Step 1: Get the program ID (which block are we?) 42 pid = tl.program_id(0) # 0 because we have a 1D grid 43 44 # Step 2: Calculate which elements this block processes 45 block_start = pid * BLOCK_SIZE 46 offsets = block_start + tl.arange(0, BLOCK_SIZE) 47 48 # Step 3: Create a mask for boundary handling 49 # (in case N is not a multiple of BLOCK_SIZE) 50 mask = offsets < N 51 52 # Step 4: Load input elements 53 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 54 y = tl.load(y_ptr + offsets, mask=mask, other=0.0) 55 56 # Step 5: Compute the result 57 output = x + y 58 59 # Step 6: Store the result 60 tl.store(output_ptr + offsets, output, mask=mask) 61 62 63 def vector_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 64 """ 65 Add two vectors using Triton. 66 67 Args: 68 x: First input vector 69 y: Second input vector 70 71 Returns: 72 Sum of x and y 73 """ 74 # Validate inputs 75 assert x.shape == y.shape 76 N = x.numel() 77 78 # Allocate output 79 output = torch.empty_like(x) 80 81 # Choose block size 82 BLOCK_SIZE = 1024 83 84 # Calculate grid size (number of blocks needed) 85 grid = (triton.cdiv(N, BLOCK_SIZE),) 86 87 # Launch the kernel 88 # The [] syntax is how we pass the grid dimensions 89 vector_add_kernel[grid]( 90 x, y, output, 91 N=N, 92 BLOCK_SIZE=BLOCK_SIZE, 93 ) 94 95 return output 96 97 98 # ============================================================================= 99 # Example 2: Vector Scaling (Scalar Multiplication) 100 # ============================================================================= 101 102 @triton.jit 103 def vector_scale_kernel( 104 x_ptr, 105 output_ptr, 106 alpha, # Scalar multiplier 107 N: tl.constexpr, 108 BLOCK_SIZE: tl.constexpr, 109 ): 110 """Scale a vector: output[i] = alpha * x[i]""" 111 pid = tl.program_id(0) 112 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 113 mask = offsets < N 114 115 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 116 output = alpha * x 117 tl.store(output_ptr + offsets, output, mask=mask) 118 119 120 def vector_scale(x: torch.Tensor, alpha: float) -> torch.Tensor: 121 """Scale a vector by a scalar.""" 122 N = x.numel() 123 output = torch.empty_like(x) 124 BLOCK_SIZE = 1024 125 grid = (triton.cdiv(N, BLOCK_SIZE),) 126 127 vector_scale_kernel[grid](x, output, alpha, N, BLOCK_SIZE) 128 return output 129 130 131 # ============================================================================= 132 # Example 3: Element-wise Operations 133 # ============================================================================= 134 135 @triton.jit 136 def elementwise_kernel( 137 x_ptr, 138 output_ptr, 139 N: tl.constexpr, 140 BLOCK_SIZE: tl.constexpr, 141 OP: tl.constexpr, # Operation type (0=square, 1=sqrt, 2=exp, 3=relu) 142 ): 143 """ 144 Apply various element-wise operations. 145 146 Demonstrates using tl.constexpr for compile-time branching. 147 """ 148 pid = tl.program_id(0) 149 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 150 mask = offsets < N 151 152 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 153 154 # Apply operation based on OP 155 if OP == 0: # Square 156 output = x * x 157 elif OP == 1: # Square root 158 output = tl.sqrt(tl.abs(x)) # abs to handle negative numbers 159 elif OP == 2: # Exponential 160 output = tl.exp(x) 161 elif OP == 3: # ReLU 162 output = tl.where(x > 0, x, 0.0) 163 else: 164 output = x 165 166 tl.store(output_ptr + offsets, output, mask=mask) 167 168 169 def square(x: torch.Tensor) -> torch.Tensor: 170 """Square each element.""" 171 output = torch.empty_like(x) 172 grid = (triton.cdiv(x.numel(), 1024),) 173 elementwise_kernel[grid](x, output, x.numel(), 1024, OP=0) 174 return output 175 176 177 def relu(x: torch.Tensor) -> torch.Tensor: 178 """Apply ReLU activation.""" 179 output = torch.empty_like(x) 180 grid = (triton.cdiv(x.numel(), 1024),) 181 elementwise_kernel[grid](x, output, x.numel(), 1024, OP=3) 182 return output 183 184 185 # ============================================================================= 186 # Main: Run Examples 187 # ============================================================================= 188 189 def main(): 190 """Run all quick start examples.""" 191 print("=" * 60) 192 print("Triton Quick Start Examples") 193 print("=" * 60) 194 195 # Check for CUDA 196 if not torch.cuda.is_available(): 197 print("\nCUDA not available. Some examples will be skipped.") 198 device = "cpu" 199 else: 200 device = "cuda" 201 202 print(f"\nUsing device: {device}") 203 204 # ------------------------------------------------------------------------- 205 # Example 1: Vector Addition 206 # ------------------------------------------------------------------------- 207 print("\n" + "-" * 40) 208 print("Example 1: Vector Addition") 209 print("-" * 40) 210 211 N = 10000 212 x = torch.randn(N, device=device, dtype=torch.float32) 213 y = torch.randn(N, device=device, dtype=torch.float32) 214 215 # Triton result 216 result_triton = vector_add(x, y) 217 218 # PyTorch reference 219 result_torch = x + y 220 221 # Verify 222 max_diff = (result_triton - result_torch).abs().max().item() 223 print(f"Vector size: {N}") 224 print(f"Max difference: {max_diff:.6f}") 225 print(f"Correct: {max_diff < 1e-5}") 226 227 # ------------------------------------------------------------------------- 228 # Example 2: Vector Scaling 229 # ------------------------------------------------------------------------- 230 print("\n" + "-" * 40) 231 print("Example 2: Vector Scaling") 232 print("-" * 40) 233 234 alpha = 2.5 235 result_triton = vector_scale(x, alpha) 236 result_torch = alpha * x 237 238 max_diff = (result_triton - result_torch).abs().max().item() 239 print(f"Scale factor: {alpha}") 240 print(f"Max difference: {max_diff:.6f}") 241 print(f"Correct: {max_diff < 1e-5}") 242 243 # ------------------------------------------------------------------------- 244 # Example 3: Element-wise Operations 245 # ------------------------------------------------------------------------- 246 print("\n" + "-" * 40) 247 print("Example 3: Element-wise Operations") 248 print("-" * 40) 249 250 x = torch.randn(1000, device=device, dtype=torch.float32) 251 252 # Square 253 result_triton = square(x) 254 result_torch = x * x 255 print(f"Square - Correct: {torch.allclose(result_triton, result_torch)}") 256 257 # ReLU 258 result_triton = relu(x) 259 result_torch = torch.nn.functional.relu(x) 260 print(f"ReLU - Correct: {torch.allclose(result_triton, result_torch)}") 261 262 # ------------------------------------------------------------------------- 263 # Summary 264 # ------------------------------------------------------------------------- 265 print("\n" + "=" * 60) 266 print("Quick Start Complete!") 267 print("=" * 60) 268 print(""" 269 Key Takeaways: 270 1. @triton.jit marks a function as a Triton kernel 271 2. tl.program_id() gets the block index 272 3. tl.load()/tl.store() handle memory operations 273 4. mask= handles boundary conditions 274 5. BLOCK_SIZE as tl.constexpr enables compiler optimization 275 276 Next Steps: 277 - See gemm.py for matrix multiplication 278 - See flash_attention.py for advanced kernels 279 - See pytorch_integration.py for neural network integration 280 """) 281 282 283 if __name__ == "__main__": 284 main()