/ common / utils / model_loader.py
model_loader.py
  1  """
  2  Model loading utilities for AI System Optimization Series.
  3  Provides unified interface for loading models from various frameworks.
  4  """
  5  
  6  from dataclasses import dataclass
  7  from typing import Any
  8  
  9  
 10  @dataclass
 11  class ModelInfo:
 12      """Information about a loaded model."""
 13  
 14      name: str
 15      framework: str
 16      input_shape: tuple[int, ...]
 17      num_parameters: int
 18      dtype: str
 19  
 20  
 21  def load_model(
 22      model_name: str = "resnet50",
 23      framework: str = "pytorch",
 24      pretrained: bool = True,
 25      device: str = "cuda",
 26  ) -> tuple[Any, ModelInfo]:
 27      """
 28      Load a model from the specified framework.
 29  
 30      Args:
 31          model_name: Name of the model (e.g., "resnet50", "bert-base")
 32          framework: Framework to use ("pytorch", "onnx")
 33          pretrained: Whether to load pretrained weights
 34          device: Device to load model on
 35  
 36      Returns:
 37          Tuple of (model, ModelInfo)
 38      """
 39      if framework == "pytorch":
 40          return _load_pytorch_model(model_name, pretrained, device)
 41      elif framework == "onnx":
 42          return _load_onnx_model(model_name)
 43      else:
 44          raise ValueError(f"Unsupported framework: {framework}")
 45  
 46  
 47  def _load_pytorch_model(model_name: str, pretrained: bool, device: str) -> tuple[Any, ModelInfo]:
 48      """Load a PyTorch model."""
 49      try:
 50          import torch  # noqa: F401
 51          import torchvision.models as models
 52      except ImportError:
 53          raise ImportError("PyTorch and torchvision are required")
 54  
 55      # Map model names to torchvision functions
 56      model_map = {
 57          "resnet50": models.resnet50,
 58          "resnet18": models.resnet18,
 59          "resnet101": models.resnet101,
 60          "vgg16": models.vgg16,
 61          "mobilenet_v2": models.mobilenet_v2,
 62          "efficientnet_b0": models.efficientnet_b0,
 63      }
 64  
 65      if model_name not in model_map:
 66          raise ValueError(f"Unknown model: {model_name}. Available: {list(model_map.keys())}")
 67  
 68      # Load model
 69      weights = "DEFAULT" if pretrained else None
 70      model = model_map[model_name](weights=weights)
 71      model = model.to(device)
 72      model.eval()
 73  
 74      # Get model info
 75      num_params = sum(p.numel() for p in model.parameters())
 76  
 77      # Default input shape for ImageNet models
 78      input_shape = (1, 3, 224, 224)
 79  
 80      info = ModelInfo(
 81          name=model_name,
 82          framework="pytorch",
 83          input_shape=input_shape,
 84          num_parameters=num_params,
 85          dtype="float32",
 86      )
 87  
 88      return model, info
 89  
 90  
 91  def _load_onnx_model(model_path: str) -> tuple[Any, ModelInfo]:
 92      """Load an ONNX model."""
 93      try:
 94          import onnx
 95          import onnxruntime as ort
 96      except ImportError:
 97          raise ImportError("onnx and onnxruntime are required")
 98  
 99      # Load ONNX model
100      onnx_model = onnx.load(model_path)
101  
102      # Create inference session
103      available_providers = ort.get_available_providers()
104      if "CUDAExecutionProvider" in available_providers:
105          providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
106      elif "CPUExecutionProvider" in available_providers:
107          providers = ["CPUExecutionProvider"]
108      else:
109          raise RuntimeError(
110              "No supported ONNX Runtime execution providers available. "
111              f"Available providers: {available_providers}"
112          )
113  
114      session = ort.InferenceSession(model_path, providers=providers)
115  
116      # Get input shape
117      input_info = session.get_inputs()[0]
118      input_shape = tuple(input_info.shape)
119  
120      # Count parameters (approximate from model size)
121      num_params = sum(
122          init.dims[0] * (init.dims[1] if len(init.dims) > 1 else 1)
123          for init in onnx_model.graph.initializer
124      )
125  
126      info = ModelInfo(
127          name=model_path,
128          framework="onnx",
129          input_shape=input_shape,
130          num_parameters=num_params,
131          dtype=input_info.type,
132      )
133  
134      return session, info
135  
136  
137  def get_model_info(model: Any, framework: str = "pytorch") -> dict[str, Any]:
138      """
139      Get detailed information about a model.
140  
141      Args:
142          model: The model object
143          framework: Framework the model is from
144  
145      Returns:
146          Dictionary with model information
147      """
148      info = {"framework": framework}
149  
150      if framework == "pytorch":
151          try:
152              info["num_parameters"] = sum(p.numel() for p in model.parameters())
153              info["trainable_parameters"] = sum(
154                  p.numel() for p in model.parameters() if p.requires_grad
155              )
156              info["device"] = str(next(model.parameters()).device)
157              info["dtype"] = str(next(model.parameters()).dtype)
158          except Exception as e:
159              info["error"] = str(e)
160  
161      elif framework == "onnx":
162          try:
163              inputs = model.get_inputs()
164              outputs = model.get_outputs()
165              info["inputs"] = [(i.name, i.shape, i.type) for i in inputs]
166              info["outputs"] = [(o.name, o.shape, o.type) for o in outputs]
167              info["providers"] = model.get_providers()
168          except Exception as e:
169              info["error"] = str(e)
170  
171      return info
172  
173  
174  def export_to_onnx(
175      model: Any, output_path: str, input_shape: tuple[int, ...], opset_version: int = 14
176  ) -> str:
177      """
178      Export a PyTorch model to ONNX format.
179  
180      Args:
181          model: PyTorch model
182          output_path: Path to save ONNX model
183          input_shape: Input tensor shape
184          opset_version: ONNX opset version
185  
186      Returns:
187          Path to saved ONNX model
188      """
189      try:
190          import torch
191      except ImportError:
192          raise ImportError("PyTorch is required for ONNX export")
193  
194      model.eval()
195      device = next(model.parameters()).device
196      dummy_input = torch.randn(*input_shape, device=device)
197  
198      torch.onnx.export(
199          model,
200          dummy_input,
201          output_path,
202          opset_version=opset_version,
203          input_names=["input"],
204          output_names=["output"],
205          dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
206      )
207  
208      return output_path