/ 01_TVM_End2End_Optimization / 2_auto_scheduler_tuning.py
2_auto_scheduler_tuning.py
  1  """
  2  TVM End-to-End Optimization - Step 2: Auto-Scheduler Tuning
  3  
  4  This script demonstrates:
  5  1. Using Ansor (Auto-Scheduler) for automatic optimization
  6  2. Extracting tuning tasks from Relay IR
  7  3. Running auto-tuning with configurable trials
  8  4. Applying tuned schedules for optimized inference
  9  
 10  Requirements: 2.2
 11  """
 12  
 13  import argparse
 14  import os
 15  import sys
 16  
 17  import numpy as np
 18  import torch
 19  import torchvision.models as models
 20  
 21  # TVM imports
 22  try:
 23      import tvm
 24      from tvm import auto_scheduler, relay
 25      from tvm.contrib import graph_executor
 26  
 27      TVM_AVAILABLE = True
 28  except ImportError:
 29      TVM_AVAILABLE = False
 30      print("Warning: TVM not installed. This script requires TVM.")
 31  
 32  sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
 33  from common.benchmark.timer import TimingResult, benchmark_function
 34  
 35  
 36  def load_model_and_convert(
 37      model_name: str = "resnet50", input_shape: tuple[int, ...] = (1, 3, 224, 224)
 38  ) -> tuple["tvm.IRModule", dict]:
 39      """Load PyTorch model and convert to Relay IR."""
 40      model_map = {
 41          "resnet50": models.resnet50,
 42          "resnet18": models.resnet18,
 43          "mobilenet_v2": models.mobilenet_v2,
 44      }
 45  
 46      model = model_map[model_name](weights="DEFAULT")
 47      model.eval()
 48  
 49      example_input = torch.randn(*input_shape)
 50      with torch.no_grad():
 51          scripted_model = torch.jit.trace(model, example_input)
 52  
 53      input_infos = [("input0", input_shape)]
 54      mod, params = relay.frontend.from_pytorch(scripted_model, input_infos)
 55  
 56      return mod, params
 57  
 58  
 59  def run_auto_scheduler(
 60      mod: "tvm.IRModule",
 61      params: dict,
 62      target: str = "cuda",
 63      num_trials: int = 1000,
 64      log_file: str = "tuning_logs/resnet50.json",
 65      early_stopping: int = 600,
 66  ) -> "tvm.IRModule":
 67      """
 68      Run Ansor auto-scheduler for automatic optimization.
 69  
 70      Args:
 71          mod: TVM Relay IRModule
 72          params: Model parameters
 73          target: Compilation target
 74          num_trials: Number of tuning trials
 75          log_file: Path to save tuning log
 76          early_stopping: Stop if no improvement after this many trials
 77  
 78      Returns:
 79          Optimized IRModule
 80      """
 81      if not TVM_AVAILABLE:
 82          raise RuntimeError("TVM is required")
 83  
 84      target = tvm.target.Target(target)
 85  
 86      # Ensure log directory exists
 87      os.makedirs(os.path.dirname(log_file), exist_ok=True)
 88  
 89      print("Extracting tasks from Relay IR...")
 90      tasks, task_weights = auto_scheduler.extract_tasks(mod, params, target)
 91  
 92      print(f"Found {len(tasks)} tuning tasks:")
 93      for i, task in enumerate(tasks):
 94          print(f"  Task {i}: {task.desc}")
 95  
 96      # Configure tuning options
 97      tuning_options = auto_scheduler.TuningOptions(
 98          num_measure_trials=num_trials,
 99          runner=auto_scheduler.LocalRunner(
100              timeout=10, number=3, repeat=1, min_repeat_ms=100, enable_cpu_cache_flush=False
101          ),
102          measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
103          early_stopping=early_stopping,
104      )
105  
106      print(f"\nStarting auto-tuning with {num_trials} trials...")
107      print(f"Log file: {log_file}")
108  
109      # Run tuning
110      tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
111      tuner.tune(tuning_options)
112  
113      print(f"\nTuning complete. Results saved to {log_file}")
114  
115      return mod
116  
117  
118  def compile_with_tuned_schedule(
119      mod: "tvm.IRModule",
120      params: dict,
121      target: str = "cuda",
122      log_file: str = "tuning_logs/resnet50.json",
123  ) -> "tvm.runtime.Module":
124      """
125      Compile model with tuned schedule.
126  
127      Args:
128          mod: TVM Relay IRModule
129          params: Model parameters
130          target: Compilation target
131          log_file: Path to tuning log
132  
133      Returns:
134          Compiled runtime module
135      """
136      target = tvm.target.Target(target)
137  
138      with (
139          auto_scheduler.ApplyHistoryBest(log_file),
140          tvm.transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}),
141      ):
142          lib = relay.build(mod, target=target, params=params)
143  
144      return lib
145  
146  
147  def benchmark_tuned_model(
148      lib: "tvm.runtime.Module",
149      input_shape: tuple[int, ...],
150      target: str = "cuda",
151      warmup_iters: int = 10,
152      bench_iters: int = 100,
153  ) -> TimingResult:
154      """Benchmark the tuned model."""
155      dev = tvm.device(target, 0)
156      module = graph_executor.GraphModule(lib["default"](dev))
157  
158      input_data = np.random.randn(*input_shape).astype("float32")
159      module.set_input("input0", input_data)
160  
161      def run_inference():
162          module.run()
163  
164      result = benchmark_function(
165          run_inference, warmup_iters=warmup_iters, bench_iters=bench_iters, sync_cuda=True
166      )
167  
168      return result
169  
170  
171  def main() -> None:
172      parser = argparse.ArgumentParser(description="TVM Auto-Scheduler Tuning")
173      parser.add_argument("--model", type=str, default="resnet50", help="Model name")
174      parser.add_argument("--num_trials", type=int, default=1000, help="Number of tuning trials")
175      parser.add_argument(
176          "--log_file", type=str, default="tuning_logs/resnet50.json", help="Tuning log file"
177      )
178      parser.add_argument("--target", type=str, default="cuda", help="Target device")
179      parser.add_argument("--tune", action="store_true", help="Run tuning (skip if log exists)")
180      parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations")
181      parser.add_argument("--bench", type=int, default=100, help="Benchmark iterations")
182      args = parser.parse_args()
183  
184      if not TVM_AVAILABLE:
185          print("TVM is required for this script. Please install TVM.")
186          return
187  
188      print("=" * 60)
189      print("TVM End-to-End Optimization - Step 2: Auto-Scheduler Tuning")
190      print("=" * 60)
191  
192      input_shape = (1, 3, 224, 224)
193  
194      # Load and convert model
195      print(f"\nLoading {args.model}...")
196      mod, params = load_model_and_convert(args.model, input_shape)
197  
198      # Run tuning if requested or log doesn't exist
199      log_path = os.path.join(os.path.dirname(__file__), args.log_file)
200  
201      if args.tune or not os.path.exists(log_path):
202          print("\n" + "-" * 40)
203          run_auto_scheduler(
204              mod, params, target=args.target, num_trials=args.num_trials, log_file=log_path
205          )
206      else:
207          print(f"\nUsing existing tuning log: {log_path}")
208  
209      # Compile with tuned schedule
210      print("\n" + "-" * 40)
211      print("Compiling with tuned schedule...")
212      lib = compile_with_tuned_schedule(mod, params, args.target, log_path)
213  
214      # Benchmark
215      print("\n" + "-" * 40)
216      result = benchmark_tuned_model(
217          lib, input_shape, args.target, warmup_iters=args.warmup, bench_iters=args.bench
218      )
219  
220      print("Auto-tuned inference:")
221      print(f"  Mean latency: {result.mean_ms:.3f} ms")
222      print(f"  Std: {result.std_ms:.3f} ms")
223      print(f"  Min: {result.min_ms:.3f} ms")
224      print(f"  Max: {result.max_ms:.3f} ms")
225  
226  
227  if __name__ == "__main__":
228      main()