auto_scheduler.py
  1  """
  2  TVM Auto-Scheduler Tuning module.
  3  
  4  This module demonstrates:
  5  1. Using Ansor (Auto-Scheduler) for automatic optimization
  6  2. Extracting tuning tasks from Relay IR
  7  3. Running auto-tuning with configurable trials
  8  4. Applying tuned schedules for optimized inference
  9  
 10  Requirements: 2.2
 11  """
 12  
 13  from __future__ import annotations
 14  
 15  import argparse
 16  from typing import TYPE_CHECKING
 17  
 18  import numpy as np
 19  import torch
 20  import torchvision.models as models
 21  
 22  from common.benchmark.timer import TimingResult, benchmark_function
 23  from common.utils.path_utils import ensure_project_root_in_path
 24  
 25  # Ensure project root is in path for imports
 26  ensure_project_root_in_path()
 27  
 28  if TYPE_CHECKING:
 29      import tvm
 30  
 31  # TVM imports
 32  try:
 33      import tvm
 34      from tvm import auto_scheduler, relay
 35      from tvm.contrib import graph_executor
 36  
 37      TVM_AVAILABLE = True
 38  except ImportError:
 39      TVM_AVAILABLE = False
 40  
 41  
 42  def load_model_and_convert(
 43      model_name: str = "resnet50", input_shape: tuple[int, ...] = (1, 3, 224, 224)
 44  ) -> tuple[tvm.IRModule, dict]:
 45      """Load PyTorch model and convert to Relay IR."""
 46      model_map = {
 47          "resnet50": models.resnet50,
 48          "resnet18": models.resnet18,
 49          "mobilenet_v2": models.mobilenet_v2,
 50      }
 51  
 52      model = model_map[model_name](weights="DEFAULT")
 53      model.eval()
 54  
 55      example_input = torch.randn(*input_shape)
 56      with torch.no_grad():
 57          scripted_model = torch.jit.trace(model, example_input)
 58  
 59      input_infos = [("input0", input_shape)]
 60      mod, params = relay.frontend.from_pytorch(scripted_model, input_infos)
 61  
 62      return mod, params
 63  
 64  
 65  def run_auto_scheduler(
 66      mod: tvm.IRModule,
 67      params: dict,
 68      target: str = "cuda",
 69      num_trials: int = 1000,
 70      log_file: str = "tuning_logs/resnet50.json",
 71      early_stopping: int = 600,
 72  ) -> tvm.IRModule:
 73      """
 74      Run Ansor auto-scheduler for automatic optimization.
 75  
 76      Args:
 77          mod: TVM Relay IRModule
 78          params: Model parameters
 79          target: Compilation target
 80          num_trials: Number of tuning trials
 81          log_file: Path to save tuning log
 82          early_stopping: Stop if no improvement after this many trials
 83  
 84      Returns:
 85          Optimized IRModule
 86      """
 87      if not TVM_AVAILABLE:
 88          raise RuntimeError("TVM is required")
 89  
 90      target = tvm.target.Target(target)
 91  
 92      # Ensure log directory exists
 93      os.makedirs(os.path.dirname(log_file), exist_ok=True)
 94  
 95      print("Extracting tasks from Relay IR...")
 96      tasks, task_weights = auto_scheduler.extract_tasks(mod, params, target)
 97  
 98      print(f"Found {len(tasks)} tuning tasks:")
 99      for i, task in enumerate(tasks):
100          print(f"  Task {i}: {task.desc}")
101  
102      # Configure tuning options
103      tuning_options = auto_scheduler.TuningOptions(
104          num_measure_trials=num_trials,
105          runner=auto_scheduler.LocalRunner(
106              timeout=10, number=3, repeat=1, min_repeat_ms=100, enable_cpu_cache_flush=False
107          ),
108          measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
109          early_stopping=early_stopping,
110      )
111  
112      print(f"\nStarting auto-tuning with {num_trials} trials...")
113      print(f"Log file: {log_file}")
114  
115      # Run tuning
116      tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
117      tuner.tune(tuning_options)
118  
119      print(f"\nTuning complete. Results saved to {log_file}")
120  
121      return mod
122  
123  
124  def compile_with_tuned_schedule(
125      mod: tvm.IRModule,
126      params: dict,
127      target: str = "cuda",
128      log_file: str = "tuning_logs/resnet50.json",
129  ) -> tvm.runtime.Module:
130      """
131      Compile model with tuned schedule.
132  
133      Args:
134          mod: TVM Relay IRModule
135          params: Model parameters
136          target: Compilation target
137          log_file: Path to tuning log
138  
139      Returns:
140          Compiled runtime module
141      """
142      target = tvm.target.Target(target)
143  
144      # Apply tuned schedule
145      with (
146          auto_scheduler.ApplyHistoryBest(log_file),
147          tvm.transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}),
148      ):
149          lib = relay.build(mod, target=target, params=params)
150  
151      return lib
152  
153  
154  def benchmark_tuned_model(
155      lib: tvm.runtime.Module,
156      input_shape: tuple[int, ...],
157      target: str = "cuda",
158      warmup_iters: int = 10,
159      bench_iters: int = 100,
160  ) -> TimingResult:
161      """Benchmark the tuned model."""
162      dev = tvm.device(target, 0)
163      module = graph_executor.GraphModule(lib["default"](dev))
164  
165      input_data = np.random.randn(*input_shape).astype("float32")
166      module.set_input("input0", input_data)
167  
168      def run_inference():
169          module.run()
170  
171      result = benchmark_function(
172          run_inference, warmup_iters=warmup_iters, bench_iters=bench_iters, sync_cuda=True
173      )
174  
175      return result
176  
177  
178  def main() -> None:
179      parser = argparse.ArgumentParser(description="TVM Auto-Scheduler Tuning")
180      parser.add_argument("--model", type=str, default="resnet50", help="Model name")
181      parser.add_argument("--num_trials", type=int, default=1000, help="Number of tuning trials")
182      parser.add_argument(
183          "--log_file", type=str, default="tuning_logs/resnet50.json", help="Tuning log file"
184      )
185      parser.add_argument("--target", type=str, default="cuda", help="Target device")
186      parser.add_argument("--tune", action="store_true", help="Run tuning (skip if log exists)")
187      parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations")
188      parser.add_argument("--bench", type=int, default=100, help="Benchmark iterations")
189      args = parser.parse_args()
190  
191      if not TVM_AVAILABLE:
192          print("TVM is required for this script. Please install TVM.")
193          return
194  
195      print("=" * 60)
196      print("TVM End-to-End Optimization - Step 2: Auto-Scheduler Tuning")
197      print("=" * 60)
198  
199      input_shape = (1, 3, 224, 224)
200  
201      # Load and convert model
202      print(f"\nLoading {args.model}...")
203      mod, params = load_model_and_convert(args.model, input_shape)
204  
205      # Run tuning if requested or log doesn't exist
206      log_path = os.path.join(os.path.dirname(__file__), args.log_file)
207  
208      if args.tune or not os.path.exists(log_path):
209          print("\n" + "-" * 40)
210          run_auto_scheduler(
211              mod, params, target=args.target, num_trials=args.num_trials, log_file=log_path
212          )
213      else:
214          print(f"\nUsing existing tuning log: {log_path}")
215  
216      # Compile with tuned schedule
217      print("\n" + "-" * 40)
218      print("Compiling with tuned schedule...")
219      lib = compile_with_tuned_schedule(mod, params, args.target, log_path)
220  
221      # Benchmark
222      print("\n" + "-" * 40)
223      result = benchmark_tuned_model(
224          lib, input_shape, args.target, warmup_iters=args.warmup, bench_iters=args.bench
225      )
226  
227      print("Auto-tuned inference:")
228      print(f"  Mean latency: {result.mean_ms:.3f} ms")
229      print(f"  Std: {result.std_ms:.3f} ms")
230      print(f"  Min: {result.min_ms:.3f} ms")
231      print(f"  Max: {result.max_ms:.3f} ms")
232  
233      return 0
234  
235  
236  if __name__ == "__main__":
237      sys.exit(main())