triton_basics.md
1 # Triton GPU Programming Basics 2 3 [English](#english) | [中文](#中文) 4 5 --- 6 7 <a name="english"></a> 8 ## English 9 10 ### What is Triton? 11 12 [Triton](https://triton-lang.org/) is a language and compiler for parallel programming on GPUs. It provides a Python-based DSL (Domain-Specific Language) that compiles to efficient GPU kernels while abstracting away low-level details like thread management and shared memory allocation. 13 14 ### Why Triton? 15 16 | Aspect | Traditional CUDA | Triton | 17 |--------|-----------------|--------| 18 | Language | C++/CUDA | Python DSL | 19 | Thread Management | Manual | Automatic | 20 | Memory Management | Manual | Automatic | 21 | Tensor Cores | wmma/mma intrinsics | tl.dot() | 22 | Learning Curve | Steep | Moderate | 23 | Performance | Highest | High | 24 25 ### Programming Model 26 27 Triton operates at the **block level**, not the thread level. You write programs that operate on blocks of data, and the compiler handles: 28 29 1. **Thread mapping** - How threads map to data elements 30 2. **Shared memory** - Automatic allocation and management 31 3. **Memory coalescing** - Optimized memory access patterns 32 4. **Tensor Core usage** - Automatic detection and usage 33 34 ``` 35 ┌─────────────────────────────────────────────────────────────────┐ 36 │ Triton Programming Model │ 37 ├─────────────────────────────────────────────────────────────────┤ 38 │ │ 39 │ Your Code (Block Level) │ 40 │ ┌─────────────────────────────────────────────────────────┐ │ 41 │ │ @triton.jit │ │ 42 │ │ def kernel(x_ptr, y_ptr, N, BLOCK_SIZE: tl.constexpr): │ │ 43 │ │ pid = tl.program_id(0) │ │ 44 │ │ offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) │ │ 45 │ │ x = tl.load(x_ptr + offs) │ │ 46 │ │ y = x * 2 # Your computation │ │ 47 │ │ tl.store(y_ptr + offs, y) │ │ 48 │ └─────────────────────────────────────────────────────────┘ │ 49 │ │ │ 50 │ ▼ │ 51 │ Compiler Handles (Automatic) │ 52 │ ┌─────────────────────────────────────────────────────────┐ │ 53 │ │ • Thread indexing (tid, blockIdx, blockDim) │ │ 54 │ │ • Shared memory allocation │ │ 55 │ │ • Memory coalescing │ │ 56 │ │ • Register allocation │ │ 57 │ │ • Instruction scheduling │ │ 58 │ └─────────────────────────────────────────────────────────┘ │ 59 │ │ 60 └─────────────────────────────────────────────────────────────────┘ 61 ``` 62 63 ### Key Concepts 64 65 #### 1. Program ID 66 67 Each kernel instance (program) has a unique ID: 68 69 ```python 70 pid = tl.program_id(0) # 1D grid 71 pid_x = tl.program_id(0) # 2D grid 72 pid_y = tl.program_id(1) 73 ``` 74 75 #### 2. Block Offsets 76 77 Calculate which data elements this block processes: 78 79 ```python 80 BLOCK_SIZE = 128 81 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 82 ``` 83 84 #### 3. Load and Store 85 86 Memory operations with masking for boundary handling: 87 88 ```python 89 # Load with mask 90 x = tl.load(ptr + offs, mask=offs < N, other=0.0) 91 92 # Store with mask 93 tl.store(ptr + offs, y, mask=offs < N) 94 ``` 95 96 #### 4. Compute 97 98 Standard Python operations work on blocks: 99 100 ```python 101 y = x * 2 + 1 # Element-wise 102 y = tl.sum(x) # Reduction 103 y = tl.dot(a, b) # Matrix multiply (uses Tensor Cores) 104 ``` 105 106 ### Example: Vector Addition 107 108 ```python 109 import triton 110 import triton.language as tl 111 112 @triton.jit 113 def add_kernel(x_ptr, y_ptr, output_ptr, N, BLOCK_SIZE: tl.constexpr): 114 """Add two vectors.""" 115 # Get program ID 116 pid = tl.program_id(0) 117 118 # Calculate offsets 119 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 120 121 # Load inputs 122 x = tl.load(x_ptr + offs, mask=offs < N) 123 y = tl.load(y_ptr + offs, mask=offs < N) 124 125 # Compute 126 output = x + y 127 128 # Store result 129 tl.store(output_ptr + offs, output, mask=offs < N) 130 131 # Launch kernel 132 def add(x, y): 133 N = x.shape[0] 134 output = torch.empty_like(x) 135 BLOCK_SIZE = 128 136 grid = (triton.cdiv(N, BLOCK_SIZE),) 137 add_kernel[grid](x, y, output, N, BLOCK_SIZE) 138 return output 139 ``` 140 141 ### Example: Matrix Multiplication 142 143 ```python 144 @triton.jit 145 def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, 146 BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr): 147 """Block-level GEMM.""" 148 pid_m = tl.program_id(0) 149 pid_n = tl.program_id(1) 150 151 # Block offsets 152 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 153 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 154 155 # Initialize accumulator 156 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) 157 158 # Loop over K 159 for k in range(0, K, BLOCK_K): 160 offs_k = k + tl.arange(0, BLOCK_K) 161 162 # Load tiles 163 a = tl.load(a_ptr + offs_m[:, None] * K + offs_k[None, :]) 164 b = tl.load(b_ptr + offs_k[:, None] * N + offs_n[None, :]) 165 166 # Matrix multiply (uses Tensor Cores) 167 acc += tl.dot(a, b) 168 169 # Store result 170 tl.store(c_ptr + offs_m[:, None] * N + offs_n[None, :], acc) 171 ``` 172 173 ### Autotuning 174 175 Triton supports automatic tuning of block sizes: 176 177 ```python 178 @triton.autotune( 179 configs=[ 180 triton.Config({'BLOCK_SIZE': 64}, num_warps=4), 181 triton.Config({'BLOCK_SIZE': 128}, num_warps=8), 182 triton.Config({'BLOCK_SIZE': 256}, num_warps=16), 183 ], 184 key=['N'], # Autotune based on input size 185 ) 186 @triton.jit 187 def kernel(...): 188 ... 189 ``` 190 191 ### Best Practices 192 193 1. **Use constexpr for block sizes** - Allows compiler optimization 194 2. **Mask boundary accesses** - Prevent out-of-bounds memory access 195 3. **Choose good block sizes** - Powers of 2, multiples of 32 196 4. **Use tl.dot for matmul** - Automatic Tensor Core usage 197 5. **Profile and autotune** - Different sizes may need different configs 198 199 --- 200 201 <a name="中文"></a> 202 ## 中文 203 204 ### 什么是 Triton? 205 206 [Triton](https://triton-lang.org/) 是一种用于 GPU 并行编程的语言和编译器。它提供了基于 Python 的 DSL(领域特定语言),可编译为高效的 GPU 内核,同时抽象掉线程管理和共享内存分配等底层细节。 207 208 ### 为什么选择 Triton? 209 210 | 方面 | 传统 CUDA | Triton | 211 |------|-----------|---------| 212 | 语言 | C++/CUDA | Python DSL | 213 | 线程管理 | 手动 | 自动 | 214 | 内存管理 | 手动 | 自动 | 215 | Tensor Cores | wmma/mma 内置函数 | tl.dot() | 216 | 学习曲线 | 陡峭 | 中等 | 217 | 性能 | 最高 | 高 | 218 219 ### 编程模型 220 221 Triton 在**块级别**而非线程级别工作。你编写操作数据块的程序,编译器负责: 222 223 1. **线程映射** - 线程如何映射到数据元素 224 2. **共享内存** - 自动分配和管理 225 3. **内存合并** - 优化的内存访问模式 226 4. **Tensor Core 使用** - 自动检测和使用 227 228 ### 核心概念 229 230 #### 1. 程序 ID 231 232 每个内核实例(程序)有唯一 ID: 233 234 ```python 235 pid = tl.program_id(0) # 1D 网格 236 pid_x = tl.program_id(0) # 2D 网格 237 pid_y = tl.program_id(1) 238 ``` 239 240 #### 2. 块偏移 241 242 计算此块处理的数据元素: 243 244 ```python 245 BLOCK_SIZE = 128 246 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 247 ``` 248 249 #### 3. 加载和存储 250 251 带边界处理掩码的内存操作: 252 253 ```python 254 # 带掩码加载 255 x = tl.load(ptr + offs, mask=offs < N, other=0.0) 256 257 # 带掩码存储 258 tl.store(ptr + offs, y, mask=offs < N) 259 ``` 260 261 #### 4. 计算 262 263 标准 Python 操作作用于块: 264 265 ```python 266 y = x * 2 + 1 # 逐元素 267 y = tl.sum(x) # 归约 268 y = tl.dot(a, b) # 矩阵乘法(使用 Tensor Cores) 269 ``` 270 271 ### 最佳实践 272 273 1. **使用 constexpr 声明块大小** - 允许编译器优化 274 2. **掩码边界访问** - 防止越界内存访问 275 3. **选择好的块大小** - 2 的幂次,32 的倍数 276 4. **使用 tl.dot 进行矩阵乘法** - 自动使用 Tensor Cores 277 5. **性能分析和自动调优** - 不同大小可能需要不同配置 278 279 ### 参考资料 280 281 - [Triton 官方文档](https://triton-lang.org/main/index.html) 282 - [Triton 教程](https://triton-lang.org/main/getting-started/tutorials/index.html) 283 - [Triton GitHub](https://github.com/openai/triton)