/ 05_Triton_GPU_Kernels / docs / triton_basics.md
triton_basics.md
  1  # Triton GPU Programming Basics
  2  
  3  [English](#english) | [中文](#中文)
  4  
  5  ---
  6  
  7  <a name="english"></a>
  8  ## English
  9  
 10  ### What is Triton?
 11  
 12  [Triton](https://triton-lang.org/) is a language and compiler for parallel programming on GPUs. It provides a Python-based DSL (Domain-Specific Language) that compiles to efficient GPU kernels while abstracting away low-level details like thread management and shared memory allocation.
 13  
 14  ### Why Triton?
 15  
 16  | Aspect | Traditional CUDA | Triton |
 17  |--------|-----------------|--------|
 18  | Language | C++/CUDA | Python DSL |
 19  | Thread Management | Manual | Automatic |
 20  | Memory Management | Manual | Automatic |
 21  | Tensor Cores | wmma/mma intrinsics | tl.dot() |
 22  | Learning Curve | Steep | Moderate |
 23  | Performance | Highest | High |
 24  
 25  ### Programming Model
 26  
 27  Triton operates at the **block level**, not the thread level. You write programs that operate on blocks of data, and the compiler handles:
 28  
 29  1. **Thread mapping** - How threads map to data elements
 30  2. **Shared memory** - Automatic allocation and management
 31  3. **Memory coalescing** - Optimized memory access patterns
 32  4. **Tensor Core usage** - Automatic detection and usage
 33  
 34  ```
 35  ┌─────────────────────────────────────────────────────────────────┐
 36  │                    Triton Programming Model                     │
 37  ├─────────────────────────────────────────────────────────────────┤
 38  │                                                                 │
 39  │   Your Code (Block Level)                                       │
 40  │   ┌─────────────────────────────────────────────────────────┐  │
 41  │   │  @triton.jit                                             │  │
 42  │   │  def kernel(x_ptr, y_ptr, N, BLOCK_SIZE: tl.constexpr): │  │
 43  │   │      pid = tl.program_id(0)                              │  │
 44  │   │      offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) │  │
 45  │   │      x = tl.load(x_ptr + offs)                           │  │
 46  │   │      y = x * 2  # Your computation                       │  │
 47  │   │      tl.store(y_ptr + offs, y)                           │  │
 48  │   └─────────────────────────────────────────────────────────┘  │
 49  │                              │                                  │
 50  │                              ▼                                  │
 51  │   Compiler Handles (Automatic)                                  │
 52  │   ┌─────────────────────────────────────────────────────────┐  │
 53  │   │  • Thread indexing (tid, blockIdx, blockDim)             │  │
 54  │   │  • Shared memory allocation                              │  │
 55  │   │  • Memory coalescing                                     │  │
 56  │   │  • Register allocation                                   │  │
 57  │   │  • Instruction scheduling                                │  │
 58  │   └─────────────────────────────────────────────────────────┘  │
 59  │                                                                 │
 60  └─────────────────────────────────────────────────────────────────┘
 61  ```
 62  
 63  ### Key Concepts
 64  
 65  #### 1. Program ID
 66  
 67  Each kernel instance (program) has a unique ID:
 68  
 69  ```python
 70  pid = tl.program_id(0)  # 1D grid
 71  pid_x = tl.program_id(0)  # 2D grid
 72  pid_y = tl.program_id(1)
 73  ```
 74  
 75  #### 2. Block Offsets
 76  
 77  Calculate which data elements this block processes:
 78  
 79  ```python
 80  BLOCK_SIZE = 128
 81  offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
 82  ```
 83  
 84  #### 3. Load and Store
 85  
 86  Memory operations with masking for boundary handling:
 87  
 88  ```python
 89  # Load with mask
 90  x = tl.load(ptr + offs, mask=offs < N, other=0.0)
 91  
 92  # Store with mask
 93  tl.store(ptr + offs, y, mask=offs < N)
 94  ```
 95  
 96  #### 4. Compute
 97  
 98  Standard Python operations work on blocks:
 99  
100  ```python
101  y = x * 2 + 1  # Element-wise
102  y = tl.sum(x)  # Reduction
103  y = tl.dot(a, b)  # Matrix multiply (uses Tensor Cores)
104  ```
105  
106  ### Example: Vector Addition
107  
108  ```python
109  import triton
110  import triton.language as tl
111  
112  @triton.jit
113  def add_kernel(x_ptr, y_ptr, output_ptr, N, BLOCK_SIZE: tl.constexpr):
114      """Add two vectors."""
115      # Get program ID
116      pid = tl.program_id(0)
117      
118      # Calculate offsets
119      offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
120      
121      # Load inputs
122      x = tl.load(x_ptr + offs, mask=offs < N)
123      y = tl.load(y_ptr + offs, mask=offs < N)
124      
125      # Compute
126      output = x + y
127      
128      # Store result
129      tl.store(output_ptr + offs, output, mask=offs < N)
130  
131  # Launch kernel
132  def add(x, y):
133      N = x.shape[0]
134      output = torch.empty_like(x)
135      BLOCK_SIZE = 128
136      grid = (triton.cdiv(N, BLOCK_SIZE),)
137      add_kernel[grid](x, y, output, N, BLOCK_SIZE)
138      return output
139  ```
140  
141  ### Example: Matrix Multiplication
142  
143  ```python
144  @triton.jit
145  def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K,
146                    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
147      """Block-level GEMM."""
148      pid_m = tl.program_id(0)
149      pid_n = tl.program_id(1)
150      
151      # Block offsets
152      offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
153      offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
154      
155      # Initialize accumulator
156      acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
157      
158      # Loop over K
159      for k in range(0, K, BLOCK_K):
160          offs_k = k + tl.arange(0, BLOCK_K)
161          
162          # Load tiles
163          a = tl.load(a_ptr + offs_m[:, None] * K + offs_k[None, :])
164          b = tl.load(b_ptr + offs_k[:, None] * N + offs_n[None, :])
165          
166          # Matrix multiply (uses Tensor Cores)
167          acc += tl.dot(a, b)
168      
169      # Store result
170      tl.store(c_ptr + offs_m[:, None] * N + offs_n[None, :], acc)
171  ```
172  
173  ### Autotuning
174  
175  Triton supports automatic tuning of block sizes:
176  
177  ```python
178  @triton.autotune(
179      configs=[
180          triton.Config({'BLOCK_SIZE': 64}, num_warps=4),
181          triton.Config({'BLOCK_SIZE': 128}, num_warps=8),
182          triton.Config({'BLOCK_SIZE': 256}, num_warps=16),
183      ],
184      key=['N'],  # Autotune based on input size
185  )
186  @triton.jit
187  def kernel(...):
188      ...
189  ```
190  
191  ### Best Practices
192  
193  1. **Use constexpr for block sizes** - Allows compiler optimization
194  2. **Mask boundary accesses** - Prevent out-of-bounds memory access
195  3. **Choose good block sizes** - Powers of 2, multiples of 32
196  4. **Use tl.dot for matmul** - Automatic Tensor Core usage
197  5. **Profile and autotune** - Different sizes may need different configs
198  
199  ---
200  
201  <a name="中文"></a>
202  ## 中文
203  
204  ### 什么是 Triton?
205  
206  [Triton](https://triton-lang.org/) 是一种用于 GPU 并行编程的语言和编译器。它提供了基于 Python 的 DSL(领域特定语言),可编译为高效的 GPU 内核,同时抽象掉线程管理和共享内存分配等底层细节。
207  
208  ### 为什么选择 Triton?
209  
210  | 方面 | 传统 CUDA | Triton |
211  |------|-----------|---------|
212  | 语言 | C++/CUDA | Python DSL |
213  | 线程管理 | 手动 | 自动 |
214  | 内存管理 | 手动 | 自动 |
215  | Tensor Cores | wmma/mma 内置函数 | tl.dot() |
216  | 学习曲线 | 陡峭 | 中等 |
217  | 性能 | 最高 | 高 |
218  
219  ### 编程模型
220  
221  Triton 在**块级别**而非线程级别工作。你编写操作数据块的程序,编译器负责:
222  
223  1. **线程映射** - 线程如何映射到数据元素
224  2. **共享内存** - 自动分配和管理
225  3. **内存合并** - 优化的内存访问模式
226  4. **Tensor Core 使用** - 自动检测和使用
227  
228  ### 核心概念
229  
230  #### 1. 程序 ID
231  
232  每个内核实例(程序)有唯一 ID:
233  
234  ```python
235  pid = tl.program_id(0)  # 1D 网格
236  pid_x = tl.program_id(0)  # 2D 网格
237  pid_y = tl.program_id(1)
238  ```
239  
240  #### 2. 块偏移
241  
242  计算此块处理的数据元素:
243  
244  ```python
245  BLOCK_SIZE = 128
246  offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
247  ```
248  
249  #### 3. 加载和存储
250  
251  带边界处理掩码的内存操作:
252  
253  ```python
254  # 带掩码加载
255  x = tl.load(ptr + offs, mask=offs < N, other=0.0)
256  
257  # 带掩码存储
258  tl.store(ptr + offs, y, mask=offs < N)
259  ```
260  
261  #### 4. 计算
262  
263  标准 Python 操作作用于块:
264  
265  ```python
266  y = x * 2 + 1  # 逐元素
267  y = tl.sum(x)  # 归约
268  y = tl.dot(a, b)  # 矩阵乘法(使用 Tensor Cores)
269  ```
270  
271  ### 最佳实践
272  
273  1. **使用 constexpr 声明块大小** - 允许编译器优化
274  2. **掩码边界访问** - 防止越界内存访问
275  3. **选择好的块大小** - 2 的幂次,32 的倍数
276  4. **使用 tl.dot 进行矩阵乘法** - 自动使用 Tensor Cores
277  5. **性能分析和自动调优** - 不同大小可能需要不同配置
278  
279  ### 参考资料
280  
281  - [Triton 官方文档](https://triton-lang.org/main/index.html)
282  - [Triton 教程](https://triton-lang.org/main/getting-started/tutorials/index.html)
283  - [Triton GitHub](https://github.com/openai/triton)