tensorir_schedule.py
  1  """
  2  TVM End-to-End Optimization - Step 3: TensorIR Manual Schedule
  3  
  4  This script demonstrates:
  5  1. Writing TensorIR programs using TVMScript
  6  2. Applying schedule primitives: split, reorder, bind, vectorize
  7  3. Comparing manual schedule vs baseline performance
  8  
  9  Requirements: 2.3, 2.4
 10  """
 11  
 12  """TVM TensorIR Manual Schedule module."""
 13  
 14  from __future__ import annotations
 15  
 16  import argparse
 17  from typing import TYPE_CHECKING
 18  
 19  import numpy as np
 20  
 21  from common.benchmark.timer import TimingResult, benchmark_function
 22  from common.utils.path_utils import ensure_project_root_in_path
 23  
 24  # Ensure project root is in path for imports
 25  ensure_project_root_in_path()
 26  
 27  if TYPE_CHECKING:
 28      import tvm
 29  
 30  # TVM imports
 31  try:
 32      import tvm
 33      from tvm import tir
 34      from tvm.script import tir as T
 35  
 36      TVM_AVAILABLE = True
 37  except ImportError:
 38      TVM_AVAILABLE = False
 39  
 40  if TVM_AVAILABLE:
 41  
 42      def create_matmul_module(m: int, n: int, k: int) -> tvm.IRModule:
 43          @tvm.script.ir_module
 44          class MatMul:
 45              @T.prim_func
 46              def main(
 47                  A: T.Buffer((m, k), "float32"),
 48                  B: T.Buffer((k, n), "float32"),
 49                  C: T.Buffer((m, n), "float32"),
 50              ):
 51                  T.func_attr({"global_symbol": "main", "tir.noalias": True})
 52                  for i, j, k_idx in T.grid(m, n, k):
 53                      with T.block("matmul"):
 54                          vi, vj, vk = T.axis.remap("SSR", [i, j, k_idx])
 55                          with T.init():
 56                              C[vi, vj] = T.float32(0)
 57                          C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
 58  
 59          return MatMul
 60  
 61  
 62  def _require_positive_multiple(name: str, value: int, multiple: int) -> None:
 63      if value <= 0 or value % multiple != 0:
 64          raise ValueError(f"{name} must be a positive multiple of {multiple}, got {value}")
 65  
 66  
 67  def apply_schedule_primitives(
 68      mod: tvm.IRModule,
 69      block_size: int = 32,
 70      k_tile: int = 8,
 71  ) -> tvm.tir.Schedule:
 72      """Apply schedule primitives to optimize MatMul."""
 73      _require_positive_multiple("block_size", block_size, 1)
 74      _require_positive_multiple("k_tile", k_tile, 1)
 75  
 76      sch = tir.Schedule(mod)
 77      block = sch.get_block("matmul")
 78      i, j, k = sch.get_loops(block)
 79  
 80      i_outer, i_inner = sch.split(i, factors=[None, block_size])
 81      j_outer, j_inner = sch.split(j, factors=[None, block_size])
 82      k_outer, k_inner = sch.split(k, factors=[None, k_tile])
 83  
 84      sch.reorder(i_outer, j_outer, k_outer, i_inner, j_inner, k_inner)
 85      sch.vectorize(k_inner)
 86  
 87      print("Applied schedule primitives:")
 88      print(f"  - split(i, [None, {block_size}])")
 89      print(f"  - split(j, [None, {block_size}])")
 90      print(f"  - split(k, [None, {k_tile}])")
 91      print("  - reorder(i_outer, j_outer, k_outer, i_inner, j_inner, k_inner)")
 92      print("  - vectorize(k_inner)")
 93  
 94      return sch
 95  
 96  
 97  def apply_gpu_schedule(
 98      mod: tvm.IRModule, block_m: int = 128, block_n: int = 128, thread_tile: int = 32
 99  ) -> tvm.tir.Schedule:
100      """Apply GPU-optimized schedule for MatMul."""
101      _require_positive_multiple("block_m", block_m, thread_tile)
102      _require_positive_multiple("block_n", block_n, thread_tile)
103  
104      sch = tir.Schedule(mod)
105      block = sch.get_block("matmul")
106      i, j, k = sch.get_loops(block)
107  
108      i_outer, i_inner = sch.split(i, factors=[None, block_m])
109      j_outer, j_inner = sch.split(j, factors=[None, block_n])
110      i_inner_outer, i_inner_inner = sch.split(i_inner, factors=[None, thread_tile])
111      j_inner_outer, j_inner_inner = sch.split(j_inner, factors=[None, thread_tile])
112  
113      sch.reorder(i_outer, j_outer, i_inner_outer, j_inner_outer, i_inner_inner, j_inner_inner, k)
114      sch.bind(i_outer, "blockIdx.x")
115      sch.bind(j_outer, "blockIdx.y")
116      sch.bind(i_inner_outer, "threadIdx.y")
117      sch.bind(j_inner_outer, "threadIdx.x")
118  
119      print("Applied GPU schedule:")
120      print(f"  - Block size: ({block_m}, {block_n})")
121      print(f"  - Thread tile: ({thread_tile}, {thread_tile})")
122      print("  - Bound to blockIdx.x/y, threadIdx.x/y")
123  
124      return sch
125  
126  
127  def benchmark_schedule(
128      sch: tvm.tir.Schedule,
129      target: str = "llvm",
130      m: int = 1024,
131      n: int = 1024,
132      k: int = 1024,
133      warmup_iters: int = 5,
134      bench_iters: int = 20,
135  ) -> TimingResult:
136      """Benchmark a TensorIR schedule."""
137      mod = sch.mod
138      target_obj = tvm.target.Target(target)
139  
140      with tvm.transform.PassContext(opt_level=3):
141          lib = tvm.build(mod, target=target_obj)
142  
143      dev = tvm.device(str(target_obj), 0)
144      a_np = np.random.randn(m, k).astype("float32")
145      b_np = np.random.randn(k, n).astype("float32")
146      c_np = np.zeros((m, n), dtype="float32")
147  
148      a_tvm = tvm.nd.array(a_np, dev)
149      b_tvm = tvm.nd.array(b_np, dev)
150      c_tvm = tvm.nd.array(c_np, dev)
151      func = lib["main"]
152  
153      def run():
154          func(a_tvm, b_tvm, c_tvm)
155  
156      return benchmark_function(
157          run,
158          warmup_iters=warmup_iters,
159          bench_iters=bench_iters,
160          sync_cuda=(str(target_obj) == "cuda"),
161      )
162  
163  
164  def verify_correctness(
165      sch: tvm.tir.Schedule,
166      target: str = "llvm",
167      m: int = 1024,
168      n: int = 1024,
169      k: int = 1024,
170      rtol: float = 1e-5,
171  ) -> bool:
172      """Verify that the scheduled computation is correct."""
173      mod = sch.mod
174      target_obj = tvm.target.Target(target)
175  
176      with tvm.transform.PassContext(opt_level=3):
177          lib = tvm.build(mod, target=target_obj)
178  
179      dev = tvm.device(str(target_obj), 0)
180      a_np = np.random.randn(m, k).astype("float32")
181      b_np = np.random.randn(k, n).astype("float32")
182      c_np = np.zeros((m, n), dtype="float32")
183  
184      a_tvm = tvm.nd.array(a_np, dev)
185      b_tvm = tvm.nd.array(b_np, dev)
186      c_tvm = tvm.nd.array(c_np, dev)
187  
188      func = lib["main"]
189      func(a_tvm, b_tvm, c_tvm)
190  
191      c_ref = np.matmul(a_np, b_np)
192      np.testing.assert_allclose(c_tvm.numpy(), c_ref, rtol=rtol)
193      return True
194  
195  
196  def _validate_problem_size(size: int, target: str) -> None:
197      _require_positive_multiple("size", size, 8)
198      if target == "cuda":
199          _require_positive_multiple("size", size, 32)
200  
201  
202  def main() -> None:
203      parser = argparse.ArgumentParser(description="TensorIR Manual Schedule")
204      parser.add_argument("--target", type=str, default="llvm", help="Target (llvm or cuda)")
205      parser.add_argument("--size", type=int, default=1024, help="Matrix size (M=N=K)")
206      parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations")
207      parser.add_argument("--bench", type=int, default=20, help="Benchmark iterations")
208      args = parser.parse_args()
209  
210      if not TVM_AVAILABLE:
211          print("TVM is required for this script. Please install TVM.")
212          return
213  
214      try:
215          _validate_problem_size(args.size, args.target)
216      except ValueError as exc:
217          print(f"Invalid configuration: {exc}")
218          return
219  
220      print("=" * 60)
221      print("TVM End-to-End Optimization - Step 3: TensorIR Manual Schedule")
222      print("=" * 60)
223  
224      m = n = k = args.size
225      print(f"\nMatrix size: {m} x {n} x {k}")
226      print(f"Target: {args.target}")
227  
228      baseline_mod = create_matmul_module(m, n, k)
229  
230      print("\n" + "-" * 40)
231      print("Baseline (no optimization):")
232      baseline_sch = tir.Schedule(baseline_mod)
233      baseline_result = benchmark_schedule(
234          baseline_sch, args.target, m, n, k, args.warmup, args.bench
235      )
236      print(f"  Mean latency: {baseline_result.mean_ms:.3f} ms")
237  
238      print("\n" + "-" * 40)
239      print("Applying schedule primitives...")
240      optimized_sch = apply_schedule_primitives(create_matmul_module(m, n, k))
241  
242      print("\n" + "-" * 40)
243      print("Verifying correctness...")
244      try:
245          verify_correctness(optimized_sch, args.target, m, n, k)
246          print("  ✓ Results match numpy reference")
247      except AssertionError as exc:
248          print(f"  ✗ Verification failed: {exc}")
249          return
250  
251      print("\n" + "-" * 40)
252      print("Optimized schedule:")
253      optimized_result = benchmark_schedule(
254          optimized_sch, args.target, m, n, k, args.warmup, args.bench
255      )
256      print(f"  Mean latency: {optimized_result.mean_ms:.3f} ms")
257  
258      print("\n" + "=" * 60)
259      print("Summary:")
260      print(f"  Baseline:  {baseline_result.mean_ms:.3f} ms")
261      print(f"  Optimized: {optimized_result.mean_ms:.3f} ms")
262      print(f"  Speedup:   {baseline_result.mean_ms / optimized_result.mean_ms:.2f}x")
263  
264      print("\n" + "-" * 40)
265      print("Optimized TensorIR:")
266      print(optimized_sch.mod.script())
267  
268      return 0
269  
270  
271  if __name__ == "__main__":
272      sys.exit(main())