baseline.py
  1  """
  2  TVM End-to-End Optimization - Step 1: Import and Baseline
  3  
  4  This module provides:
  5  1. Loading pretrained PyTorch models (ResNet50, etc.)
  6  2. Converting them to TVM Relay IR
  7  3. Running baseline inference without optimization
  8  """
  9  
 10  from __future__ import annotations
 11  
 12  import argparse
 13  from typing import TYPE_CHECKING
 14  
 15  import numpy as np
 16  import torch
 17  import torch.nn as nn
 18  import torchvision.models as models
 19  
 20  from common.benchmark.timer import TimingResult, benchmark_function
 21  from common.utils.path_utils import ensure_project_root_in_path
 22  
 23  # Ensure project root is in path for imports
 24  ensure_project_root_in_path()
 25  
 26  if TYPE_CHECKING:
 27      import tvm
 28  
 29  # TVM imports (optional, graceful fallback)
 30  try:
 31      import tvm
 32      from tvm import relay
 33      from tvm.contrib import graph_executor
 34  
 35      TVM_AVAILABLE = True
 36  except ImportError:
 37      TVM_AVAILABLE = False
 38  
 39  
 40  def load_pytorch_model(
 41      model_name: str = "resnet50", pretrained: bool = True, device: str = "cuda"
 42  ) -> tuple[nn.Module, tuple[int, ...]]:
 43      """Load a pretrained PyTorch model.
 44  
 45      Args:
 46          model_name: Name of the model to load (resnet50, resnet18, resnet101, mobilenet_v2)
 47          pretrained: Whether to load pretrained weights
 48          device: Device to load model on
 49  
 50      Returns:
 51          Tuple of (model, input_shape)
 52      """
 53      model_map = {
 54          "resnet50": models.resnet50,
 55          "resnet18": models.resnet18,
 56          "resnet101": models.resnet101,
 57          "mobilenet_v2": models.mobilenet_v2,
 58      }
 59  
 60      if model_name not in model_map:
 61          raise ValueError(f"Unknown model: {model_name}")
 62  
 63      weights = "DEFAULT" if pretrained else None
 64      model = model_map[model_name](weights=weights)
 65      model = model.to(device)
 66      model.eval()
 67  
 68      # Standard ImageNet input shape
 69      input_shape = (1, 3, 224, 224)
 70  
 71      print(f"Loaded {model_name} on {device}")
 72      print(f"  Parameters: {sum(p.numel() for p in model.parameters()):,}")
 73  
 74      return model, input_shape
 75  
 76  
 77  def convert_to_relay(
 78      model: nn.Module, input_shape: tuple[int, ...], input_name: str = "input0"
 79  ) -> tuple[tvm.IRModule, dict]:
 80      """Convert PyTorch model to TVM Relay IR.
 81  
 82      Args:
 83          model: PyTorch model
 84          input_shape: Input tensor shape
 85          input_name: Name for the input tensor
 86  
 87      Returns:
 88          Tuple of (Relay IRModule, parameters dict)
 89  
 90      Raises:
 91          RuntimeError: If TVM is not installed
 92      """
 93      if not TVM_AVAILABLE:
 94          raise RuntimeError("TVM is required for Relay conversion")
 95  
 96      # Create example input
 97      device = next(model.parameters()).device
 98      example_input = torch.randn(*input_shape, device=device)
 99  
100      # Trace the model
101      with torch.no_grad():
102          scripted_model = torch.jit.trace(model, example_input)
103  
104      # Convert to Relay
105      input_infos = [(input_name, input_shape)]
106      mod, params = relay.frontend.from_pytorch(scripted_model, input_infos)
107  
108      print("Converted to Relay IR")
109      print(f"  Input: {input_name} {input_shape}")
110  
111      return mod, params
112  
113  
114  def run_baseline_inference(
115      mod: tvm.IRModule,
116      params: dict,
117      input_shape: tuple[int, ...],
118      target: str = "cuda",
119      warmup_iters: int = 10,
120      bench_iters: int = 100,
121  ) -> TimingResult:
122      """Run baseline inference without optimization.
123  
124      Args:
125          mod: TVM Relay IRModule
126          params: Model parameters
127          input_shape: Input tensor shape
128          target: Compilation target (cuda or llvm)
129          warmup_iters: Number of warmup iterations
130          bench_iters: Number of benchmark iterations
131  
132      Returns:
133          TimingResult with latency statistics
134  
135      Raises:
136          RuntimeError: If TVM is not installed
137      """
138      if not TVM_AVAILABLE:
139          raise RuntimeError("TVM is required for inference")
140  
141      # Compile without optimization
142      tvm_target = tvm.target.Target(target)
143  
144      with tvm.transform.PassContext(opt_level=0):
145          lib = relay.build(mod, target=tvm_target, params=params)
146  
147      # Create runtime
148      dev = tvm.device(target, 0)
149      module = graph_executor.GraphModule(lib["default"](dev))
150  
151      # Prepare input
152      input_data = np.random.randn(*input_shape).astype("float32")
153      module.set_input("input0", input_data)
154  
155      # Benchmark
156      def run_inference():
157          module.run()
158  
159      result = benchmark_function(
160          run_inference, warmup_iters=warmup_iters, bench_iters=bench_iters, sync_cuda=True
161      )
162  
163      print("Baseline inference (opt_level=0):")
164      print(f"  Mean latency: {result.mean_ms:.3f} ms")
165      print(f"  Std: {result.std_ms:.3f} ms")
166  
167      return result
168  
169  
170  def run_pytorch_baseline(
171      model: nn.Module, input_shape: tuple[int, ...], warmup_iters: int = 10, bench_iters: int = 100
172  ) -> TimingResult:
173      """Run PyTorch eager mode baseline.
174  
175      Args:
176          model: PyTorch model
177          input_shape: Input tensor shape
178          warmup_iters: Number of warmup iterations
179          bench_iters: Number of benchmark iterations
180  
181      Returns:
182          TimingResult with latency statistics
183      """
184      device = next(model.parameters()).device
185      input_data = torch.randn(*input_shape, device=device)
186  
187      def run_inference():
188          with torch.no_grad():
189              _ = model(input_data)
190  
191      result = benchmark_function(
192          run_inference, warmup_iters=warmup_iters, bench_iters=bench_iters, sync_cuda=True
193      )
194  
195      print("PyTorch eager mode baseline:")
196      print(f"  Mean latency: {result.mean_ms:.3f} ms")
197      print(f"  Std: {result.std_ms:.3f} ms")
198  
199      return result
200  
201  
202  def main() -> int:
203      """Main entry point for CLI."""
204      parser = argparse.ArgumentParser(description="TVM Import and Baseline")
205      parser.add_argument("--model", type=str, default="resnet50", help="Model name")
206      parser.add_argument("--device", type=str, default="cuda", help="Device")
207      parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations")
208      parser.add_argument("--bench", type=int, default=100, help="Benchmark iterations")
209      args = parser.parse_args()
210  
211      # Check CUDA availability
212      if args.device == "cuda" and not torch.cuda.is_available():
213          print("CUDA not available, falling back to CPU")
214          args.device = "cpu"
215  
216      print("=" * 60)
217      print("TVM End-to-End Optimization - Step 1: Import and Baseline")
218      print("=" * 60)
219  
220      # Load PyTorch model
221      model, input_shape = load_pytorch_model(args.model, device=args.device)
222  
223      # PyTorch baseline
224      print("\n" + "-" * 40)
225      pytorch_result = run_pytorch_baseline(
226          model, input_shape, warmup_iters=args.warmup, bench_iters=args.bench
227      )
228  
229      # TVM conversion and baseline
230      if TVM_AVAILABLE:
231          print("\n" + "-" * 40)
232          mod, params = convert_to_relay(model.cpu(), input_shape)
233  
234          target = "cuda" if args.device == "cuda" else "llvm"
235          print("\n" + "-" * 40)
236          tvm_result = run_baseline_inference(
237              mod,
238              params,
239              input_shape,
240              target=target,
241              warmup_iters=args.warmup,
242              bench_iters=args.bench,
243          )
244  
245          # Summary
246          print("\n" + "=" * 60)
247          print("Summary:")
248          print(f"  PyTorch Eager: {pytorch_result.mean_ms:.3f} ms")
249          print(f"  TVM Baseline:  {tvm_result.mean_ms:.3f} ms")
250          print(f"  Speedup:       {pytorch_result.mean_ms / tvm_result.mean_ms:.2f}x")
251          return 0
252      else:
253          print("\nTVM not available. Install TVM to run Relay conversion.")
254          return 1
255  
256  
257  if __name__ == "__main__":
258      sys.exit(main())