/ pyod / utils / torch_utility.py
torch_utility.py
  1  # -*- coding: utf-8 -*-
  2  """Utility functions for PyTorch models
  3  """
  4  # Author: Tiankai Yang <tiankaiy@usc.edu>
  5  # License: BSD 2 clause
  6  
  7  import torch
  8  import torch.nn as nn
  9  
 10  
 11  class TorchDataset(torch.utils.data.Dataset):
 12      def __init__(self, X, y=None, mean=None, std=None, eps=1e-8,
 13                   X_dtype=torch.float32, y_dtype=torch.float32,
 14                   return_idx=False):
 15          self.X = X
 16          self.y = y
 17          self.mean = mean
 18          self.std = std
 19          self.eps = eps
 20          self.X_dtype = X_dtype
 21          self.y_dtype = y_dtype
 22          self.return_idx = return_idx
 23  
 24      def __len__(self):
 25          return len(self.X)
 26  
 27      def __getitem__(self, idx):
 28          if torch.is_tensor(idx):
 29              idx = idx.tolist()
 30          sample = self.X[idx, :]
 31  
 32          if self.mean is not None and self.std is not None:
 33              sample = (sample - self.mean) / (self.std + self.eps)
 34  
 35          if self.y is not None:
 36              if self.return_idx:
 37                  return torch.as_tensor(sample, dtype=self.X_dtype), \
 38                      torch.as_tensor(self.y[idx], dtype=self.y_dtype), idx
 39              else:
 40                  return torch.as_tensor(sample, dtype=self.X_dtype), \
 41                      torch.as_tensor(self.y[idx], dtype=self.y_dtype)
 42          else:
 43              if self.return_idx:
 44                  return torch.as_tensor(sample, dtype=self.X_dtype), idx
 45              else:
 46                  return torch.as_tensor(sample, dtype=self.X_dtype)
 47  
 48  
 49  class LinearBlock(nn.Module):
 50      """
 51      Linear block with activation and batch normalization
 52  
 53      Parameters
 54      ----------
 55      in_features : int
 56          Number of input features
 57          
 58      out_features : int
 59          Number of output features.
 60  
 61      has_act : bool, optional (default=True)
 62          If True, apply activation function after linear layer.
 63  
 64      activation_name : str, optional (default='relu')
 65          Activation function name. Available functions: 
 66          'elu', 'identity', 'leaky_relu', 'relu', 'sigmoid',
 67          'softmax', 'softplus', 'tanh'.
 68  
 69      batch_norm : bool, optional (default=True)
 70          If True, apply batch normalization after activation function if `has_act` is True,
 71          or after linear layer if `has_act` is False.
 72          The following four parameters are used only if `batch_norm` is True.
 73          See https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html#batchnorm1d for details.
 74  
 75      bn_eps : float, optional (default=1e-5)
 76          A value added to the denominator for numerical stability
 77  
 78      bn_momentum : float, optional (default=0.1)
 79          The value used for the running_mean and running_var computation. 
 80          Can be set to None for cumulative moving average (i.e. simple average)
 81  
 82      bn_affine : bool, optional (default=True)
 83          A boolean value that when set to 'True', this module has learnable affine parameters.
 84  
 85      bn_track_running_stats : bool, optional (default=True)
 86          Batch normalization track_running_stats.
 87  
 88      dropout_rate : float, optional (default=0)
 89          The probability of an element to be zeroed.
 90          See https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html#dropout for details.
 91  
 92      init_type : str, optional (default='kaiming_uniform')
 93          Initialization type.
 94          Available types: 'uniform', 'normal', 'constant', 'ones', 'zeros', 'eye', 'dirac',
 95          'xavier_uniform', 'xavier_normal', 'kaiming_uniform', 'kaiming_normal', 'trunc_normal',
 96          'orthogonal', 'sparse'.
 97          See https://pytorch.org/docs/stable/nn.init.html for details.
 98  
 99      inplace : bool, optional (default=False)
100          If set to True, activation function and dropout are applied in-place.
101  
102      activation_params : dict, optional (default=None)
103          Additional parameters for activation function.
104          For example, `activation_params={
105              'elu_alpha': 1.0, 
106              'leaky_relu_negative_slope': 0.01}`.
107  
108      init_params : dict, optional (default=None)
109          Additional parameters for initialization function.
110          For example, `init_params={
111              'uniform_a': 0.0, 
112              'uniform_b': 1.0}`.
113      """
114  
115      def __init__(self, in_features, out_features,
116                   has_act=True, activation_name='relu',
117                   batch_norm=False, bn_eps=1e-5, bn_momentum=0.1,
118                   bn_affine=True, bn_track_running_stats=True,
119                   dropout_rate=0,
120                   init_type='kaiming_uniform',
121                   inplace=False,
122                   activation_params: dict = {},
123                   init_params: dict = {}):
124          super(LinearBlock, self).__init__()
125          self.linear = nn.Linear(in_features, out_features)
126          self.has_act = has_act
127          if has_act:
128              # only use the variable about activation function in **kwargs
129              self.activation = get_activation_by_name(activation_name,
130                                                       inplace=inplace,
131                                                       **activation_params)
132          self.batch_norm = batch_norm
133          if batch_norm:
134              self.bn = nn.BatchNorm1d(out_features, eps=bn_eps,
135                                       momentum=bn_momentum, affine=bn_affine,
136                                       track_running_stats=bn_track_running_stats)
137          self.dropout_rate = dropout_rate
138          if dropout_rate > 0:
139              self.dropout = nn.Dropout(p=dropout_rate, inplace=inplace)
140          init_weights(layer=self.linear, name=init_type, **init_params)
141  
142      def forward(self, x):
143          x = self.linear(x)
144          if self.batch_norm:
145              x = self.bn(x)
146          if self.has_act:
147              x = self.activation(x)
148          if self.dropout_rate > 0:
149              x = self.dropout(x)
150          return x
151  
152  
153  def get_activation_by_name(name, inplace=False,
154                             elu_alpha=1.0,
155                             leaky_relu_negative_slope=0.01,
156                             softmax_dim=None,
157                             softplus_beta=1.0, softplus_threshold=20.0):
158      """
159      Get activation function by name
160  
161      Parameters
162      ----------
163      name : str
164          Activation function name. Available functions: 
165          'elu', 'identity', 'leaky_relu', 'relu', 'sigmoid',
166          'softmax', 'softplus', 'tanh'.
167  
168      inplace : bool, optional (default=False)
169          If set to True, do the operation in-place.
170  
171      elu_alpha : float, optional (default=1.0)
172          The alpha value for the ELU formulation.
173          See https://pytorch.org/docs/stable/generated/torch.nn.ELU.html#elu for details.
174  
175      leaky_relu_negative_slope : float, optional (default=0.01)
176          Controls the angle of the negative slope (which is used for negative inputs values).
177          See https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html#leakyrelu for details.
178  
179      softmax_dim : int, optional (default=None)
180          A dimension along which Softmax will be computed (so every slice along dim will sum to 1).
181          See https://pytorch.org/docs/stable/generated/torch.nn.Softmax.html#softmax for details.
182  
183      softplus_beta : float, optional (default=1.0)
184          The beta value for the Softplus formulation.
185          See https://pytorch.org/docs/stable/generated/torch.nn.Softplus.html#softplus for details.
186  
187      softplus_threshold : float, optional (default=20.0)
188          Values above this revert to a linear function
189          See https://pytorch.org/docs/stable/generated/torch.nn.Softplus.html#softplus for details.
190  
191      Returns
192      -------
193      activation : torch.nn.Module
194          Activation function module
195      """
196      activation_dict = {
197          'elu': nn.ELU(alpha=elu_alpha, inplace=inplace),
198          'identity': nn.Identity(),
199          'leaky_relu': nn.LeakyReLU(negative_slope=leaky_relu_negative_slope,
200                                     inplace=inplace),
201          'relu': nn.ReLU(inplace=inplace),
202          'sigmoid': nn.Sigmoid(),
203          'softmax': nn.Softmax(dim=softmax_dim),
204          'softplus': nn.Softplus(beta=softplus_beta,
205                                  threshold=softplus_threshold),
206          'tanh': nn.Tanh()
207      }
208  
209      if name in activation_dict.keys():
210          return activation_dict[name]
211  
212      else:
213          raise ValueError(f"{name} is not a valid activation.")
214  
215  
216  def get_optimizer_by_name(model, name, lr=1e-3, weight_decay=0,
217                            adam_eps=1e-8,
218                            sgd_momentum=0, sgd_nesterov=False):
219      """
220      Get optimizer by name
221  
222      Parameters
223      ----------
224      model : torch.nn.Module
225          Model to be optimized.
226  
227      name : str
228          Optimizer name. Available optimizers: 'adam', 'sgd'.
229          See https://pytorch.org/docs/stable/optim.html for details.
230  
231      lr : float, optional (default=1e-3)
232          Learning rate.
233  
234      weight_decay : float, optional (default=0)
235          Weight decay (L2 penalty).
236  
237      adam_eps : float, optional (default=1e-8)
238          Term added to the denominator to improve numerical stability.
239          See https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam for details.
240  
241      sgd_momentum : float, optional (default=0)
242          Momentum factor in SGD.
243          See https://pytorch.org/docs/stable/generated/torch.optim.SGD.html#torch.optim.SGD for details.
244  
245      sgd_nesterov : bool, optional (default=False)
246          Enables Nesterov momentum.
247          See https://pytorch.org/docs/stable/generated/torch.optim.SGD.html#torch.optim.SGD for details.
248  
249      Returns
250      -------
251      optimizer : torch.optim.Optimizer
252          Optimizer
253      """
254      optimizer_dict = {
255          'adam': torch.optim.Adam(model.parameters(), lr=lr,
256                                   weight_decay=weight_decay, eps=adam_eps),
257          'sgd': torch.optim.SGD(model.parameters(), lr=lr,
258                                 momentum=sgd_momentum,
259                                 weight_decay=weight_decay,
260                                 nesterov=sgd_nesterov)
261      }
262  
263      if name in optimizer_dict.keys():
264          return optimizer_dict[name]
265  
266      else:
267          raise ValueError(f"{name} is not a valid optimizer.")
268  
269  
270  def get_criterion_by_name(name, reduction='mean',
271                            bce_weight=None):
272      """
273      Get criterion by name
274  
275      Parameters
276      ----------
277      name : str
278          Loss function name. Available functions: 'mse', 'mae', 'bce'.
279          See https://pytorch.org/docs/stable/nn.html#loss-functions for details.
280  
281      reduction : str, optional (default='mean')
282          Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.
283          'none': no reduction will be applied, 
284          'mean': the sum of the output will be divided by the number of elements in the output, 
285          'sum': the output will be summed. Note: size_average and reduce are in the process of being deprecated, 
286              and in the meantime, specifying either of those two args will override reduction. Default: 'mean'
287          See https://pytorch.org/docs/stable/nn.html#loss-functions for details.
288  
289      bce_weight : torch.Tensor, optional (default=None)
290          A manual rescaling weight given to the loss of each batch element.
291          See https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html#torch.nn.BCELoss for details.
292  
293      Returns
294      -------
295      criterion : torch.nn.Module
296          Criterion module.
297      """
298      criterion_dict = {
299          'mse': nn.MSELoss(reduction=reduction),
300          'mae': nn.L1Loss(reduction=reduction),
301          'bce': nn.BCELoss(reduction=reduction, weight=bce_weight)
302      }
303  
304      if name in criterion_dict.keys():
305          return criterion_dict[name]
306  
307      else:
308          raise ValueError(f"{name} is not a valid criterion.")
309  
310  
311  def init_weights(layer, name='kaiming_uniform',
312                   uniform_a=0.0, uniform_b=1.0,
313                   normal_mean=0.0, normal_std=1.0,
314                   constant_val=0.0,
315                   xavier_gain=1.0,
316                   kaiming_a=0, kaiming_mode='fan_in',
317                   kaiming_nonlinearity='leaky_relu',
318                   trunc_mean=0.0, trunc_std=1.0, trunc_a=-2, trunc_b=2,
319                   orthogonal_gain=1.0,
320                   sparse_sparsity=None, sparse_std=0.01, sparse_generator=None):
321      """
322      Initialize weights for a layer
323  
324      Parameters
325      ----------
326      layer : torch.nn.Module
327          Layer to be initialized.
328  
329      name : str, optional (default='kaiming_uniform')
330          Initialization type.
331          Available types: 'uniform', 'normal', 'constant', 'ones', 'zeros', 'eye', 'dirac',
332          'xavier_uniform', 'xavier_normal', 'kaiming_uniform', 'kaiming_normal', 'trunc_normal',
333          'orthogonal', 'sparse'.
334          See https://pytorch.org/docs/stable/nn.init.html for details.
335  
336      uniform_a : float, optional (default=0.0)
337          The lower bound for the uniform distribution.
338          See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.uniform_ for details.
339  
340      uniform_b : float, optional (default=1.0)
341          The upper bound for the uniform distribution.
342          See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.uniform_ for details.
343  
344      normal_mean : float, optional (default=0.0)
345          The mean of the normal distribution.
346          See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.normal_ for details.
347  
348      normal_std : float, optional (default=1.0)
349          The standard deviation of the normal distribution.
350          See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.normal_ for details.
351  
352      constant_val : float, optional (default=0.0)
353          The value to fill the tensor with.
354          See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.constant_ for details.
355  
356      xavier_gain : float, optional (default=1.0)
357          An optional scaling factor.
358          See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.xavier_uniform_ 
359          and https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.xavier_normal_ for details.
360  
361      kaiming_a : float, optional (default=0)
362          The negative slope of the rectifier used after this layer (only used with 'leaky_relu')
363          See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_uniform_ 
364          and https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_normal_ for details.
365  
366      kaiming_mode : str, optional (default='fan_in')
367          The mode for kaiming initialization. Available modes: 'fan_in', 'fan_out'.
368          See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_uniform_ 
369          and https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_normal_ for details.
370  
371      kaiming_nonlinearity : str, optional (default='leaky_relu')
372          The non-linear function (nn.functional name), recommended to use only with 'relu' or 'leaky_relu'.
373          See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_uniform_ 
374          and https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_normal_ for details.
375  
376      trunc_mean : float, optional (default=0.0)
377          The mean value of the truncated normal distribution.
378          See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.trunc_normal_ for details.
379  
380      trunc_std : float, optional (default=1.0)
381          The standard deviation of the truncated normal distribution.
382          See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.trunc_normal_ for details.
383  
384      trunc_a : float, optional (default=-2)
385          The minimum cutoff value.
386          See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.trunc_normal_ for details.
387  
388      trunc_b : float, optional (default=2)
389          The maximum cutoff value.
390          See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.trunc_normal_ for details.
391  
392      orthogonal_gain : float, optional (default=1.0)
393          The optional scaling factor
394          See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.orthogonal_ for details.
395  
396      sparse_sparsity : float, optional (default=None)
397          This parameter must be provided if used!
398          The fraction of elements in each column to be set to zero.
399          See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.sparse_ for details.
400  
401      sparse_std : float, optional (default=0.01)
402          The standard deviation of the normal distribution used to generate the non-zero values
403          See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.sparse_ for details.
404  
405      sparse_generator : Optional[Generator] (default=None)
406          The torch Generator to sample from.
407          See https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.sparse_ for details.
408      """
409      init_name_dict = {
410          'uniform': nn.init.uniform_,
411          'normal': nn.init.normal_,
412          'constant': nn.init.constant_,
413          'ones': nn.init.ones_,
414          'zeros': nn.init.zeros_,
415          'eye': nn.init.eye_,
416          'xavier_uniform': nn.init.xavier_uniform_,
417          'xavier_normal': nn.init.xavier_normal_,
418          'kaiming_uniform': nn.init.kaiming_uniform_,
419          'kaiming_normal': nn.init.kaiming_normal_,
420          'trunc_normal': nn.init.trunc_normal_,
421          'orthogonal': nn.init.orthogonal_,
422          'sparse': nn.init.sparse_
423      }
424  
425      if name in init_name_dict.keys():
426          if name == 'uniform':
427              init_name_dict[name](layer.weight, a=uniform_a, b=uniform_b)
428          elif name == 'normal':
429              init_name_dict[name](layer.weight, mean=normal_mean,
430                                   std=normal_std)
431          elif name == 'constant':
432              init_name_dict[name](layer.weight, val=constant_val)
433          elif name == 'ones':
434              init_name_dict[name](layer.weight)
435          elif name == 'zeros':
436              init_name_dict[name](layer.weight)
437          elif name == 'eye':
438              init_name_dict[name](layer.weight)
439          elif name == 'xavier_uniform':
440              init_name_dict[name](layer.weight, gain=xavier_gain)
441          elif name == 'xavier_normal':
442              init_name_dict[name](layer.weight, gain=xavier_gain)
443          elif name == 'kaiming_uniform':
444              init_name_dict[name](layer.weight, a=kaiming_a, mode=kaiming_mode,
445                                   nonlinearity=kaiming_nonlinearity)
446          elif name == 'kaiming_normal':
447              init_name_dict[name](layer.weight, a=kaiming_a, mode=kaiming_mode,
448                                   nonlinearity=kaiming_nonlinearity)
449          elif name == 'trunc_normal':
450              init_name_dict[name](layer.weight, mean=trunc_mean, std=trunc_std,
451                                   a=trunc_a, b=trunc_b)
452          elif name == 'orthogonal':
453              init_name_dict[name](layer.weight, gain=orthogonal_gain)
454          elif name == 'sparse':
455              init_name_dict[name](layer.weight, sparsity=sparse_sparsity,
456                                   std=sparse_std)
457      else:
458          raise ValueError(f"{name} is not a valid initialization type.")