/ 05_Triton_GPU_Kernels / examples / quick_start.py
quick_start.py
  1  """
  2  Triton Quick Start Example.
  3  
  4  A simple introduction to Triton GPU programming.
  5  This example demonstrates:
  6  1. Writing a Triton kernel
  7  2. Launching the kernel
  8  3. Verifying correctness
  9  """
 10  
 11  from __future__ import annotations
 12  
 13  import torch
 14  import triton
 15  import triton.language as tl
 16  
 17  
 18  # =============================================================================
 19  # Example 1: Vector Addition (Hello World of GPU Programming)
 20  # =============================================================================
 21  
 22  @triton.jit
 23  def vector_add_kernel(
 24      x_ptr,  # Pointer to first input vector
 25      y_ptr,  # Pointer to second input vector
 26      output_ptr,  # Pointer to output vector
 27      N: tl.constexpr,  # Number of elements
 28      BLOCK_SIZE: tl.constexpr,  # Block size (compile-time constant)
 29  ):
 30      """
 31      Add two vectors element-wise: output[i] = x[i] + y[i]
 32  
 33      This is the "Hello World" of GPU programming.
 34  
 35      Key concepts:
 36      - tl.program_id(): Get the ID of this program (block)
 37      - tl.arange(): Generate a range of indices
 38      - tl.load()/tl.store(): Read/write memory
 39      - mask=: Handle boundary conditions
 40      """
 41      # Step 1: Get the program ID (which block are we?)
 42      pid = tl.program_id(0)  # 0 because we have a 1D grid
 43  
 44      # Step 2: Calculate which elements this block processes
 45      block_start = pid * BLOCK_SIZE
 46      offsets = block_start + tl.arange(0, BLOCK_SIZE)
 47  
 48      # Step 3: Create a mask for boundary handling
 49      # (in case N is not a multiple of BLOCK_SIZE)
 50      mask = offsets < N
 51  
 52      # Step 4: Load input elements
 53      x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
 54      y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
 55  
 56      # Step 5: Compute the result
 57      output = x + y
 58  
 59      # Step 6: Store the result
 60      tl.store(output_ptr + offsets, output, mask=mask)
 61  
 62  
 63  def vector_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
 64      """
 65      Add two vectors using Triton.
 66  
 67      Args:
 68          x: First input vector
 69          y: Second input vector
 70  
 71      Returns:
 72          Sum of x and y
 73      """
 74      # Validate inputs
 75      assert x.shape == y.shape
 76      N = x.numel()
 77  
 78      # Allocate output
 79      output = torch.empty_like(x)
 80  
 81      # Choose block size
 82      BLOCK_SIZE = 1024
 83  
 84      # Calculate grid size (number of blocks needed)
 85      grid = (triton.cdiv(N, BLOCK_SIZE),)
 86  
 87      # Launch the kernel
 88      # The [] syntax is how we pass the grid dimensions
 89      vector_add_kernel[grid](
 90          x, y, output,
 91          N=N,
 92          BLOCK_SIZE=BLOCK_SIZE,
 93      )
 94  
 95      return output
 96  
 97  
 98  # =============================================================================
 99  # Example 2: Vector Scaling (Scalar Multiplication)
100  # =============================================================================
101  
102  @triton.jit
103  def vector_scale_kernel(
104      x_ptr,
105      output_ptr,
106      alpha,  # Scalar multiplier
107      N: tl.constexpr,
108      BLOCK_SIZE: tl.constexpr,
109  ):
110      """Scale a vector: output[i] = alpha * x[i]"""
111      pid = tl.program_id(0)
112      offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
113      mask = offsets < N
114  
115      x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
116      output = alpha * x
117      tl.store(output_ptr + offsets, output, mask=mask)
118  
119  
120  def vector_scale(x: torch.Tensor, alpha: float) -> torch.Tensor:
121      """Scale a vector by a scalar."""
122      N = x.numel()
123      output = torch.empty_like(x)
124      BLOCK_SIZE = 1024
125      grid = (triton.cdiv(N, BLOCK_SIZE),)
126  
127      vector_scale_kernel[grid](x, output, alpha, N, BLOCK_SIZE)
128      return output
129  
130  
131  # =============================================================================
132  # Example 3: Element-wise Operations
133  # =============================================================================
134  
135  @triton.jit
136  def elementwise_kernel(
137      x_ptr,
138      output_ptr,
139      N: tl.constexpr,
140      BLOCK_SIZE: tl.constexpr,
141      OP: tl.constexpr,  # Operation type (0=square, 1=sqrt, 2=exp, 3=relu)
142  ):
143      """
144      Apply various element-wise operations.
145  
146      Demonstrates using tl.constexpr for compile-time branching.
147      """
148      pid = tl.program_id(0)
149      offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
150      mask = offsets < N
151  
152      x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
153  
154      # Apply operation based on OP
155      if OP == 0:  # Square
156          output = x * x
157      elif OP == 1:  # Square root
158          output = tl.sqrt(tl.abs(x))  # abs to handle negative numbers
159      elif OP == 2:  # Exponential
160          output = tl.exp(x)
161      elif OP == 3:  # ReLU
162          output = tl.where(x > 0, x, 0.0)
163      else:
164          output = x
165  
166      tl.store(output_ptr + offsets, output, mask=mask)
167  
168  
169  def square(x: torch.Tensor) -> torch.Tensor:
170      """Square each element."""
171      output = torch.empty_like(x)
172      grid = (triton.cdiv(x.numel(), 1024),)
173      elementwise_kernel[grid](x, output, x.numel(), 1024, OP=0)
174      return output
175  
176  
177  def relu(x: torch.Tensor) -> torch.Tensor:
178      """Apply ReLU activation."""
179      output = torch.empty_like(x)
180      grid = (triton.cdiv(x.numel(), 1024),)
181      elementwise_kernel[grid](x, output, x.numel(), 1024, OP=3)
182      return output
183  
184  
185  # =============================================================================
186  # Main: Run Examples
187  # =============================================================================
188  
189  def main():
190      """Run all quick start examples."""
191      print("=" * 60)
192      print("Triton Quick Start Examples")
193      print("=" * 60)
194  
195      # Check for CUDA
196      if not torch.cuda.is_available():
197          print("\nCUDA not available. Some examples will be skipped.")
198          device = "cpu"
199      else:
200          device = "cuda"
201  
202      print(f"\nUsing device: {device}")
203  
204      # -------------------------------------------------------------------------
205      # Example 1: Vector Addition
206      # -------------------------------------------------------------------------
207      print("\n" + "-" * 40)
208      print("Example 1: Vector Addition")
209      print("-" * 40)
210  
211      N = 10000
212      x = torch.randn(N, device=device, dtype=torch.float32)
213      y = torch.randn(N, device=device, dtype=torch.float32)
214  
215      # Triton result
216      result_triton = vector_add(x, y)
217  
218      # PyTorch reference
219      result_torch = x + y
220  
221      # Verify
222      max_diff = (result_triton - result_torch).abs().max().item()
223      print(f"Vector size: {N}")
224      print(f"Max difference: {max_diff:.6f}")
225      print(f"Correct: {max_diff < 1e-5}")
226  
227      # -------------------------------------------------------------------------
228      # Example 2: Vector Scaling
229      # -------------------------------------------------------------------------
230      print("\n" + "-" * 40)
231      print("Example 2: Vector Scaling")
232      print("-" * 40)
233  
234      alpha = 2.5
235      result_triton = vector_scale(x, alpha)
236      result_torch = alpha * x
237  
238      max_diff = (result_triton - result_torch).abs().max().item()
239      print(f"Scale factor: {alpha}")
240      print(f"Max difference: {max_diff:.6f}")
241      print(f"Correct: {max_diff < 1e-5}")
242  
243      # -------------------------------------------------------------------------
244      # Example 3: Element-wise Operations
245      # -------------------------------------------------------------------------
246      print("\n" + "-" * 40)
247      print("Example 3: Element-wise Operations")
248      print("-" * 40)
249  
250      x = torch.randn(1000, device=device, dtype=torch.float32)
251  
252      # Square
253      result_triton = square(x)
254      result_torch = x * x
255      print(f"Square - Correct: {torch.allclose(result_triton, result_torch)}")
256  
257      # ReLU
258      result_triton = relu(x)
259      result_torch = torch.nn.functional.relu(x)
260      print(f"ReLU - Correct: {torch.allclose(result_triton, result_torch)}")
261  
262      # -------------------------------------------------------------------------
263      # Summary
264      # -------------------------------------------------------------------------
265      print("\n" + "=" * 60)
266      print("Quick Start Complete!")
267      print("=" * 60)
268      print("""
269  Key Takeaways:
270  1. @triton.jit marks a function as a Triton kernel
271  2. tl.program_id() gets the block index
272  3. tl.load()/tl.store() handle memory operations
273  4. mask= handles boundary conditions
274  5. BLOCK_SIZE as tl.constexpr enables compiler optimization
275  
276  Next Steps:
277  - See gemm.py for matrix multiplication
278  - See flash_attention.py for advanced kernels
279  - See pytorch_integration.py for neural network integration
280  """)
281  
282  
283  if __name__ == "__main__":
284      main()