tensorir_schedule.py
1 """ 2 TVM End-to-End Optimization - Step 3: TensorIR Manual Schedule 3 4 This script demonstrates: 5 1. Writing TensorIR programs using TVMScript 6 2. Applying schedule primitives: split, reorder, bind, vectorize 7 3. Comparing manual schedule vs baseline performance 8 9 Requirements: 2.3, 2.4 10 """ 11 12 """TVM TensorIR Manual Schedule module.""" 13 14 from __future__ import annotations 15 16 import argparse 17 from typing import TYPE_CHECKING 18 19 import numpy as np 20 21 from common.benchmark.timer import TimingResult, benchmark_function 22 from common.utils.path_utils import ensure_project_root_in_path 23 24 # Ensure project root is in path for imports 25 ensure_project_root_in_path() 26 27 if TYPE_CHECKING: 28 import tvm 29 30 # TVM imports 31 try: 32 import tvm 33 from tvm import tir 34 from tvm.script import tir as T 35 36 TVM_AVAILABLE = True 37 except ImportError: 38 TVM_AVAILABLE = False 39 40 if TVM_AVAILABLE: 41 42 def create_matmul_module(m: int, n: int, k: int) -> tvm.IRModule: 43 @tvm.script.ir_module 44 class MatMul: 45 @T.prim_func 46 def main( 47 A: T.Buffer((m, k), "float32"), 48 B: T.Buffer((k, n), "float32"), 49 C: T.Buffer((m, n), "float32"), 50 ): 51 T.func_attr({"global_symbol": "main", "tir.noalias": True}) 52 for i, j, k_idx in T.grid(m, n, k): 53 with T.block("matmul"): 54 vi, vj, vk = T.axis.remap("SSR", [i, j, k_idx]) 55 with T.init(): 56 C[vi, vj] = T.float32(0) 57 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] 58 59 return MatMul 60 61 62 def _require_positive_multiple(name: str, value: int, multiple: int) -> None: 63 if value <= 0 or value % multiple != 0: 64 raise ValueError(f"{name} must be a positive multiple of {multiple}, got {value}") 65 66 67 def apply_schedule_primitives( 68 mod: tvm.IRModule, 69 block_size: int = 32, 70 k_tile: int = 8, 71 ) -> tvm.tir.Schedule: 72 """Apply schedule primitives to optimize MatMul.""" 73 _require_positive_multiple("block_size", block_size, 1) 74 _require_positive_multiple("k_tile", k_tile, 1) 75 76 sch = tir.Schedule(mod) 77 block = sch.get_block("matmul") 78 i, j, k = sch.get_loops(block) 79 80 i_outer, i_inner = sch.split(i, factors=[None, block_size]) 81 j_outer, j_inner = sch.split(j, factors=[None, block_size]) 82 k_outer, k_inner = sch.split(k, factors=[None, k_tile]) 83 84 sch.reorder(i_outer, j_outer, k_outer, i_inner, j_inner, k_inner) 85 sch.vectorize(k_inner) 86 87 print("Applied schedule primitives:") 88 print(f" - split(i, [None, {block_size}])") 89 print(f" - split(j, [None, {block_size}])") 90 print(f" - split(k, [None, {k_tile}])") 91 print(" - reorder(i_outer, j_outer, k_outer, i_inner, j_inner, k_inner)") 92 print(" - vectorize(k_inner)") 93 94 return sch 95 96 97 def apply_gpu_schedule( 98 mod: tvm.IRModule, block_m: int = 128, block_n: int = 128, thread_tile: int = 32 99 ) -> tvm.tir.Schedule: 100 """Apply GPU-optimized schedule for MatMul.""" 101 _require_positive_multiple("block_m", block_m, thread_tile) 102 _require_positive_multiple("block_n", block_n, thread_tile) 103 104 sch = tir.Schedule(mod) 105 block = sch.get_block("matmul") 106 i, j, k = sch.get_loops(block) 107 108 i_outer, i_inner = sch.split(i, factors=[None, block_m]) 109 j_outer, j_inner = sch.split(j, factors=[None, block_n]) 110 i_inner_outer, i_inner_inner = sch.split(i_inner, factors=[None, thread_tile]) 111 j_inner_outer, j_inner_inner = sch.split(j_inner, factors=[None, thread_tile]) 112 113 sch.reorder(i_outer, j_outer, i_inner_outer, j_inner_outer, i_inner_inner, j_inner_inner, k) 114 sch.bind(i_outer, "blockIdx.x") 115 sch.bind(j_outer, "blockIdx.y") 116 sch.bind(i_inner_outer, "threadIdx.y") 117 sch.bind(j_inner_outer, "threadIdx.x") 118 119 print("Applied GPU schedule:") 120 print(f" - Block size: ({block_m}, {block_n})") 121 print(f" - Thread tile: ({thread_tile}, {thread_tile})") 122 print(" - Bound to blockIdx.x/y, threadIdx.x/y") 123 124 return sch 125 126 127 def benchmark_schedule( 128 sch: tvm.tir.Schedule, 129 target: str = "llvm", 130 m: int = 1024, 131 n: int = 1024, 132 k: int = 1024, 133 warmup_iters: int = 5, 134 bench_iters: int = 20, 135 ) -> TimingResult: 136 """Benchmark a TensorIR schedule.""" 137 mod = sch.mod 138 target_obj = tvm.target.Target(target) 139 140 with tvm.transform.PassContext(opt_level=3): 141 lib = tvm.build(mod, target=target_obj) 142 143 dev = tvm.device(str(target_obj), 0) 144 a_np = np.random.randn(m, k).astype("float32") 145 b_np = np.random.randn(k, n).astype("float32") 146 c_np = np.zeros((m, n), dtype="float32") 147 148 a_tvm = tvm.nd.array(a_np, dev) 149 b_tvm = tvm.nd.array(b_np, dev) 150 c_tvm = tvm.nd.array(c_np, dev) 151 func = lib["main"] 152 153 def run(): 154 func(a_tvm, b_tvm, c_tvm) 155 156 return benchmark_function( 157 run, 158 warmup_iters=warmup_iters, 159 bench_iters=bench_iters, 160 sync_cuda=(str(target_obj) == "cuda"), 161 ) 162 163 164 def verify_correctness( 165 sch: tvm.tir.Schedule, 166 target: str = "llvm", 167 m: int = 1024, 168 n: int = 1024, 169 k: int = 1024, 170 rtol: float = 1e-5, 171 ) -> bool: 172 """Verify that the scheduled computation is correct.""" 173 mod = sch.mod 174 target_obj = tvm.target.Target(target) 175 176 with tvm.transform.PassContext(opt_level=3): 177 lib = tvm.build(mod, target=target_obj) 178 179 dev = tvm.device(str(target_obj), 0) 180 a_np = np.random.randn(m, k).astype("float32") 181 b_np = np.random.randn(k, n).astype("float32") 182 c_np = np.zeros((m, n), dtype="float32") 183 184 a_tvm = tvm.nd.array(a_np, dev) 185 b_tvm = tvm.nd.array(b_np, dev) 186 c_tvm = tvm.nd.array(c_np, dev) 187 188 func = lib["main"] 189 func(a_tvm, b_tvm, c_tvm) 190 191 c_ref = np.matmul(a_np, b_np) 192 np.testing.assert_allclose(c_tvm.numpy(), c_ref, rtol=rtol) 193 return True 194 195 196 def _validate_problem_size(size: int, target: str) -> None: 197 _require_positive_multiple("size", size, 8) 198 if target == "cuda": 199 _require_positive_multiple("size", size, 32) 200 201 202 def main() -> None: 203 parser = argparse.ArgumentParser(description="TensorIR Manual Schedule") 204 parser.add_argument("--target", type=str, default="llvm", help="Target (llvm or cuda)") 205 parser.add_argument("--size", type=int, default=1024, help="Matrix size (M=N=K)") 206 parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations") 207 parser.add_argument("--bench", type=int, default=20, help="Benchmark iterations") 208 args = parser.parse_args() 209 210 if not TVM_AVAILABLE: 211 print("TVM is required for this script. Please install TVM.") 212 return 213 214 try: 215 _validate_problem_size(args.size, args.target) 216 except ValueError as exc: 217 print(f"Invalid configuration: {exc}") 218 return 219 220 print("=" * 60) 221 print("TVM End-to-End Optimization - Step 3: TensorIR Manual Schedule") 222 print("=" * 60) 223 224 m = n = k = args.size 225 print(f"\nMatrix size: {m} x {n} x {k}") 226 print(f"Target: {args.target}") 227 228 baseline_mod = create_matmul_module(m, n, k) 229 230 print("\n" + "-" * 40) 231 print("Baseline (no optimization):") 232 baseline_sch = tir.Schedule(baseline_mod) 233 baseline_result = benchmark_schedule( 234 baseline_sch, args.target, m, n, k, args.warmup, args.bench 235 ) 236 print(f" Mean latency: {baseline_result.mean_ms:.3f} ms") 237 238 print("\n" + "-" * 40) 239 print("Applying schedule primitives...") 240 optimized_sch = apply_schedule_primitives(create_matmul_module(m, n, k)) 241 242 print("\n" + "-" * 40) 243 print("Verifying correctness...") 244 try: 245 verify_correctness(optimized_sch, args.target, m, n, k) 246 print(" ✓ Results match numpy reference") 247 except AssertionError as exc: 248 print(f" ✗ Verification failed: {exc}") 249 return 250 251 print("\n" + "-" * 40) 252 print("Optimized schedule:") 253 optimized_result = benchmark_schedule( 254 optimized_sch, args.target, m, n, k, args.warmup, args.bench 255 ) 256 print(f" Mean latency: {optimized_result.mean_ms:.3f} ms") 257 258 print("\n" + "=" * 60) 259 print("Summary:") 260 print(f" Baseline: {baseline_result.mean_ms:.3f} ms") 261 print(f" Optimized: {optimized_result.mean_ms:.3f} ms") 262 print(f" Speedup: {baseline_result.mean_ms / optimized_result.mean_ms:.2f}x") 263 264 print("\n" + "-" * 40) 265 print("Optimized TensorIR:") 266 print(optimized_sch.mod.script()) 267 268 return 0 269 270 271 if __name__ == "__main__": 272 sys.exit(main())