torch_utility.py
1 # -*- coding: utf-8 -*- 2 """Utility functions for PyTorch models 3 """ 4 # Author: Tiankai Yang <tiankaiy@usc.edu> 5 # License: BSD 2 clause 6 7 import torch 8 import torch.nn as nn 9 10 11 class TorchDataset(torch.utils.data.Dataset): 12 def __init__(self, X, y=None, mean=None, std=None, eps=1e-8, 13 X_dtype=torch.float32, y_dtype=torch.float32, 14 return_idx=False): 15 self.X = X 16 self.y = y 17 self.mean = mean 18 self.std = std 19 self.eps = eps 20 self.X_dtype = X_dtype 21 self.y_dtype = y_dtype 22 self.return_idx = return_idx 23 24 def __len__(self): 25 return len(self.X) 26 27 def __getitem__(self, idx): 28 if torch.is_tensor(idx): 29 idx = idx.tolist() 30 sample = self.X[idx, :] 31 32 if self.mean is not None and self.std is not None: 33 sample = (sample - self.mean) / (self.std + self.eps) 34 35 if self.y is not None: 36 if self.return_idx: 37 return torch.as_tensor(sample, dtype=self.X_dtype), \ 38 torch.as_tensor(self.y[idx], dtype=self.y_dtype), idx 39 else: 40 return torch.as_tensor(sample, dtype=self.X_dtype), \ 41 torch.as_tensor(self.y[idx], dtype=self.y_dtype) 42 else: 43 if self.return_idx: 44 return torch.as_tensor(sample, dtype=self.X_dtype), idx 45 else: 46 return torch.as_tensor(sample, dtype=self.X_dtype) 47 48 49 class LinearBlock(nn.Module): 50 """ 51 Linear block with activation and batch normalization 52 53 Parameters 54 ---------- 55 in_features : int 56 Number of input features 57 58 out_features : int 59 Number of output features. 60 61 has_act : bool, optional (default=True) 62 If True, apply activation function after linear layer. 63 64 activation_name : str, optional (default='relu') 65 Activation function name. Available functions: 66 'elu', 'identity', 'leaky_relu', 'relu', 'sigmoid', 67 'softmax', 'softplus', 'tanh'. 68 69 batch_norm : bool, optional (default=True) 70 If True, apply batch normalization after activation function if `has_act` is True, 71 or after linear layer if `has_act` is False. 72 The following four parameters are used only if `batch_norm` is True. 73 See https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html#batchnorm1d for details. 74 75 bn_eps : float, optional (default=1e-5) 76 A value added to the denominator for numerical stability 77 78 bn_momentum : float, optional (default=0.1) 79 The value used for the running_mean and running_var computation. 80 Can be set to None for cumulative moving average (i.e. simple average) 81 82 bn_affine : bool, optional (default=True) 83 A boolean value that when set to 'True', this module has learnable affine parameters. 84 85 bn_track_running_stats : bool, optional (default=True) 86 Batch normalization track_running_stats. 87 88 dropout_rate : float, optional (default=0) 89 The probability of an element to be zeroed. 90 See https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html#dropout for details. 91 92 init_type : str, optional (default='kaiming_uniform') 93 Initialization type. 94 Available types: 'uniform', 'normal', 'constant', 'ones', 'zeros', 'eye', 'dirac', 95 'xavier_uniform', 'xavier_normal', 'kaiming_uniform', 'kaiming_normal', 'trunc_normal', 96 'orthogonal', 'sparse'. 97 See https://pytorch.org/docs/stable/nn.init.html for details. 98 99 inplace : bool, optional (default=False) 100 If set to True, activation function and dropout are applied in-place. 101 102 activation_params : dict, optional (default=None) 103 Additional parameters for activation function. 104 For example, `activation_params={ 105 'elu_alpha': 1.0, 106 'leaky_relu_negative_slope': 0.01}`. 107 108 init_params : dict, optional (default=None) 109 Additional parameters for initialization function. 110 For example, `init_params={ 111 'uniform_a': 0.0, 112 'uniform_b': 1.0}`. 113 """ 114 115 def __init__(self, in_features, out_features, 116 has_act=True, activation_name='relu', 117 batch_norm=False, bn_eps=1e-5, bn_momentum=0.1, 118 bn_affine=True, bn_track_running_stats=True, 119 dropout_rate=0, 120 init_type='kaiming_uniform', 121 inplace=False, 122 activation_params: dict = {}, 123 init_params: dict = {}): 124 super(LinearBlock, self).__init__() 125 self.linear = nn.Linear(in_features, out_features) 126 self.has_act = has_act 127 if has_act: 128 # only use the variable about activation function in **kwargs 129 self.activation = get_activation_by_name(activation_name, 130 inplace=inplace, 131 **activation_params) 132 self.batch_norm = batch_norm 133 if batch_norm: 134 self.bn = nn.BatchNorm1d(out_features, eps=bn_eps, 135 momentum=bn_momentum, affine=bn_affine, 136 track_running_stats=bn_track_running_stats) 137 self.dropout_rate = dropout_rate 138 if dropout_rate > 0: 139 self.dropout = nn.Dropout(p=dropout_rate, inplace=inplace) 140 init_weights(layer=self.linear, name=init_type, **init_params) 141 142 def forward(self, x): 143 x = self.linear(x) 144 if self.batch_norm: 145 x = self.bn(x) 146 if self.has_act: 147 x = self.activation(x) 148 if self.dropout_rate > 0: 149 x = self.dropout(x) 150 return x 151 152 153 def get_activation_by_name(name, inplace=False, 154 elu_alpha=1.0, 155 leaky_relu_negative_slope=0.01, 156 softmax_dim=None, 157 softplus_beta=1.0, softplus_threshold=20.0): 158 """ 159 Get activation function by name 160 161 Parameters 162 ---------- 163 name : str 164 Activation function name. Available functions: 165 'elu', 'identity', 'leaky_relu', 'relu', 'sigmoid', 166 'softmax', 'softplus', 'tanh'. 167 168 inplace : bool, optional (default=False) 169 If set to True, do the operation in-place. 170 171 elu_alpha : float, optional (default=1.0) 172 The alpha value for the ELU formulation. 173 See https://pytorch.org/docs/stable/generated/torch.nn.ELU.html#elu for details. 174 175 leaky_relu_negative_slope : float, optional (default=0.01) 176 Controls the angle of the negative slope (which is used for negative inputs values). 177 See https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html#leakyrelu for details. 178 179 softmax_dim : int, optional (default=None) 180 A dimension along which Softmax will be computed (so every slice along dim will sum to 1). 181 See https://pytorch.org/docs/stable/generated/torch.nn.Softmax.html#softmax for details. 182 183 softplus_beta : float, optional (default=1.0) 184 The beta value for the Softplus formulation. 185 See https://pytorch.org/docs/stable/generated/torch.nn.Softplus.html#softplus for details. 186 187 softplus_threshold : float, optional (default=20.0) 188 Values above this revert to a linear function 189 See https://pytorch.org/docs/stable/generated/torch.nn.Softplus.html#softplus for details. 190 191 Returns 192 ------- 193 activation : torch.nn.Module 194 Activation function module 195 """ 196 activation_dict = { 197 'elu': nn.ELU(alpha=elu_alpha, inplace=inplace), 198 'identity': nn.Identity(), 199 'leaky_relu': nn.LeakyReLU(negative_slope=leaky_relu_negative_slope, 200 inplace=inplace), 201 'relu': nn.ReLU(inplace=inplace), 202 'sigmoid': nn.Sigmoid(), 203 'softmax': nn.Softmax(dim=softmax_dim), 204 'softplus': nn.Softplus(beta=softplus_beta, 205 threshold=softplus_threshold), 206 'tanh': nn.Tanh() 207 } 208 209 if name in activation_dict.keys(): 210 return activation_dict[name] 211 212 else: 213 raise ValueError(f"{name} is not a valid activation.") 214 215 216 def get_optimizer_by_name(model, name, lr=1e-3, weight_decay=0, 217 adam_eps=1e-8, 218 sgd_momentum=0, sgd_nesterov=False): 219 """ 220 Get optimizer by name 221 222 Parameters 223 ---------- 224 model : torch.nn.Module 225 Model to be optimized. 226 227 name : str 228 Optimizer name. Available optimizers: 'adam', 'sgd'. 229 See https://pytorch.org/docs/stable/optim.html for details. 230 231 lr : float, optional (default=1e-3) 232 Learning rate. 233 234 weight_decay : float, optional (default=0) 235 Weight decay (L2 penalty). 236 237 adam_eps : float, optional (default=1e-8) 238 Term added to the denominator to improve numerical stability. 239 See https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam for details. 240 241 sgd_momentum : float, optional (default=0) 242 Momentum factor in SGD. 243 See https://pytorch.org/docs/stable/generated/torch.optim.SGD.html#torch.optim.SGD for details. 244 245 sgd_nesterov : bool, optional (default=False) 246 Enables Nesterov momentum. 247 See https://pytorch.org/docs/stable/generated/torch.optim.SGD.html#torch.optim.SGD for details. 248 249 Returns 250 ------- 251 optimizer : torch.optim.Optimizer 252 Optimizer 253 """ 254 optimizer_dict = { 255 'adam': torch.optim.Adam(model.parameters(), lr=lr, 256 weight_decay=weight_decay, eps=adam_eps), 257 'sgd': torch.optim.SGD(model.parameters(), lr=lr, 258 momentum=sgd_momentum, 259 weight_decay=weight_decay, 260 nesterov=sgd_nesterov) 261 } 262 263 if name in optimizer_dict.keys(): 264 return optimizer_dict[name] 265 266 else: 267 raise ValueError(f"{name} is not a valid optimizer.") 268 269 270 def get_criterion_by_name(name, reduction='mean', 271 bce_weight=None): 272 """ 273 Get criterion by name 274 275 Parameters 276 ---------- 277 name : str 278 Loss function name. Available functions: 'mse', 'mae', 'bce'. 279 See https://pytorch.org/docs/stable/nn.html#loss-functions for details. 280 281 reduction : str, optional (default='mean') 282 Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 283 'none': no reduction will be applied, 284 'mean': the sum of the output will be divided by the number of elements in the output, 285 'sum': the output will be summed. Note: size_average and reduce are in the process of being deprecated, 286 and in the meantime, specifying either of those two args will override reduction. Default: 'mean' 287 See https://pytorch.org/docs/stable/nn.html#loss-functions for details. 288 289 bce_weight : torch.Tensor, optional (default=None) 290 A manual rescaling weight given to the loss of each batch element. 291 See https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html#torch.nn.BCELoss for details. 292 293 Returns 294 ------- 295 criterion : torch.nn.Module 296 Criterion module. 297 """ 298 criterion_dict = { 299 'mse': nn.MSELoss(reduction=reduction), 300 'mae': nn.L1Loss(reduction=reduction), 301 'bce': nn.BCELoss(reduction=reduction, weight=bce_weight) 302 } 303 304 if name in criterion_dict.keys(): 305 return criterion_dict[name] 306 307 else: 308 raise ValueError(f"{name} is not a valid criterion.") 309 310 311 def init_weights(layer, name='kaiming_uniform', 312 uniform_a=0.0, uniform_b=1.0, 313 normal_mean=0.0, normal_std=1.0, 314 constant_val=0.0, 315 xavier_gain=1.0, 316 kaiming_a=0, kaiming_mode='fan_in', 317 kaiming_nonlinearity='leaky_relu', 318 trunc_mean=0.0, trunc_std=1.0, trunc_a=-2, trunc_b=2, 319 orthogonal_gain=1.0, 320 sparse_sparsity=None, sparse_std=0.01, sparse_generator=None): 321 """ 322 Initialize weights for a layer 323 324 Parameters 325 ---------- 326 layer : torch.nn.Module 327 Layer to be initialized. 328 329 name : str, optional (default='kaiming_uniform') 330 Initialization type. 331 Available types: 'uniform', 'normal', 'constant', 'ones', 'zeros', 'eye', 'dirac', 332 'xavier_uniform', 'xavier_normal', 'kaiming_uniform', 'kaiming_normal', 'trunc_normal', 333 'orthogonal', 'sparse'. 334 See https://pytorch.org/docs/stable/nn.init.html for details. 335 336 uniform_a : float, optional (default=0.0) 337 The lower bound for the uniform distribution. 338 See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.uniform_ for details. 339 340 uniform_b : float, optional (default=1.0) 341 The upper bound for the uniform distribution. 342 See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.uniform_ for details. 343 344 normal_mean : float, optional (default=0.0) 345 The mean of the normal distribution. 346 See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.normal_ for details. 347 348 normal_std : float, optional (default=1.0) 349 The standard deviation of the normal distribution. 350 See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.normal_ for details. 351 352 constant_val : float, optional (default=0.0) 353 The value to fill the tensor with. 354 See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.constant_ for details. 355 356 xavier_gain : float, optional (default=1.0) 357 An optional scaling factor. 358 See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.xavier_uniform_ 359 and https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.xavier_normal_ for details. 360 361 kaiming_a : float, optional (default=0) 362 The negative slope of the rectifier used after this layer (only used with 'leaky_relu') 363 See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_uniform_ 364 and https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_normal_ for details. 365 366 kaiming_mode : str, optional (default='fan_in') 367 The mode for kaiming initialization. Available modes: 'fan_in', 'fan_out'. 368 See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_uniform_ 369 and https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_normal_ for details. 370 371 kaiming_nonlinearity : str, optional (default='leaky_relu') 372 The non-linear function (nn.functional name), recommended to use only with 'relu' or 'leaky_relu'. 373 See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_uniform_ 374 and https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_normal_ for details. 375 376 trunc_mean : float, optional (default=0.0) 377 The mean value of the truncated normal distribution. 378 See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.trunc_normal_ for details. 379 380 trunc_std : float, optional (default=1.0) 381 The standard deviation of the truncated normal distribution. 382 See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.trunc_normal_ for details. 383 384 trunc_a : float, optional (default=-2) 385 The minimum cutoff value. 386 See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.trunc_normal_ for details. 387 388 trunc_b : float, optional (default=2) 389 The maximum cutoff value. 390 See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.trunc_normal_ for details. 391 392 orthogonal_gain : float, optional (default=1.0) 393 The optional scaling factor 394 See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.orthogonal_ for details. 395 396 sparse_sparsity : float, optional (default=None) 397 This parameter must be provided if used! 398 The fraction of elements in each column to be set to zero. 399 See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.sparse_ for details. 400 401 sparse_std : float, optional (default=0.01) 402 The standard deviation of the normal distribution used to generate the non-zero values 403 See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.sparse_ for details. 404 405 sparse_generator : Optional[Generator] (default=None) 406 The torch Generator to sample from. 407 See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.sparse_ for details. 408 """ 409 init_name_dict = { 410 'uniform': nn.init.uniform_, 411 'normal': nn.init.normal_, 412 'constant': nn.init.constant_, 413 'ones': nn.init.ones_, 414 'zeros': nn.init.zeros_, 415 'eye': nn.init.eye_, 416 'xavier_uniform': nn.init.xavier_uniform_, 417 'xavier_normal': nn.init.xavier_normal_, 418 'kaiming_uniform': nn.init.kaiming_uniform_, 419 'kaiming_normal': nn.init.kaiming_normal_, 420 'trunc_normal': nn.init.trunc_normal_, 421 'orthogonal': nn.init.orthogonal_, 422 'sparse': nn.init.sparse_ 423 } 424 425 if name in init_name_dict.keys(): 426 if name == 'uniform': 427 init_name_dict[name](layer.weight, a=uniform_a, b=uniform_b) 428 elif name == 'normal': 429 init_name_dict[name](layer.weight, mean=normal_mean, 430 std=normal_std) 431 elif name == 'constant': 432 init_name_dict[name](layer.weight, val=constant_val) 433 elif name == 'ones': 434 init_name_dict[name](layer.weight) 435 elif name == 'zeros': 436 init_name_dict[name](layer.weight) 437 elif name == 'eye': 438 init_name_dict[name](layer.weight) 439 elif name == 'xavier_uniform': 440 init_name_dict[name](layer.weight, gain=xavier_gain) 441 elif name == 'xavier_normal': 442 init_name_dict[name](layer.weight, gain=xavier_gain) 443 elif name == 'kaiming_uniform': 444 init_name_dict[name](layer.weight, a=kaiming_a, mode=kaiming_mode, 445 nonlinearity=kaiming_nonlinearity) 446 elif name == 'kaiming_normal': 447 init_name_dict[name](layer.weight, a=kaiming_a, mode=kaiming_mode, 448 nonlinearity=kaiming_nonlinearity) 449 elif name == 'trunc_normal': 450 init_name_dict[name](layer.weight, mean=trunc_mean, std=trunc_std, 451 a=trunc_a, b=trunc_b) 452 elif name == 'orthogonal': 453 init_name_dict[name](layer.weight, gain=orthogonal_gain) 454 elif name == 'sparse': 455 init_name_dict[name](layer.weight, sparsity=sparse_sparsity, 456 std=sparse_std) 457 else: 458 raise ValueError(f"{name} is not a valid initialization type.")