model_loader.py
1 """ 2 Model loading utilities for AI System Optimization Series. 3 Provides unified interface for loading models from various frameworks. 4 """ 5 6 from dataclasses import dataclass 7 from typing import Any 8 9 10 @dataclass 11 class ModelInfo: 12 """Information about a loaded model.""" 13 14 name: str 15 framework: str 16 input_shape: tuple[int, ...] 17 num_parameters: int 18 dtype: str 19 20 21 def load_model( 22 model_name: str = "resnet50", 23 framework: str = "pytorch", 24 pretrained: bool = True, 25 device: str = "cuda", 26 ) -> tuple[Any, ModelInfo]: 27 """ 28 Load a model from the specified framework. 29 30 Args: 31 model_name: Name of the model (e.g., "resnet50", "bert-base") 32 framework: Framework to use ("pytorch", "onnx") 33 pretrained: Whether to load pretrained weights 34 device: Device to load model on 35 36 Returns: 37 Tuple of (model, ModelInfo) 38 """ 39 if framework == "pytorch": 40 return _load_pytorch_model(model_name, pretrained, device) 41 elif framework == "onnx": 42 return _load_onnx_model(model_name) 43 else: 44 raise ValueError(f"Unsupported framework: {framework}") 45 46 47 def _load_pytorch_model(model_name: str, pretrained: bool, device: str) -> tuple[Any, ModelInfo]: 48 """Load a PyTorch model.""" 49 try: 50 import torch # noqa: F401 51 import torchvision.models as models 52 except ImportError: 53 raise ImportError("PyTorch and torchvision are required") 54 55 # Map model names to torchvision functions 56 model_map = { 57 "resnet50": models.resnet50, 58 "resnet18": models.resnet18, 59 "resnet101": models.resnet101, 60 "vgg16": models.vgg16, 61 "mobilenet_v2": models.mobilenet_v2, 62 "efficientnet_b0": models.efficientnet_b0, 63 } 64 65 if model_name not in model_map: 66 raise ValueError(f"Unknown model: {model_name}. Available: {list(model_map.keys())}") 67 68 # Load model 69 weights = "DEFAULT" if pretrained else None 70 model = model_map[model_name](weights=weights) 71 model = model.to(device) 72 model.eval() 73 74 # Get model info 75 num_params = sum(p.numel() for p in model.parameters()) 76 77 # Default input shape for ImageNet models 78 input_shape = (1, 3, 224, 224) 79 80 info = ModelInfo( 81 name=model_name, 82 framework="pytorch", 83 input_shape=input_shape, 84 num_parameters=num_params, 85 dtype="float32", 86 ) 87 88 return model, info 89 90 91 def _load_onnx_model(model_path: str) -> tuple[Any, ModelInfo]: 92 """Load an ONNX model.""" 93 try: 94 import onnx 95 import onnxruntime as ort 96 except ImportError: 97 raise ImportError("onnx and onnxruntime are required") 98 99 # Load ONNX model 100 onnx_model = onnx.load(model_path) 101 102 # Create inference session 103 available_providers = ort.get_available_providers() 104 if "CUDAExecutionProvider" in available_providers: 105 providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] 106 elif "CPUExecutionProvider" in available_providers: 107 providers = ["CPUExecutionProvider"] 108 else: 109 raise RuntimeError( 110 "No supported ONNX Runtime execution providers available. " 111 f"Available providers: {available_providers}" 112 ) 113 114 session = ort.InferenceSession(model_path, providers=providers) 115 116 # Get input shape 117 input_info = session.get_inputs()[0] 118 input_shape = tuple(input_info.shape) 119 120 # Count parameters (approximate from model size) 121 num_params = sum( 122 init.dims[0] * (init.dims[1] if len(init.dims) > 1 else 1) 123 for init in onnx_model.graph.initializer 124 ) 125 126 info = ModelInfo( 127 name=model_path, 128 framework="onnx", 129 input_shape=input_shape, 130 num_parameters=num_params, 131 dtype=input_info.type, 132 ) 133 134 return session, info 135 136 137 def get_model_info(model: Any, framework: str = "pytorch") -> dict[str, Any]: 138 """ 139 Get detailed information about a model. 140 141 Args: 142 model: The model object 143 framework: Framework the model is from 144 145 Returns: 146 Dictionary with model information 147 """ 148 info = {"framework": framework} 149 150 if framework == "pytorch": 151 try: 152 info["num_parameters"] = sum(p.numel() for p in model.parameters()) 153 info["trainable_parameters"] = sum( 154 p.numel() for p in model.parameters() if p.requires_grad 155 ) 156 info["device"] = str(next(model.parameters()).device) 157 info["dtype"] = str(next(model.parameters()).dtype) 158 except Exception as e: 159 info["error"] = str(e) 160 161 elif framework == "onnx": 162 try: 163 inputs = model.get_inputs() 164 outputs = model.get_outputs() 165 info["inputs"] = [(i.name, i.shape, i.type) for i in inputs] 166 info["outputs"] = [(o.name, o.shape, o.type) for o in outputs] 167 info["providers"] = model.get_providers() 168 except Exception as e: 169 info["error"] = str(e) 170 171 return info 172 173 174 def export_to_onnx( 175 model: Any, output_path: str, input_shape: tuple[int, ...], opset_version: int = 14 176 ) -> str: 177 """ 178 Export a PyTorch model to ONNX format. 179 180 Args: 181 model: PyTorch model 182 output_path: Path to save ONNX model 183 input_shape: Input tensor shape 184 opset_version: ONNX opset version 185 186 Returns: 187 Path to saved ONNX model 188 """ 189 try: 190 import torch 191 except ImportError: 192 raise ImportError("PyTorch is required for ONNX export") 193 194 model.eval() 195 device = next(model.parameters()).device 196 dummy_input = torch.randn(*input_shape, device=device) 197 198 torch.onnx.export( 199 model, 200 dummy_input, 201 output_path, 202 opset_version=opset_version, 203 input_names=["input"], 204 output_names=["output"], 205 dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, 206 ) 207 208 return output_path