/ 01_TVM_End2End_Optimization / 2_auto_scheduler_tuning.py
2_auto_scheduler_tuning.py
1 """ 2 TVM End-to-End Optimization - Step 2: Auto-Scheduler Tuning 3 4 This script demonstrates: 5 1. Using Ansor (Auto-Scheduler) for automatic optimization 6 2. Extracting tuning tasks from Relay IR 7 3. Running auto-tuning with configurable trials 8 4. Applying tuned schedules for optimized inference 9 10 Requirements: 2.2 11 """ 12 13 import argparse 14 import os 15 import sys 16 17 import numpy as np 18 import torch 19 import torchvision.models as models 20 21 # TVM imports 22 try: 23 import tvm 24 from tvm import auto_scheduler, relay 25 from tvm.contrib import graph_executor 26 27 TVM_AVAILABLE = True 28 except ImportError: 29 TVM_AVAILABLE = False 30 print("Warning: TVM not installed. This script requires TVM.") 31 32 sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) 33 from common.benchmark.timer import TimingResult, benchmark_function 34 35 36 def load_model_and_convert( 37 model_name: str = "resnet50", input_shape: tuple[int, ...] = (1, 3, 224, 224) 38 ) -> tuple["tvm.IRModule", dict]: 39 """Load PyTorch model and convert to Relay IR.""" 40 model_map = { 41 "resnet50": models.resnet50, 42 "resnet18": models.resnet18, 43 "mobilenet_v2": models.mobilenet_v2, 44 } 45 46 model = model_map[model_name](weights="DEFAULT") 47 model.eval() 48 49 example_input = torch.randn(*input_shape) 50 with torch.no_grad(): 51 scripted_model = torch.jit.trace(model, example_input) 52 53 input_infos = [("input0", input_shape)] 54 mod, params = relay.frontend.from_pytorch(scripted_model, input_infos) 55 56 return mod, params 57 58 59 def run_auto_scheduler( 60 mod: "tvm.IRModule", 61 params: dict, 62 target: str = "cuda", 63 num_trials: int = 1000, 64 log_file: str = "tuning_logs/resnet50.json", 65 early_stopping: int = 600, 66 ) -> "tvm.IRModule": 67 """ 68 Run Ansor auto-scheduler for automatic optimization. 69 70 Args: 71 mod: TVM Relay IRModule 72 params: Model parameters 73 target: Compilation target 74 num_trials: Number of tuning trials 75 log_file: Path to save tuning log 76 early_stopping: Stop if no improvement after this many trials 77 78 Returns: 79 Optimized IRModule 80 """ 81 if not TVM_AVAILABLE: 82 raise RuntimeError("TVM is required") 83 84 target = tvm.target.Target(target) 85 86 # Ensure log directory exists 87 os.makedirs(os.path.dirname(log_file), exist_ok=True) 88 89 print("Extracting tasks from Relay IR...") 90 tasks, task_weights = auto_scheduler.extract_tasks(mod, params, target) 91 92 print(f"Found {len(tasks)} tuning tasks:") 93 for i, task in enumerate(tasks): 94 print(f" Task {i}: {task.desc}") 95 96 # Configure tuning options 97 tuning_options = auto_scheduler.TuningOptions( 98 num_measure_trials=num_trials, 99 runner=auto_scheduler.LocalRunner( 100 timeout=10, number=3, repeat=1, min_repeat_ms=100, enable_cpu_cache_flush=False 101 ), 102 measure_callbacks=[auto_scheduler.RecordToFile(log_file)], 103 early_stopping=early_stopping, 104 ) 105 106 print(f"\nStarting auto-tuning with {num_trials} trials...") 107 print(f"Log file: {log_file}") 108 109 # Run tuning 110 tuner = auto_scheduler.TaskScheduler(tasks, task_weights) 111 tuner.tune(tuning_options) 112 113 print(f"\nTuning complete. Results saved to {log_file}") 114 115 return mod 116 117 118 def compile_with_tuned_schedule( 119 mod: "tvm.IRModule", 120 params: dict, 121 target: str = "cuda", 122 log_file: str = "tuning_logs/resnet50.json", 123 ) -> "tvm.runtime.Module": 124 """ 125 Compile model with tuned schedule. 126 127 Args: 128 mod: TVM Relay IRModule 129 params: Model parameters 130 target: Compilation target 131 log_file: Path to tuning log 132 133 Returns: 134 Compiled runtime module 135 """ 136 target = tvm.target.Target(target) 137 138 with ( 139 auto_scheduler.ApplyHistoryBest(log_file), 140 tvm.transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}), 141 ): 142 lib = relay.build(mod, target=target, params=params) 143 144 return lib 145 146 147 def benchmark_tuned_model( 148 lib: "tvm.runtime.Module", 149 input_shape: tuple[int, ...], 150 target: str = "cuda", 151 warmup_iters: int = 10, 152 bench_iters: int = 100, 153 ) -> TimingResult: 154 """Benchmark the tuned model.""" 155 dev = tvm.device(target, 0) 156 module = graph_executor.GraphModule(lib["default"](dev)) 157 158 input_data = np.random.randn(*input_shape).astype("float32") 159 module.set_input("input0", input_data) 160 161 def run_inference(): 162 module.run() 163 164 result = benchmark_function( 165 run_inference, warmup_iters=warmup_iters, bench_iters=bench_iters, sync_cuda=True 166 ) 167 168 return result 169 170 171 def main() -> None: 172 parser = argparse.ArgumentParser(description="TVM Auto-Scheduler Tuning") 173 parser.add_argument("--model", type=str, default="resnet50", help="Model name") 174 parser.add_argument("--num_trials", type=int, default=1000, help="Number of tuning trials") 175 parser.add_argument( 176 "--log_file", type=str, default="tuning_logs/resnet50.json", help="Tuning log file" 177 ) 178 parser.add_argument("--target", type=str, default="cuda", help="Target device") 179 parser.add_argument("--tune", action="store_true", help="Run tuning (skip if log exists)") 180 parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations") 181 parser.add_argument("--bench", type=int, default=100, help="Benchmark iterations") 182 args = parser.parse_args() 183 184 if not TVM_AVAILABLE: 185 print("TVM is required for this script. Please install TVM.") 186 return 187 188 print("=" * 60) 189 print("TVM End-to-End Optimization - Step 2: Auto-Scheduler Tuning") 190 print("=" * 60) 191 192 input_shape = (1, 3, 224, 224) 193 194 # Load and convert model 195 print(f"\nLoading {args.model}...") 196 mod, params = load_model_and_convert(args.model, input_shape) 197 198 # Run tuning if requested or log doesn't exist 199 log_path = os.path.join(os.path.dirname(__file__), args.log_file) 200 201 if args.tune or not os.path.exists(log_path): 202 print("\n" + "-" * 40) 203 run_auto_scheduler( 204 mod, params, target=args.target, num_trials=args.num_trials, log_file=log_path 205 ) 206 else: 207 print(f"\nUsing existing tuning log: {log_path}") 208 209 # Compile with tuned schedule 210 print("\n" + "-" * 40) 211 print("Compiling with tuned schedule...") 212 lib = compile_with_tuned_schedule(mod, params, args.target, log_path) 213 214 # Benchmark 215 print("\n" + "-" * 40) 216 result = benchmark_tuned_model( 217 lib, input_shape, args.target, warmup_iters=args.warmup, bench_iters=args.bench 218 ) 219 220 print("Auto-tuned inference:") 221 print(f" Mean latency: {result.mean_ms:.3f} ms") 222 print(f" Std: {result.std_ms:.3f} ms") 223 print(f" Min: {result.min_ms:.3f} ms") 224 print(f" Max: {result.max_ms:.3f} ms") 225 226 227 if __name__ == "__main__": 228 main()