/ 01_TVM_End2End_Optimization / 3_tensorir_manual_schedule.py
3_tensorir_manual_schedule.py
  1  """
  2  TVM End-to-End Optimization - Step 3: TensorIR Manual Schedule
  3  
  4  This script demonstrates:
  5  1. Writing TensorIR programs using TVMScript
  6  2. Applying schedule primitives: split, reorder, bind, vectorize
  7  3. Comparing manual schedule vs baseline performance
  8  
  9  Requirements: 2.3, 2.4
 10  """
 11  
 12  import argparse
 13  import os
 14  import sys
 15  
 16  # Add project root to path before importing common modules
 17  sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
 18  
 19  import numpy as np
 20  
 21  try:
 22      import tvm
 23      from tvm import tir
 24      from tvm.script import tir as tir_script
 25  
 26      TVM_AVAILABLE = True
 27  except ImportError:
 28      TVM_AVAILABLE = False
 29      print("Warning: TVM not installed. This script requires TVM.")
 30  
 31  from common.benchmark.timer import TimingResult, benchmark_function
 32  
 33  if TVM_AVAILABLE:
 34  
 35      def create_matmul_module(m: int, n: int, k: int) -> "tvm.IRModule":
 36          @tvm.script.ir_module
 37          class MatMul:
 38              @tir_script.prim_func
 39              def main(
 40                  A: tir_script.Buffer((m, k), "float32"),
 41                  B: tir_script.Buffer((k, n), "float32"),
 42                  C: tir_script.Buffer((m, n), "float32"),
 43              ):
 44                  tir_script.func_attr({"global_symbol": "main", "tir.noalias": True})
 45                  for i, j, k_idx in tir_script.grid(m, n, k):
 46                      with tir_script.block("matmul"):
 47                          vi, vj, vk = tir_script.axis.remap("SSR", [i, j, k_idx])
 48                          with tir_script.init():
 49                              C[vi, vj] = tir_script.float32(0)
 50                          C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
 51  
 52          return MatMul
 53  
 54  
 55  def _require_positive_multiple(name: str, value: int, multiple: int) -> None:
 56      if value <= 0 or value % multiple != 0:
 57          raise ValueError(f"{name} must be a positive multiple of {multiple}, got {value}")
 58  
 59  
 60  def apply_schedule_primitives(
 61      mod: "tvm.IRModule",
 62      block_size: int = 32,
 63      k_tile: int = 8,
 64  ) -> "tvm.tir.Schedule":
 65      """Apply schedule primitives to optimize MatMul."""
 66      _require_positive_multiple("block_size", block_size, 1)
 67      _require_positive_multiple("k_tile", k_tile, 1)
 68  
 69      sch = tir.Schedule(mod)
 70      block = sch.get_block("matmul")
 71      i, j, k = sch.get_loops(block)
 72  
 73      i_outer, i_inner = sch.split(i, factors=[None, block_size])
 74      j_outer, j_inner = sch.split(j, factors=[None, block_size])
 75      k_outer, k_inner = sch.split(k, factors=[None, k_tile])
 76  
 77      sch.reorder(i_outer, j_outer, k_outer, i_inner, j_inner, k_inner)
 78      sch.vectorize(k_inner)
 79  
 80      print("Applied schedule primitives:")
 81      print(f"  - split(i, [None, {block_size}])")
 82      print(f"  - split(j, [None, {block_size}])")
 83      print(f"  - split(k, [None, {k_tile}])")
 84      print("  - reorder(i_outer, j_outer, k_outer, i_inner, j_inner, k_inner)")
 85      print("  - vectorize(k_inner)")
 86  
 87      return sch
 88  
 89  
 90  def apply_gpu_schedule(
 91      mod: "tvm.IRModule", block_m: int = 128, block_n: int = 128, thread_tile: int = 32
 92  ) -> "tvm.tir.Schedule":
 93      """Apply GPU-optimized schedule for MatMul."""
 94      _require_positive_multiple("block_m", block_m, thread_tile)
 95      _require_positive_multiple("block_n", block_n, thread_tile)
 96  
 97      sch = tir.Schedule(mod)
 98      block = sch.get_block("matmul")
 99      i, j, k = sch.get_loops(block)
100  
101      i_outer, i_inner = sch.split(i, factors=[None, block_m])
102      j_outer, j_inner = sch.split(j, factors=[None, block_n])
103      i_inner_outer, i_inner_inner = sch.split(i_inner, factors=[None, thread_tile])
104      j_inner_outer, j_inner_inner = sch.split(j_inner, factors=[None, thread_tile])
105  
106      sch.reorder(i_outer, j_outer, i_inner_outer, j_inner_outer, i_inner_inner, j_inner_inner, k)
107      sch.bind(i_outer, "blockIdx.x")
108      sch.bind(j_outer, "blockIdx.y")
109      sch.bind(i_inner_outer, "threadIdx.y")
110      sch.bind(j_inner_outer, "threadIdx.x")
111  
112      print("Applied GPU schedule:")
113      print(f"  - Block size: ({block_m}, {block_n})")
114      print(f"  - Thread tile: ({thread_tile}, {thread_tile})")
115      print("  - Bound to blockIdx.x/y, threadIdx.x/y")
116  
117      return sch
118  
119  
120  def benchmark_schedule(
121      sch: "tvm.tir.Schedule",
122      target: str = "llvm",
123      m: int = 1024,
124      n: int = 1024,
125      k: int = 1024,
126      warmup_iters: int = 5,
127      bench_iters: int = 20,
128  ) -> TimingResult:
129      """Benchmark a TensorIR schedule."""
130      mod = sch.mod
131      target_obj = tvm.target.Target(target)
132  
133      with tvm.transform.PassContext(opt_level=3):
134          lib = tvm.build(mod, target=target_obj)
135  
136      dev = tvm.device(str(target_obj), 0)
137      a_np = np.random.randn(m, k).astype("float32")
138      b_np = np.random.randn(k, n).astype("float32")
139      c_np = np.zeros((m, n), dtype="float32")
140  
141      a_tvm = tvm.nd.array(a_np, dev)
142      b_tvm = tvm.nd.array(b_np, dev)
143      c_tvm = tvm.nd.array(c_np, dev)
144      func = lib["main"]
145  
146      def run():
147          func(a_tvm, b_tvm, c_tvm)
148  
149      return benchmark_function(
150          run,
151          warmup_iters=warmup_iters,
152          bench_iters=bench_iters,
153          sync_cuda=(str(target_obj) == "cuda"),
154      )
155  
156  
157  def verify_correctness(
158      sch: "tvm.tir.Schedule",
159      target: str = "llvm",
160      m: int = 1024,
161      n: int = 1024,
162      k: int = 1024,
163      rtol: float = 1e-5,
164  ) -> bool:
165      """Verify that the scheduled computation is correct."""
166      mod = sch.mod
167      target_obj = tvm.target.Target(target)
168  
169      with tvm.transform.PassContext(opt_level=3):
170          lib = tvm.build(mod, target=target_obj)
171  
172      dev = tvm.device(str(target_obj), 0)
173      a_np = np.random.randn(m, k).astype("float32")
174      b_np = np.random.randn(k, n).astype("float32")
175      c_np = np.zeros((m, n), dtype="float32")
176  
177      a_tvm = tvm.nd.array(a_np, dev)
178      b_tvm = tvm.nd.array(b_np, dev)
179      c_tvm = tvm.nd.array(c_np, dev)
180  
181      func = lib["main"]
182      func(a_tvm, b_tvm, c_tvm)
183  
184      c_ref = np.matmul(a_np, b_np)
185      np.testing.assert_allclose(c_tvm.numpy(), c_ref, rtol=rtol)
186      return True
187  
188  
189  def _validate_problem_size(size: int, target: str) -> None:
190      _require_positive_multiple("size", size, 8)
191      if target == "cuda":
192          _require_positive_multiple("size", size, 32)
193  
194  
195  def main() -> None:
196      parser = argparse.ArgumentParser(description="TensorIR Manual Schedule")
197      parser.add_argument("--target", type=str, default="llvm", help="Target (llvm or cuda)")
198      parser.add_argument("--size", type=int, default=1024, help="Matrix size (M=N=K)")
199      parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations")
200      parser.add_argument("--bench", type=int, default=20, help="Benchmark iterations")
201      args = parser.parse_args()
202  
203      if not TVM_AVAILABLE:
204          print("TVM is required for this script. Please install TVM.")
205          return
206  
207      try:
208          _validate_problem_size(args.size, args.target)
209      except ValueError as exc:
210          print(f"Invalid configuration: {exc}")
211          return
212  
213      print("=" * 60)
214      print("TVM End-to-End Optimization - Step 3: TensorIR Manual Schedule")
215      print("=" * 60)
216  
217      m = n = k = args.size
218      print(f"\nMatrix size: {m} x {n} x {k}")
219      print(f"Target: {args.target}")
220  
221      baseline_mod = create_matmul_module(m, n, k)
222  
223      print("\n" + "-" * 40)
224      print("Baseline (no optimization):")
225      baseline_sch = tir.Schedule(baseline_mod)
226      baseline_result = benchmark_schedule(
227          baseline_sch, args.target, m, n, k, args.warmup, args.bench
228      )
229      print(f"  Mean latency: {baseline_result.mean_ms:.3f} ms")
230  
231      print("\n" + "-" * 40)
232      print("Applying schedule primitives...")
233      optimized_sch = apply_schedule_primitives(create_matmul_module(m, n, k))
234  
235      print("\n" + "-" * 40)
236      print("Verifying correctness...")
237      try:
238          verify_correctness(optimized_sch, args.target, m, n, k)
239          print("  ✓ Results match numpy reference")
240      except AssertionError as exc:
241          print(f"  ✗ Verification failed: {exc}")
242          return
243  
244      print("\n" + "-" * 40)
245      print("Optimized schedule:")
246      optimized_result = benchmark_schedule(
247          optimized_sch, args.target, m, n, k, args.warmup, args.bench
248      )
249      print(f"  Mean latency: {optimized_result.mean_ms:.3f} ms")
250  
251      print("\n" + "=" * 60)
252      print("Summary:")
253      print(f"  Baseline:  {baseline_result.mean_ms:.3f} ms")
254      print(f"  Optimized: {optimized_result.mean_ms:.3f} ms")
255      print(f"  Speedup:   {baseline_result.mean_ms / optimized_result.mean_ms:.2f}x")
256  
257      print("\n" + "-" * 40)
258      print("Optimized TensorIR:")
259      print(optimized_sch.mod.script())
260  
261  
262  if __name__ == "__main__":
263      main()