simple_tile_ops.py
1 """ 2 Simple Tile Operations Examples. 3 4 This module demonstrates basic tile operations in the cuTile 5 programming model. 6 7 Requirements: 5.1, 5.2 8 """ 9 10 import os 11 import sys 12 13 import numpy as np 14 15 sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) 16 17 from tile_gemm import TileArray, TileShape, tile_kernel 18 19 # ============================================================================ 20 # Element-wise Operations 21 # ============================================================================ 22 23 24 @tile_kernel 25 def tile_add(a: TileArray, b: TileArray) -> TileArray: 26 """ 27 Element-wise addition of two tile arrays. 28 29 In cuTile, this would compile to a single kernel that: 30 - Automatically partitions the arrays into tiles 31 - Maps tiles to thread blocks 32 - Handles boundary conditions 33 """ 34 return a + b 35 36 37 @tile_kernel 38 def tile_mul(a: TileArray, b: TileArray) -> TileArray: 39 """Element-wise multiplication.""" 40 return a * b 41 42 43 @tile_kernel 44 def tile_relu(x: TileArray) -> TileArray: 45 """ 46 ReLU activation: max(0, x) 47 48 cuTile would optimize this to use efficient comparison 49 and selection operations. 50 """ 51 result = np.maximum(x.data, 0) 52 return TileArray(result, x.tile_shape) 53 54 55 @tile_kernel 56 def tile_softmax(x: TileArray, axis: int = -1) -> TileArray: 57 """ 58 Softmax along specified axis. 59 60 cuTile would handle the reduction operations efficiently, 61 using shared memory for intermediate results. 62 """ 63 # Numerically stable softmax 64 x_max = np.max(x.data, axis=axis, keepdims=True) 65 exp_x = np.exp(x.data - x_max) 66 sum_exp = np.sum(exp_x, axis=axis, keepdims=True) 67 result = exp_x / sum_exp 68 return TileArray(result, x.tile_shape) 69 70 71 # ============================================================================ 72 # Reduction Operations 73 # ============================================================================ 74 75 76 @tile_kernel 77 def tile_sum(x: TileArray, axis: int = None) -> TileArray: 78 """ 79 Sum reduction. 80 81 cuTile would use efficient parallel reduction algorithms, 82 automatically choosing between warp-level and block-level 83 reductions based on the data size. 84 """ 85 result = np.sum(x.data, axis=axis) 86 if axis is None: 87 return TileArray(np.array([result]), TileShape(1, 1)) 88 return TileArray(result, x.tile_shape) 89 90 91 @tile_kernel 92 def tile_mean(x: TileArray, axis: int = None) -> TileArray: 93 """Mean reduction.""" 94 result = np.mean(x.data, axis=axis) 95 if axis is None: 96 return TileArray(np.array([result]), TileShape(1, 1)) 97 return TileArray(result, x.tile_shape) 98 99 100 @tile_kernel 101 def tile_max(x: TileArray, axis: int = None) -> TileArray: 102 """Max reduction.""" 103 result = np.max(x.data, axis=axis) 104 if axis is None: 105 return TileArray(np.array([result]), TileShape(1, 1)) 106 return TileArray(result, x.tile_shape) 107 108 109 # ============================================================================ 110 # Fused Operations 111 # ============================================================================ 112 113 114 @tile_kernel 115 def fused_bias_relu(x: TileArray, bias: TileArray) -> TileArray: 116 """ 117 Fused bias addition and ReLU. 118 119 In cuTile, the compiler would automatically fuse these operations 120 into a single kernel, avoiding intermediate memory writes. 121 """ 122 result = np.maximum(x.data + bias.data, 0) 123 return TileArray(result, x.tile_shape) 124 125 126 @tile_kernel 127 def fused_layer_norm( 128 x: TileArray, gamma: TileArray, beta: TileArray, eps: float = 1e-5 129 ) -> TileArray: 130 """ 131 Fused Layer Normalization. 132 133 cuTile would optimize this to: 134 1. Compute mean and variance in a single pass 135 2. Fuse normalization with scale and shift 136 """ 137 mean = np.mean(x.data, axis=-1, keepdims=True) 138 var = np.var(x.data, axis=-1, keepdims=True) 139 x_norm = (x.data - mean) / np.sqrt(var + eps) 140 result = gamma.data * x_norm + beta.data 141 return TileArray(result, x.tile_shape) 142 143 144 # ============================================================================ 145 # Example Usage 146 # ============================================================================ 147 148 149 def demonstrate_tile_ops() -> None: 150 """Demonstrate various tile operations.""" 151 print("=" * 60) 152 print("cuTile Simple Operations Demonstration") 153 print("=" * 60) 154 155 # Create test data 156 np.random.seed(42) 157 shape = (256, 512) 158 159 x = TileArray(np.random.randn(*shape).astype(np.float32)) 160 y = TileArray(np.random.randn(*shape).astype(np.float32)) 161 162 print(f"\nInput shape: {shape}") 163 164 # Element-wise operations 165 print("\n" + "-" * 40) 166 print("Element-wise Operations") 167 168 z_add = tile_add(x, y) 169 print(f" tile_add: {z_add.shape}") 170 171 z_mul = tile_mul(x, y) 172 print(f" tile_mul: {z_mul.shape}") 173 174 z_relu = tile_relu(x) 175 print(f" tile_relu: {z_relu.shape}") 176 print(f" Min value (should be >= 0): {z_relu.data.min():.4f}") 177 178 # Reduction operations 179 print("\n" + "-" * 40) 180 print("Reduction Operations") 181 182 z_sum = tile_sum(x) 183 print(f" tile_sum (all): {z_sum.data[0]:.4f}") 184 185 z_sum_axis = tile_sum(x, axis=1) 186 print(f" tile_sum (axis=1): shape={z_sum_axis.shape}") 187 188 z_mean = tile_mean(x) 189 print(f" tile_mean: {z_mean.data[0]:.4f}") 190 191 z_max = tile_max(x) 192 print(f" tile_max: {z_max.data[0]:.4f}") 193 194 # Softmax 195 print("\n" + "-" * 40) 196 print("Softmax") 197 198 z_softmax = tile_softmax(x, axis=-1) 199 print(f" tile_softmax: shape={z_softmax.shape}") 200 print(f" Sum along axis=-1 (should be 1.0): {z_softmax.data[0].sum():.6f}") 201 202 # Fused operations 203 print("\n" + "-" * 40) 204 print("Fused Operations") 205 206 bias = TileArray(np.random.randn(shape[1]).astype(np.float32)) 207 z_fused = fused_bias_relu(x, bias) 208 print(f" fused_bias_relu: shape={z_fused.shape}") 209 210 gamma = TileArray(np.ones(shape[1]).astype(np.float32)) 211 beta = TileArray(np.zeros(shape[1]).astype(np.float32)) 212 z_ln = fused_layer_norm(x, gamma, beta) 213 print(f" fused_layer_norm: shape={z_ln.shape}") 214 print(f" Mean (should be ~0): {z_ln.data.mean():.6f}") 215 print(f" Std (should be ~1): {z_ln.data.std():.6f}") 216 217 print("\n" + "=" * 60) 218 print("Demonstration complete") 219 print("=" * 60) 220 221 222 if __name__ == "__main__": 223 demonstrate_tile_ops()