/ pyod / models / dif.py
dif.py
  1  # -*- coding: utf-8 -*-
  2  """Deep Isolation Forest for Anomaly Detection (DIF)
  3  """
  4  # Author: Hongzuo Xu <hongzuoxu@126.edu>
  5  # License: BSD 2 clause
  6  
  7  
  8  import numpy as np
  9  
 10  try:
 11      import torch
 12  except ImportError:
 13      print('please install torch first')
 14  
 15  import torch
 16  from torch.utils.data import DataLoader
 17  
 18  from sklearn.utils import check_array
 19  from sklearn.utils.validation import check_is_fitted
 20  from sklearn.ensemble import IsolationForest
 21  from sklearn.preprocessing import StandardScaler, MinMaxScaler
 22  
 23  from .base import BaseDetector
 24  from ..utils.torch_utility import get_activation_by_name
 25  
 26  
 27  class DIF(BaseDetector):
 28      """Deep Isolation Forest (DIF) is an extension of iForest. It uses deep
 29      representation ensemble to achieve non-linear isolation on original data
 30      space. See :cite:`xu2023dif` for details.
 31  
 32      Parameters
 33      ----------
 34      batch_size : int, optional (default=1000)
 35          Number of samples per gradient update.
 36  
 37      representation_dim, int, optional (default=20)
 38          Dimensionality of the representation space.
 39  
 40      hidden_neurons, list, optional (default=[64, 32])
 41          The number of neurons per hidden layers. So the network has the
 42          structure as [n_features, hidden_neurons[0], hidden_neurons[1], ..., representation_dim]
 43  
 44      hidden_activation, str, optional (default='tanh')
 45          Activation function to use for hidden layers.
 46          All hidden layers are forced to use the same type of activation.
 47          See https://pytorch.org/docs/stable/nn.html for details.
 48          Currently only
 49          'relu': nn.ReLU()
 50          'sigmoid': nn.Sigmoid()
 51          'tanh': nn.Tanh()
 52          are supported. See pyod/utils/torch_utility.py for details.
 53  
 54      skip_connection, boolean, optional (default=False)
 55          If True, apply skip-connection in the neural network structure.
 56  
 57      n_ensemble, int, optional (default=50)
 58          The number of deep representation ensemble members.
 59  
 60      n_estimators, int, optional (default=6)
 61          The number of isolation forest of each representation.
 62  
 63      max_samples, int, optional (default=256)
 64          The number of samples to draw from X to train each base isolation tree.
 65  
 66      contamination : float in (0., 0.5), optional (default=0.1)
 67          The amount of contamination of the data set,
 68          i.e. the proportion of outliers in the data set. Used when fitting to
 69          define the threshold on the decision function.
 70  
 71      random_state : int or None, optional (default=None)
 72          If int, random_state is the seed used by the random
 73          number generator;
 74          If None, the random number generator is the
 75          RandomState instance used by `np.random`.
 76  
 77      device, 'cuda', 'cpu', or None, optional (default=None)
 78          if 'cuda', use GPU acceleration in torch
 79          if 'cpu', use cpu in torch
 80          if None, automatically determine whether GPU is available
 81  
 82  
 83      Attributes
 84      ----------
 85      net_lst : list of torch.Module
 86          The list of representation neural networks.
 87  
 88      iForest_lst : list of iForest
 89          The list of instantiated iForest model.
 90  
 91      x_reduced_lst: list of numpy array
 92          The list of training data representations
 93  
 94      decision_scores_ : numpy array of shape (n_samples,)
 95          The outlier scores of the training data.
 96          The higher, the more abnormal. Outliers tend to have higher
 97          scores. This value is available once the detector is fitted.
 98  
 99      threshold_ : float
100          The threshold is based on ``contamination``. It is the
101          ``n_samples * contamination`` most abnormal samples in
102          ``decision_scores_``. The threshold is calculated for generating
103          binary outlier labels.
104  
105      labels_ : int, either 0 or 1
106          The binary labels of the training data. 0 stands for inliers
107          and 1 for outliers/anomalies. It is generated by applying
108          ``threshold_`` on ``decision_scores_``.
109      """
110  
111      def __init__(self,
112                   batch_size=1000,
113                   representation_dim=20,
114                   hidden_neurons=None,
115                   hidden_activation='tanh',
116                   skip_connection=False,
117                   n_ensemble=50,
118                   n_estimators=6,
119                   max_samples=256,
120                   contamination=0.1,
121                   random_state=None,
122                   device=None):
123          super(DIF, self).__init__(contamination=contamination)
124          self.batch_size = batch_size
125          self.representation_dim = representation_dim
126          self.hidden_activation = hidden_activation
127          self.skip_connection = skip_connection
128          self.hidden_neurons = hidden_neurons
129  
130          self.n_ensemble = n_ensemble
131          self.n_estimators = n_estimators
132          self.max_samples = max_samples
133  
134          self.random_state = random_state
135          self.device = device
136  
137          self.minmax_scaler = None
138  
139          # create default calculation device (support GPU if available)
140          if self.device is None:
141              self.device = torch.device(
142                  "cuda:0" if torch.cuda.is_available() else "cpu")
143  
144          # set random seed
145          if self.random_state is not None:
146              torch.manual_seed(self.random_state)
147              torch.cuda.manual_seed(self.random_state)
148              torch.cuda.manual_seed_all(self.random_state)
149              np.random.seed(self.random_state)
150  
151          # default values for the amount of hidden neurons
152          if self.hidden_neurons is None:
153              self.hidden_neurons = [500, 100]
154  
155      def fit(self, X, y=None):
156          """Fit detector. y is ignored in unsupervised methods.
157  
158          Parameters
159          ----------
160          X : numpy array of shape (n_samples, n_features)
161              The input samples.
162  
163          y : Ignored
164              Not used, present for API consistency by convention.
165  
166          Returns
167          -------
168          self : object
169              Fitted estimator.
170          """
171          # validate inputs X and y (optional)
172          X = check_array(X)
173          self._set_n_classes(y)
174  
175          n_samples, n_features = X.shape[0], X.shape[1]
176  
177          # conduct min-max normalization before feeding into neural networks
178          self.minmax_scaler = MinMaxScaler()
179          self.minmax_scaler.fit(X)
180          X = self.minmax_scaler.transform(X)
181  
182          # prepare neural network parameters
183          network_params = {
184              'n_features': n_features,
185              'n_hidden': self.hidden_neurons,
186              'n_output': self.representation_dim,
187              'activation': self.hidden_activation,
188              'skip_connection': self.skip_connection
189          }
190  
191          # iteration
192          self.net_lst = []
193          self.iForest_lst = []
194          self.x_reduced_lst = []
195          ensemble_seeds = np.random.randint(0, 100000, self.n_ensemble)
196          for i in range(self.n_ensemble):
197              # instantiate network class and seed random seed
198              net = MLPnet(**network_params).to(self.device)
199              torch.manual_seed(ensemble_seeds[i])
200  
201              # initialize network parameters
202              for name, param in net.named_parameters():
203                  if name.endswith('weight'):
204                      torch.nn.init.normal_(param, mean=0., std=1.)
205  
206              x_reduced = self._deep_representation(net, X)
207  
208              # save network and representations
209              self.x_reduced_lst.append(x_reduced)
210              self.net_lst.append(net)
211  
212              # perform iForest upon representations
213              self.iForest_lst.append(
214                  IsolationForest(n_estimators=self.n_estimators,
215                                  max_samples=self.max_samples,
216                                  random_state=ensemble_seeds[i])
217              )
218              self.iForest_lst[i].fit(x_reduced)
219  
220          self.decision_scores_ = self.decision_function(X)
221          self._process_decision_scores()
222          return self
223  
224      def decision_function(self, X):
225          """Predict raw anomaly score of X using the fitted detector.
226  
227          The anomaly score of an input sample is computed based on different
228          detector algorithms. For consistency, outliers are assigned with
229          larger anomaly scores.
230  
231          Parameters
232          ----------
233          X : numpy array of shape (n_samples, n_features)
234              The training input samples. Sparse matrices are accepted only
235              if they are supported by the base estimator.
236  
237          Returns
238          -------
239          anomaly_scores : numpy array of shape (n_samples,)
240              The anomaly score of the input samples.
241          """
242          check_is_fitted(self, ['net_lst', 'iForest_lst', 'x_reduced_lst'])
243          X = check_array(X)
244  
245          # conduct min-max normalization before feeding into neural networks
246          X = self.minmax_scaler.transform(X)
247  
248          testing_n_samples = X.shape[0]
249          score_lst = np.zeros([self.n_ensemble, testing_n_samples])
250  
251          # iteration
252          for i in range(self.n_ensemble):
253              # transform testing data to representation
254              x_reduced = self._deep_representation(self.net_lst[i], X)
255  
256              # calculate outlier scores
257              scores = _cal_score(x_reduced, self.iForest_lst[i])
258              score_lst[i] = scores
259  
260          final_scores = np.average(score_lst, axis=0)
261          return final_scores
262  
263      def _deep_representation(self, net, X):
264          x_reduced = []
265  
266          with torch.no_grad():
267              loader = DataLoader(X, batch_size=self.batch_size,
268                                  drop_last=False, pin_memory=True,
269                                  shuffle=False)
270              for batch_x in loader:
271                  batch_x = batch_x.float().to(self.device)
272                  batch_x_reduced = net(batch_x)
273                  x_reduced.append(batch_x_reduced)
274  
275          x_reduced = torch.cat(x_reduced).data.cpu().numpy()
276          x_reduced = StandardScaler().fit_transform(x_reduced)
277          x_reduced = np.tanh(x_reduced)
278          return x_reduced
279  
280  
281  class MLPnet(torch.nn.Module):
282      def __init__(self, n_features, n_hidden=[500, 100], n_output=20,
283                   activation='ReLU', bias=False, batch_norm=False,
284                   skip_connection=False):
285          super(MLPnet, self).__init__()
286          self.skip_connection = skip_connection
287          self.n_output = n_output
288  
289          num_layers = len(n_hidden)
290  
291          if type(activation) == str:
292              activation = [activation] * num_layers
293              activation.append(None)
294  
295          assert len(activation) == len(
296              n_hidden) + 1, 'activation and n_hidden are not matched'
297  
298          self.layers = []
299          for i in range(num_layers + 1):
300              in_channels, out_channels = \
301                  self.get_in_out_channels(i, num_layers, n_features,
302                                           n_hidden, n_output, skip_connection)
303              self.layers += [
304                  LinearBlock(in_channels, out_channels,
305                              bias=bias, batch_norm=batch_norm,
306                              activation=activation[i],
307                              skip_connection=skip_connection if i != num_layers else False)
308              ]
309          self.network = torch.nn.Sequential(*self.layers)
310  
311      def forward(self, x):
312          x = self.network(x)
313          return x
314  
315      @staticmethod
316      def get_in_out_channels(i, num_layers, n_features, n_hidden, n_output,
317                              skip_connection):
318          if skip_connection is False:
319              in_channels = n_features if i == 0 else n_hidden[i - 1]
320              out_channels = n_output if i == num_layers else n_hidden[i]
321          else:
322              in_channels = n_features if i == 0 else np.sum(
323                  n_hidden[:i]) + n_features
324              out_channels = n_output if i == num_layers else n_hidden[i]
325          return in_channels, out_channels
326  
327  
328  class LinearBlock(torch.nn.Module):
329      def __init__(self, in_channels, out_channels,
330                   activation='Tanh', bias=False, batch_norm=False,
331                   skip_connection=False):
332          super(LinearBlock, self).__init__()
333  
334          self.skip_connection = skip_connection
335  
336          self.linear = torch.nn.Linear(in_channels, out_channels, bias=bias)
337  
338          if activation is not None:
339              # self.act_layer = _instantiate_class("torch.nn.modules.activation", activation)
340              self.act_layer = get_activation_by_name(activation)
341          else:
342              self.act_layer = torch.nn.Identity()
343  
344          self.batch_norm = batch_norm
345          if batch_norm is True:
346              dim = out_channels
347              self.bn_layer = torch.nn.BatchNorm1d(dim, affine=bias)
348  
349      def forward(self, x):
350          x1 = self.linear(x)
351          x1 = self.act_layer(x1)
352  
353          if self.batch_norm is True:
354              x1 = self.bn_layer(x1)
355  
356          if self.skip_connection:
357              x1 = torch.cat([x, x1], axis=1)
358  
359          return x1
360  
361  
362  def _cal_score(xx, clf):
363      depths = np.zeros((xx.shape[0], len(clf.estimators_)))
364      depth_sum = np.zeros(xx.shape[0])
365      deviations = np.zeros((xx.shape[0], len(clf.estimators_)))
366      leaf_samples = np.zeros((xx.shape[0], len(clf.estimators_)))
367  
368      for ii, estimator_tree in enumerate(clf.estimators_):
369          tree = estimator_tree.tree_
370          n_node = tree.node_count
371  
372          if n_node == 1:
373              continue
374  
375          # get feature and threshold of each node in the iTree
376          # in feature_lst, -2 indicates the leaf node
377          feature_lst, threshold_lst = tree.feature.copy(), tree.threshold.copy()
378  
379          # compute depth and score
380          leaves_index = estimator_tree.apply(xx)
381          node_indicator = estimator_tree.decision_path(xx)
382  
383          # The number of training samples in each test sample leaf
384          n_node_samples = estimator_tree.tree_.n_node_samples
385  
386          # node_indicator is a sparse matrix with shape (n_samples, n_nodes),
387          # indicating the path of input data samples
388          # each layer would result in a non-zero element in this matrix,
389          # and then the row-wise summation is the depth of data sample
390          n_samples_leaf = estimator_tree.tree_.n_node_samples[leaves_index]
391          d = (np.ravel(node_indicator.sum(axis=1)) + _average_path_length(
392              n_samples_leaf) - 1.0)
393          depths[:, ii] = d
394          depth_sum += d
395  
396          # decision path of data matrix XX
397          node_indicator = np.array(node_indicator.todense())
398  
399          # set a matrix with shape [n_sample, n_node],
400          # representing the feature value of each sample on each node
401          # set the leaf node as -2
402          value_mat = np.array([xx[i][feature_lst] for i in range(xx.shape[0])])
403          value_mat[:, np.where(feature_lst == -2)[0]] = -2
404          th_mat = np.array([threshold_lst for _ in range(xx.shape[0])])
405  
406          mat = np.abs(value_mat - th_mat) * node_indicator
407  
408          exist = (mat != 0)
409          dev = mat.sum(axis=1) / (exist.sum(axis=1) + 1e-6)
410          deviations[:, ii] = dev
411  
412      scores = 2 ** (-depth_sum / (len(clf.estimators_) * _average_path_length(
413          [clf.max_samples_])))
414      deviation = np.mean(deviations, axis=1)
415      leaf_sample = (clf.max_samples_ - np.mean(leaf_samples,
416                                                axis=1)) / clf.max_samples_
417  
418      scores = scores * deviation
419      # scores = scores * deviation * leaf_sample
420      return scores
421  
422  
423  def _average_path_length(n_samples_leaf):
424      """
425      The average path length in a n_samples iTree, which is equal to
426      the average path length of an unsuccessful BST search since the
427      latter has the same structure as an isolation tree.
428      Parameters
429      ----------
430      n_samples_leaf : array-like of shape (n_samples,)
431          The number of training samples in each test sample leaf, for
432          each estimators.
433  
434      Returns
435      -------
436      average_path_length : ndarray of shape (n_samples,)
437      """
438  
439      n_samples_leaf = check_array(n_samples_leaf, ensure_2d=False)
440  
441      n_samples_leaf_shape = n_samples_leaf.shape
442      n_samples_leaf = n_samples_leaf.reshape((1, -1))
443      average_path_length = np.zeros(n_samples_leaf.shape)
444  
445      mask_1 = n_samples_leaf <= 1
446      mask_2 = n_samples_leaf == 2
447      not_mask = ~np.logical_or(mask_1, mask_2)
448  
449      average_path_length[mask_1] = 0.
450      average_path_length[mask_2] = 1.
451      average_path_length[not_mask] = (
452              2.0 * (np.log(n_samples_leaf[not_mask] - 1.0) + np.euler_gamma)
453              - 2.0 * (n_samples_leaf[not_mask] - 1.0) / n_samples_leaf[not_mask]
454      )
455  
456      return average_path_length.reshape(n_samples_leaf_shape)