/ 01_TVM_End2End_Optimization / 3_tensorir_manual_schedule.py
3_tensorir_manual_schedule.py
1 """ 2 TVM End-to-End Optimization - Step 3: TensorIR Manual Schedule 3 4 This script demonstrates: 5 1. Writing TensorIR programs using TVMScript 6 2. Applying schedule primitives: split, reorder, bind, vectorize 7 3. Comparing manual schedule vs baseline performance 8 9 Requirements: 2.3, 2.4 10 """ 11 12 import argparse 13 import os 14 import sys 15 16 # Add project root to path before importing common modules 17 sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) 18 19 import numpy as np 20 21 try: 22 import tvm 23 from tvm import tir 24 from tvm.script import tir as tir_script 25 26 TVM_AVAILABLE = True 27 except ImportError: 28 TVM_AVAILABLE = False 29 print("Warning: TVM not installed. This script requires TVM.") 30 31 from common.benchmark.timer import TimingResult, benchmark_function 32 33 if TVM_AVAILABLE: 34 35 def create_matmul_module(m: int, n: int, k: int) -> "tvm.IRModule": 36 @tvm.script.ir_module 37 class MatMul: 38 @tir_script.prim_func 39 def main( 40 A: tir_script.Buffer((m, k), "float32"), 41 B: tir_script.Buffer((k, n), "float32"), 42 C: tir_script.Buffer((m, n), "float32"), 43 ): 44 tir_script.func_attr({"global_symbol": "main", "tir.noalias": True}) 45 for i, j, k_idx in tir_script.grid(m, n, k): 46 with tir_script.block("matmul"): 47 vi, vj, vk = tir_script.axis.remap("SSR", [i, j, k_idx]) 48 with tir_script.init(): 49 C[vi, vj] = tir_script.float32(0) 50 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] 51 52 return MatMul 53 54 55 def _require_positive_multiple(name: str, value: int, multiple: int) -> None: 56 if value <= 0 or value % multiple != 0: 57 raise ValueError(f"{name} must be a positive multiple of {multiple}, got {value}") 58 59 60 def apply_schedule_primitives( 61 mod: "tvm.IRModule", 62 block_size: int = 32, 63 k_tile: int = 8, 64 ) -> "tvm.tir.Schedule": 65 """Apply schedule primitives to optimize MatMul.""" 66 _require_positive_multiple("block_size", block_size, 1) 67 _require_positive_multiple("k_tile", k_tile, 1) 68 69 sch = tir.Schedule(mod) 70 block = sch.get_block("matmul") 71 i, j, k = sch.get_loops(block) 72 73 i_outer, i_inner = sch.split(i, factors=[None, block_size]) 74 j_outer, j_inner = sch.split(j, factors=[None, block_size]) 75 k_outer, k_inner = sch.split(k, factors=[None, k_tile]) 76 77 sch.reorder(i_outer, j_outer, k_outer, i_inner, j_inner, k_inner) 78 sch.vectorize(k_inner) 79 80 print("Applied schedule primitives:") 81 print(f" - split(i, [None, {block_size}])") 82 print(f" - split(j, [None, {block_size}])") 83 print(f" - split(k, [None, {k_tile}])") 84 print(" - reorder(i_outer, j_outer, k_outer, i_inner, j_inner, k_inner)") 85 print(" - vectorize(k_inner)") 86 87 return sch 88 89 90 def apply_gpu_schedule( 91 mod: "tvm.IRModule", block_m: int = 128, block_n: int = 128, thread_tile: int = 32 92 ) -> "tvm.tir.Schedule": 93 """Apply GPU-optimized schedule for MatMul.""" 94 _require_positive_multiple("block_m", block_m, thread_tile) 95 _require_positive_multiple("block_n", block_n, thread_tile) 96 97 sch = tir.Schedule(mod) 98 block = sch.get_block("matmul") 99 i, j, k = sch.get_loops(block) 100 101 i_outer, i_inner = sch.split(i, factors=[None, block_m]) 102 j_outer, j_inner = sch.split(j, factors=[None, block_n]) 103 i_inner_outer, i_inner_inner = sch.split(i_inner, factors=[None, thread_tile]) 104 j_inner_outer, j_inner_inner = sch.split(j_inner, factors=[None, thread_tile]) 105 106 sch.reorder(i_outer, j_outer, i_inner_outer, j_inner_outer, i_inner_inner, j_inner_inner, k) 107 sch.bind(i_outer, "blockIdx.x") 108 sch.bind(j_outer, "blockIdx.y") 109 sch.bind(i_inner_outer, "threadIdx.y") 110 sch.bind(j_inner_outer, "threadIdx.x") 111 112 print("Applied GPU schedule:") 113 print(f" - Block size: ({block_m}, {block_n})") 114 print(f" - Thread tile: ({thread_tile}, {thread_tile})") 115 print(" - Bound to blockIdx.x/y, threadIdx.x/y") 116 117 return sch 118 119 120 def benchmark_schedule( 121 sch: "tvm.tir.Schedule", 122 target: str = "llvm", 123 m: int = 1024, 124 n: int = 1024, 125 k: int = 1024, 126 warmup_iters: int = 5, 127 bench_iters: int = 20, 128 ) -> TimingResult: 129 """Benchmark a TensorIR schedule.""" 130 mod = sch.mod 131 target_obj = tvm.target.Target(target) 132 133 with tvm.transform.PassContext(opt_level=3): 134 lib = tvm.build(mod, target=target_obj) 135 136 dev = tvm.device(str(target_obj), 0) 137 a_np = np.random.randn(m, k).astype("float32") 138 b_np = np.random.randn(k, n).astype("float32") 139 c_np = np.zeros((m, n), dtype="float32") 140 141 a_tvm = tvm.nd.array(a_np, dev) 142 b_tvm = tvm.nd.array(b_np, dev) 143 c_tvm = tvm.nd.array(c_np, dev) 144 func = lib["main"] 145 146 def run(): 147 func(a_tvm, b_tvm, c_tvm) 148 149 return benchmark_function( 150 run, 151 warmup_iters=warmup_iters, 152 bench_iters=bench_iters, 153 sync_cuda=(str(target_obj) == "cuda"), 154 ) 155 156 157 def verify_correctness( 158 sch: "tvm.tir.Schedule", 159 target: str = "llvm", 160 m: int = 1024, 161 n: int = 1024, 162 k: int = 1024, 163 rtol: float = 1e-5, 164 ) -> bool: 165 """Verify that the scheduled computation is correct.""" 166 mod = sch.mod 167 target_obj = tvm.target.Target(target) 168 169 with tvm.transform.PassContext(opt_level=3): 170 lib = tvm.build(mod, target=target_obj) 171 172 dev = tvm.device(str(target_obj), 0) 173 a_np = np.random.randn(m, k).astype("float32") 174 b_np = np.random.randn(k, n).astype("float32") 175 c_np = np.zeros((m, n), dtype="float32") 176 177 a_tvm = tvm.nd.array(a_np, dev) 178 b_tvm = tvm.nd.array(b_np, dev) 179 c_tvm = tvm.nd.array(c_np, dev) 180 181 func = lib["main"] 182 func(a_tvm, b_tvm, c_tvm) 183 184 c_ref = np.matmul(a_np, b_np) 185 np.testing.assert_allclose(c_tvm.numpy(), c_ref, rtol=rtol) 186 return True 187 188 189 def _validate_problem_size(size: int, target: str) -> None: 190 _require_positive_multiple("size", size, 8) 191 if target == "cuda": 192 _require_positive_multiple("size", size, 32) 193 194 195 def main() -> None: 196 parser = argparse.ArgumentParser(description="TensorIR Manual Schedule") 197 parser.add_argument("--target", type=str, default="llvm", help="Target (llvm or cuda)") 198 parser.add_argument("--size", type=int, default=1024, help="Matrix size (M=N=K)") 199 parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations") 200 parser.add_argument("--bench", type=int, default=20, help="Benchmark iterations") 201 args = parser.parse_args() 202 203 if not TVM_AVAILABLE: 204 print("TVM is required for this script. Please install TVM.") 205 return 206 207 try: 208 _validate_problem_size(args.size, args.target) 209 except ValueError as exc: 210 print(f"Invalid configuration: {exc}") 211 return 212 213 print("=" * 60) 214 print("TVM End-to-End Optimization - Step 3: TensorIR Manual Schedule") 215 print("=" * 60) 216 217 m = n = k = args.size 218 print(f"\nMatrix size: {m} x {n} x {k}") 219 print(f"Target: {args.target}") 220 221 baseline_mod = create_matmul_module(m, n, k) 222 223 print("\n" + "-" * 40) 224 print("Baseline (no optimization):") 225 baseline_sch = tir.Schedule(baseline_mod) 226 baseline_result = benchmark_schedule( 227 baseline_sch, args.target, m, n, k, args.warmup, args.bench 228 ) 229 print(f" Mean latency: {baseline_result.mean_ms:.3f} ms") 230 231 print("\n" + "-" * 40) 232 print("Applying schedule primitives...") 233 optimized_sch = apply_schedule_primitives(create_matmul_module(m, n, k)) 234 235 print("\n" + "-" * 40) 236 print("Verifying correctness...") 237 try: 238 verify_correctness(optimized_sch, args.target, m, n, k) 239 print(" ✓ Results match numpy reference") 240 except AssertionError as exc: 241 print(f" ✗ Verification failed: {exc}") 242 return 243 244 print("\n" + "-" * 40) 245 print("Optimized schedule:") 246 optimized_result = benchmark_schedule( 247 optimized_sch, args.target, m, n, k, args.warmup, args.bench 248 ) 249 print(f" Mean latency: {optimized_result.mean_ms:.3f} ms") 250 251 print("\n" + "=" * 60) 252 print("Summary:") 253 print(f" Baseline: {baseline_result.mean_ms:.3f} ms") 254 print(f" Optimized: {optimized_result.mean_ms:.3f} ms") 255 print(f" Speedup: {baseline_result.mean_ms / optimized_result.mean_ms:.2f}x") 256 257 print("\n" + "-" * 40) 258 print("Optimized TensorIR:") 259 print(optimized_sch.mod.script()) 260 261 262 if __name__ == "__main__": 263 main()