auto_scheduler.py
1 """ 2 TVM Auto-Scheduler Tuning module. 3 4 This module demonstrates: 5 1. Using Ansor (Auto-Scheduler) for automatic optimization 6 2. Extracting tuning tasks from Relay IR 7 3. Running auto-tuning with configurable trials 8 4. Applying tuned schedules for optimized inference 9 10 Requirements: 2.2 11 """ 12 13 from __future__ import annotations 14 15 import argparse 16 from typing import TYPE_CHECKING 17 18 import numpy as np 19 import torch 20 import torchvision.models as models 21 22 from common.benchmark.timer import TimingResult, benchmark_function 23 from common.utils.path_utils import ensure_project_root_in_path 24 25 # Ensure project root is in path for imports 26 ensure_project_root_in_path() 27 28 if TYPE_CHECKING: 29 import tvm 30 31 # TVM imports 32 try: 33 import tvm 34 from tvm import auto_scheduler, relay 35 from tvm.contrib import graph_executor 36 37 TVM_AVAILABLE = True 38 except ImportError: 39 TVM_AVAILABLE = False 40 41 42 def load_model_and_convert( 43 model_name: str = "resnet50", input_shape: tuple[int, ...] = (1, 3, 224, 224) 44 ) -> tuple[tvm.IRModule, dict]: 45 """Load PyTorch model and convert to Relay IR.""" 46 model_map = { 47 "resnet50": models.resnet50, 48 "resnet18": models.resnet18, 49 "mobilenet_v2": models.mobilenet_v2, 50 } 51 52 model = model_map[model_name](weights="DEFAULT") 53 model.eval() 54 55 example_input = torch.randn(*input_shape) 56 with torch.no_grad(): 57 scripted_model = torch.jit.trace(model, example_input) 58 59 input_infos = [("input0", input_shape)] 60 mod, params = relay.frontend.from_pytorch(scripted_model, input_infos) 61 62 return mod, params 63 64 65 def run_auto_scheduler( 66 mod: tvm.IRModule, 67 params: dict, 68 target: str = "cuda", 69 num_trials: int = 1000, 70 log_file: str = "tuning_logs/resnet50.json", 71 early_stopping: int = 600, 72 ) -> tvm.IRModule: 73 """ 74 Run Ansor auto-scheduler for automatic optimization. 75 76 Args: 77 mod: TVM Relay IRModule 78 params: Model parameters 79 target: Compilation target 80 num_trials: Number of tuning trials 81 log_file: Path to save tuning log 82 early_stopping: Stop if no improvement after this many trials 83 84 Returns: 85 Optimized IRModule 86 """ 87 if not TVM_AVAILABLE: 88 raise RuntimeError("TVM is required") 89 90 target = tvm.target.Target(target) 91 92 # Ensure log directory exists 93 os.makedirs(os.path.dirname(log_file), exist_ok=True) 94 95 print("Extracting tasks from Relay IR...") 96 tasks, task_weights = auto_scheduler.extract_tasks(mod, params, target) 97 98 print(f"Found {len(tasks)} tuning tasks:") 99 for i, task in enumerate(tasks): 100 print(f" Task {i}: {task.desc}") 101 102 # Configure tuning options 103 tuning_options = auto_scheduler.TuningOptions( 104 num_measure_trials=num_trials, 105 runner=auto_scheduler.LocalRunner( 106 timeout=10, number=3, repeat=1, min_repeat_ms=100, enable_cpu_cache_flush=False 107 ), 108 measure_callbacks=[auto_scheduler.RecordToFile(log_file)], 109 early_stopping=early_stopping, 110 ) 111 112 print(f"\nStarting auto-tuning with {num_trials} trials...") 113 print(f"Log file: {log_file}") 114 115 # Run tuning 116 tuner = auto_scheduler.TaskScheduler(tasks, task_weights) 117 tuner.tune(tuning_options) 118 119 print(f"\nTuning complete. Results saved to {log_file}") 120 121 return mod 122 123 124 def compile_with_tuned_schedule( 125 mod: tvm.IRModule, 126 params: dict, 127 target: str = "cuda", 128 log_file: str = "tuning_logs/resnet50.json", 129 ) -> tvm.runtime.Module: 130 """ 131 Compile model with tuned schedule. 132 133 Args: 134 mod: TVM Relay IRModule 135 params: Model parameters 136 target: Compilation target 137 log_file: Path to tuning log 138 139 Returns: 140 Compiled runtime module 141 """ 142 target = tvm.target.Target(target) 143 144 # Apply tuned schedule 145 with ( 146 auto_scheduler.ApplyHistoryBest(log_file), 147 tvm.transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}), 148 ): 149 lib = relay.build(mod, target=target, params=params) 150 151 return lib 152 153 154 def benchmark_tuned_model( 155 lib: tvm.runtime.Module, 156 input_shape: tuple[int, ...], 157 target: str = "cuda", 158 warmup_iters: int = 10, 159 bench_iters: int = 100, 160 ) -> TimingResult: 161 """Benchmark the tuned model.""" 162 dev = tvm.device(target, 0) 163 module = graph_executor.GraphModule(lib["default"](dev)) 164 165 input_data = np.random.randn(*input_shape).astype("float32") 166 module.set_input("input0", input_data) 167 168 def run_inference(): 169 module.run() 170 171 result = benchmark_function( 172 run_inference, warmup_iters=warmup_iters, bench_iters=bench_iters, sync_cuda=True 173 ) 174 175 return result 176 177 178 def main() -> None: 179 parser = argparse.ArgumentParser(description="TVM Auto-Scheduler Tuning") 180 parser.add_argument("--model", type=str, default="resnet50", help="Model name") 181 parser.add_argument("--num_trials", type=int, default=1000, help="Number of tuning trials") 182 parser.add_argument( 183 "--log_file", type=str, default="tuning_logs/resnet50.json", help="Tuning log file" 184 ) 185 parser.add_argument("--target", type=str, default="cuda", help="Target device") 186 parser.add_argument("--tune", action="store_true", help="Run tuning (skip if log exists)") 187 parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations") 188 parser.add_argument("--bench", type=int, default=100, help="Benchmark iterations") 189 args = parser.parse_args() 190 191 if not TVM_AVAILABLE: 192 print("TVM is required for this script. Please install TVM.") 193 return 194 195 print("=" * 60) 196 print("TVM End-to-End Optimization - Step 2: Auto-Scheduler Tuning") 197 print("=" * 60) 198 199 input_shape = (1, 3, 224, 224) 200 201 # Load and convert model 202 print(f"\nLoading {args.model}...") 203 mod, params = load_model_and_convert(args.model, input_shape) 204 205 # Run tuning if requested or log doesn't exist 206 log_path = os.path.join(os.path.dirname(__file__), args.log_file) 207 208 if args.tune or not os.path.exists(log_path): 209 print("\n" + "-" * 40) 210 run_auto_scheduler( 211 mod, params, target=args.target, num_trials=args.num_trials, log_file=log_path 212 ) 213 else: 214 print(f"\nUsing existing tuning log: {log_path}") 215 216 # Compile with tuned schedule 217 print("\n" + "-" * 40) 218 print("Compiling with tuned schedule...") 219 lib = compile_with_tuned_schedule(mod, params, args.target, log_path) 220 221 # Benchmark 222 print("\n" + "-" * 40) 223 result = benchmark_tuned_model( 224 lib, input_shape, args.target, warmup_iters=args.warmup, bench_iters=args.bench 225 ) 226 227 print("Auto-tuned inference:") 228 print(f" Mean latency: {result.mean_ms:.3f} ms") 229 print(f" Std: {result.std_ms:.3f} ms") 230 print(f" Min: {result.min_ms:.3f} ms") 231 print(f" Max: {result.max_ms:.3f} ms") 232 233 return 0 234 235 236 if __name__ == "__main__": 237 sys.exit(main())