/ 04_CuTile_NextGen_CUDA / examples / simple_tile_ops.py
simple_tile_ops.py
  1  """
  2  Simple Tile Operations Examples.
  3  
  4  This module demonstrates basic tile operations in the cuTile
  5  programming model.
  6  
  7  Requirements: 5.1, 5.2
  8  """
  9  
 10  import os
 11  import sys
 12  
 13  import numpy as np
 14  
 15  sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
 16  
 17  from tile_gemm import TileArray, TileShape, tile_kernel
 18  
 19  # ============================================================================
 20  # Element-wise Operations
 21  # ============================================================================
 22  
 23  
 24  @tile_kernel
 25  def tile_add(a: TileArray, b: TileArray) -> TileArray:
 26      """
 27      Element-wise addition of two tile arrays.
 28  
 29      In cuTile, this would compile to a single kernel that:
 30      - Automatically partitions the arrays into tiles
 31      - Maps tiles to thread blocks
 32      - Handles boundary conditions
 33      """
 34      return a + b
 35  
 36  
 37  @tile_kernel
 38  def tile_mul(a: TileArray, b: TileArray) -> TileArray:
 39      """Element-wise multiplication."""
 40      return a * b
 41  
 42  
 43  @tile_kernel
 44  def tile_relu(x: TileArray) -> TileArray:
 45      """
 46      ReLU activation: max(0, x)
 47  
 48      cuTile would optimize this to use efficient comparison
 49      and selection operations.
 50      """
 51      result = np.maximum(x.data, 0)
 52      return TileArray(result, x.tile_shape)
 53  
 54  
 55  @tile_kernel
 56  def tile_softmax(x: TileArray, axis: int = -1) -> TileArray:
 57      """
 58      Softmax along specified axis.
 59  
 60      cuTile would handle the reduction operations efficiently,
 61      using shared memory for intermediate results.
 62      """
 63      # Numerically stable softmax
 64      x_max = np.max(x.data, axis=axis, keepdims=True)
 65      exp_x = np.exp(x.data - x_max)
 66      sum_exp = np.sum(exp_x, axis=axis, keepdims=True)
 67      result = exp_x / sum_exp
 68      return TileArray(result, x.tile_shape)
 69  
 70  
 71  # ============================================================================
 72  # Reduction Operations
 73  # ============================================================================
 74  
 75  
 76  @tile_kernel
 77  def tile_sum(x: TileArray, axis: int = None) -> TileArray:
 78      """
 79      Sum reduction.
 80  
 81      cuTile would use efficient parallel reduction algorithms,
 82      automatically choosing between warp-level and block-level
 83      reductions based on the data size.
 84      """
 85      result = np.sum(x.data, axis=axis)
 86      if axis is None:
 87          return TileArray(np.array([result]), TileShape(1, 1))
 88      return TileArray(result, x.tile_shape)
 89  
 90  
 91  @tile_kernel
 92  def tile_mean(x: TileArray, axis: int = None) -> TileArray:
 93      """Mean reduction."""
 94      result = np.mean(x.data, axis=axis)
 95      if axis is None:
 96          return TileArray(np.array([result]), TileShape(1, 1))
 97      return TileArray(result, x.tile_shape)
 98  
 99  
100  @tile_kernel
101  def tile_max(x: TileArray, axis: int = None) -> TileArray:
102      """Max reduction."""
103      result = np.max(x.data, axis=axis)
104      if axis is None:
105          return TileArray(np.array([result]), TileShape(1, 1))
106      return TileArray(result, x.tile_shape)
107  
108  
109  # ============================================================================
110  # Fused Operations
111  # ============================================================================
112  
113  
114  @tile_kernel
115  def fused_bias_relu(x: TileArray, bias: TileArray) -> TileArray:
116      """
117      Fused bias addition and ReLU.
118  
119      In cuTile, the compiler would automatically fuse these operations
120      into a single kernel, avoiding intermediate memory writes.
121      """
122      result = np.maximum(x.data + bias.data, 0)
123      return TileArray(result, x.tile_shape)
124  
125  
126  @tile_kernel
127  def fused_layer_norm(
128      x: TileArray, gamma: TileArray, beta: TileArray, eps: float = 1e-5
129  ) -> TileArray:
130      """
131      Fused Layer Normalization.
132  
133      cuTile would optimize this to:
134      1. Compute mean and variance in a single pass
135      2. Fuse normalization with scale and shift
136      """
137      mean = np.mean(x.data, axis=-1, keepdims=True)
138      var = np.var(x.data, axis=-1, keepdims=True)
139      x_norm = (x.data - mean) / np.sqrt(var + eps)
140      result = gamma.data * x_norm + beta.data
141      return TileArray(result, x.tile_shape)
142  
143  
144  # ============================================================================
145  # Example Usage
146  # ============================================================================
147  
148  
149  def demonstrate_tile_ops() -> None:
150      """Demonstrate various tile operations."""
151      print("=" * 60)
152      print("cuTile Simple Operations Demonstration")
153      print("=" * 60)
154  
155      # Create test data
156      np.random.seed(42)
157      shape = (256, 512)
158  
159      x = TileArray(np.random.randn(*shape).astype(np.float32))
160      y = TileArray(np.random.randn(*shape).astype(np.float32))
161  
162      print(f"\nInput shape: {shape}")
163  
164      # Element-wise operations
165      print("\n" + "-" * 40)
166      print("Element-wise Operations")
167  
168      z_add = tile_add(x, y)
169      print(f"  tile_add: {z_add.shape}")
170  
171      z_mul = tile_mul(x, y)
172      print(f"  tile_mul: {z_mul.shape}")
173  
174      z_relu = tile_relu(x)
175      print(f"  tile_relu: {z_relu.shape}")
176      print(f"    Min value (should be >= 0): {z_relu.data.min():.4f}")
177  
178      # Reduction operations
179      print("\n" + "-" * 40)
180      print("Reduction Operations")
181  
182      z_sum = tile_sum(x)
183      print(f"  tile_sum (all): {z_sum.data[0]:.4f}")
184  
185      z_sum_axis = tile_sum(x, axis=1)
186      print(f"  tile_sum (axis=1): shape={z_sum_axis.shape}")
187  
188      z_mean = tile_mean(x)
189      print(f"  tile_mean: {z_mean.data[0]:.4f}")
190  
191      z_max = tile_max(x)
192      print(f"  tile_max: {z_max.data[0]:.4f}")
193  
194      # Softmax
195      print("\n" + "-" * 40)
196      print("Softmax")
197  
198      z_softmax = tile_softmax(x, axis=-1)
199      print(f"  tile_softmax: shape={z_softmax.shape}")
200      print(f"    Sum along axis=-1 (should be 1.0): {z_softmax.data[0].sum():.6f}")
201  
202      # Fused operations
203      print("\n" + "-" * 40)
204      print("Fused Operations")
205  
206      bias = TileArray(np.random.randn(shape[1]).astype(np.float32))
207      z_fused = fused_bias_relu(x, bias)
208      print(f"  fused_bias_relu: shape={z_fused.shape}")
209  
210      gamma = TileArray(np.ones(shape[1]).astype(np.float32))
211      beta = TileArray(np.zeros(shape[1]).astype(np.float32))
212      z_ln = fused_layer_norm(x, gamma, beta)
213      print(f"  fused_layer_norm: shape={z_ln.shape}")
214      print(f"    Mean (should be ~0): {z_ln.data.mean():.6f}")
215      print(f"    Std (should be ~1): {z_ln.data.std():.6f}")
216  
217      print("\n" + "=" * 60)
218      print("Demonstration complete")
219      print("=" * 60)
220  
221  
222  if __name__ == "__main__":
223      demonstrate_tile_ops()