anogan.py
1 # -*- coding: utf-8 -*- 2 """Anomaly Detection with Generative Adversarial Networks (AnoGAN) 3 Paper: https://arxiv.org/pdf/1703.05921.pdf 4 Note, that this is another implementation of AnoGAN as the one from https://github.com/fuchami/ANOGAN 5 """ 6 # Author: Michiel Bongaerts (but not author of the AnoGAN method) 7 # License: BSD 2 clause 8 9 import matplotlib.pyplot as plt 10 import numpy as np 11 import pandas as pd 12 13 try: 14 import torch 15 except ImportError: 16 print('please install torch first') 17 18 import torch 19 import torch.nn as nn 20 import torch.optim as optim 21 from sklearn.preprocessing import StandardScaler 22 from sklearn.utils import check_array 23 from sklearn.utils.validation import check_is_fitted 24 from torch.utils.data import DataLoader, TensorDataset 25 26 from .base import BaseDetector 27 from ..utils.torch_utility import get_activation_by_name 28 from ..utils.utility import check_parameter 29 30 31 class Generator(nn.Module): 32 def __init__(self, latent_dim_G, n_features, G_layers, dropout_rate, 33 activation_hidden, output_activation): 34 super(Generator, self).__init__() 35 self.latent_dim = latent_dim_G 36 self.n_features = n_features 37 self.layers = G_layers 38 self.dropout_rate = dropout_rate 39 self.activation_hidden = activation_hidden 40 self.output_activation = output_activation 41 42 self.model = self._build_generator() 43 44 def _build_generator(self): 45 layers = [nn.Dropout(self.dropout_rate), nn.Linear( 46 self.latent_dim, self.layers[0]), 47 get_activation_by_name(self.activation_hidden)] 48 for i in range(1, len(self.layers)): 49 layers.extend([ 50 nn.Dropout(self.dropout_rate), 51 nn.Linear(self.layers[i - 1], self.layers[i]), 52 get_activation_by_name(self.activation_hidden) 53 ]) 54 layers.append(nn.Linear(self.layers[-1], self.n_features)) 55 if self.output_activation: 56 layers.append(get_activation_by_name(self.output_activation)) 57 return nn.Sequential(*layers) 58 59 def forward(self, x): 60 return self.model(x) 61 62 63 class Discriminator(nn.Module): 64 def __init__(self, n_features, D_layers, dropout_rate, activation_hidden): 65 super(Discriminator, self).__init__() 66 self.n_features = n_features 67 self.layers = D_layers 68 self.dropout_rate = dropout_rate 69 self.activation_hidden = activation_hidden 70 71 self.model = self._build_discriminator() 72 73 def _build_discriminator(self): 74 layers = [nn.Dropout(self.dropout_rate), nn.Linear( 75 self.n_features, self.layers[0]), 76 get_activation_by_name(self.activation_hidden)] 77 for i in range(1, len(self.layers)): 78 layers.extend([ 79 nn.Dropout(self.dropout_rate), 80 nn.Linear(self.layers[i - 1], self.layers[i]), 81 get_activation_by_name(self.activation_hidden) 82 ]) 83 layers.extend([nn.Linear(self.layers[-1], 1), nn.Sigmoid()]) 84 return nn.Sequential(*layers) 85 86 def forward(self, x): 87 return self.model(x) 88 89 90 class QueryModel(nn.Module): 91 def __init__(self, generator, discriminator, latent_dim_G): 92 super(QueryModel, self).__init__() 93 self.generator = generator 94 self.discriminator = discriminator 95 self.z_gamma_layer = nn.Linear(latent_dim_G, latent_dim_G) 96 97 def forward(self, query_sample): 98 z_gamma = self.z_gamma_layer(query_sample) 99 sample_gen = self.generator(z_gamma) 100 sample_disc_latent = self.discriminator(sample_gen) 101 return z_gamma, sample_gen, sample_disc_latent 102 103 104 class AnoGAN(BaseDetector): 105 """Anomaly Detection with Generative Adversarial Networks (AnoGAN). 106 See the original paper "Unsupervised anomaly detection with generative 107 adversarial networks to guide marker discovery". 108 109 See :cite:`schlegl2017unsupervised` for details. 110 111 Parameters 112 ---------- 113 114 output_activation : str, optional (default=None) 115 Activation function to use for output layer. 116 117 118 activation_hidden : str, optional (default='tanh') 119 Activation function to use for output layer. 120 121 epochs : int, optional (default=500) 122 Number of epochs to train the model. 123 124 batch_size : int, optional (default=32) 125 Number of samples per gradient update. 126 127 dropout_rate : float in (0., 1), optional (default=0.2) 128 The dropout to be used across all layers. 129 130 G_layers : list, optional (default=[20,10,3,10,20]) 131 List that indicates the number of nodes per hidden layer for the 132 generator. Thus, [10,10] indicates 2 hidden layers having each 10 nodes. 133 134 D_layers : list, optional (default=[20,10,5]) 135 List that indicates the number of nodes per hidden layer for the 136 discriminator. Thus, [10,10] indicates 2 hidden layers having each 10 137 nodes. 138 139 140 learning_rate: float in (0., 1), optional (default=0.001) 141 learning rate of training the network 142 143 index_D_layer_for_recon_error: int, optional (default = 1) 144 This is the index of the hidden layer in the discriminator for which 145 the reconstruction error will be determined between query sample and 146 the sample created from the latent space. 147 148 learning_rate_query: float in (0., 1), optional (default=0.001) 149 learning rate for the backpropagation steps needed to find a point in 150 the latent space of the generator that approximate the query sample 151 152 153 epochs_query: int, optional (default=20) 154 Number of epochs to approximate the query sample in the latent space 155 of the generator 156 157 preprocessing : bool, optional (default=True) 158 If True, apply standardization on the data. 159 160 verbose : int, optional (default=1) 161 Verbosity mode. 162 - 0 = silent 163 - 1 = progress bar 164 165 contamination : float in (0., 0.5), optional (default=0.1) 166 The amount of contamination of the data set, i.e. 167 the proportion of outliers in the data set. When fitting this is used 168 to define the threshold on the decision function. 169 170 Attributes 171 ---------- 172 173 decision_scores_ : numpy array of shape (n_samples,) 174 The outlier scores of the training data [0,1]. 175 The higher, the more abnormal. Outliers tend to have higher 176 scores. This value is available once the detector is 177 fitted. 178 179 threshold_ : float 180 The threshold is based on ``contamination``. It is the 181 ``n_samples * contamination`` most abnormal samples in 182 ``decision_scores_``. The threshold is calculated for generating 183 binary outlier labels. 184 185 labels_ : int, either 0 or 1 186 The binary labels of the training data. 0 stands for inliers 187 and 1 for outliers/anomalies. It is generated by applying 188 ``threshold_`` on ``decision_scores_``. 189 """ 190 191 def __init__(self, activation_hidden='tanh', dropout_rate=0.2, 192 latent_dim_G=2, 193 G_layers=[20, 10, 3, 10, 20], verbose=0, D_layers=[20, 10, 5], 194 index_D_layer_for_recon_error=1, epochs=500, 195 preprocessing=False, 196 learning_rate=0.001, learning_rate_query=0.01, 197 epochs_query=20, 198 batch_size=32, output_activation=None, contamination=0.1, 199 device=None): 200 super(AnoGAN, self).__init__(contamination=contamination) 201 202 self.activation_hidden = activation_hidden 203 self.dropout_rate = dropout_rate 204 self.latent_dim_G = latent_dim_G 205 self.G_layers = G_layers 206 self.D_layers = D_layers 207 self.index_D_layer_for_recon_error = index_D_layer_for_recon_error 208 self.output_activation = output_activation 209 self.contamination = contamination 210 self.epochs = epochs 211 self.learning_rate = learning_rate 212 self.learning_rate_query = learning_rate_query 213 self.epochs_query = epochs_query 214 self.preprocessing = preprocessing 215 self.batch_size = batch_size 216 self.verbose = verbose 217 218 self.hist_loss_generator = [] 219 self.hist_loss_discriminator = [] 220 221 self.device = device 222 223 check_parameter(dropout_rate, 0, 1, 224 param_name='dropout_rate', include_left=True) 225 226 def plot_learning_curves(self, start_ind=0, 227 window_smoothening=10): # pragma: no cover 228 fig = plt.figure(figsize=(12, 5)) 229 l_gen = pd.Series(self.hist_loss_generator[start_ind:]).rolling( 230 window_smoothening).mean() 231 l_disc = pd.Series(self.hist_loss_discriminator[start_ind:]).rolling( 232 window_smoothening).mean() 233 234 ax = fig.add_subplot(1, 2, 1) 235 ax.plot(range(len(l_gen)), l_gen) 236 ax.set_title('Generator') 237 ax.set_ylabel('Loss') 238 ax.set_xlabel('Iter') 239 240 ax = fig.add_subplot(1, 2, 2) 241 ax.plot(range(len(l_disc)), l_disc) 242 ax.set_title('Discriminator') 243 ax.set_ylabel('Loss') 244 ax.set_xlabel('Iter') 245 246 plt.show() 247 248 def fit(self, X, y=None): 249 """Fit detector. y is ignored in unsupervised methods. 250 251 Parameters 252 ---------- 253 X : numpy array of shape (n_samples, n_features) 254 The input samples. 255 256 y : Ignored 257 Not used, present for API consistency by convention. 258 259 Returns 260 ------- 261 self : object 262 Fitted estimator. 263 """ 264 # validate inputs X and y (optional) 265 X = check_array(X) 266 self._set_n_classes(y) 267 268 # Verify and construct the hidden units 269 self.n_samples_, self.n_features_ = X.shape 270 271 # Standardize data for better performance 272 if self.preprocessing: 273 self.scaler_ = StandardScaler() 274 X_norm = self.scaler_.fit_transform(X) 275 else: 276 X_norm = np.copy(X) 277 X_norm = torch.tensor(X_norm, dtype=torch.float32) 278 # train the discriminator and generator 279 self.generator = Generator(latent_dim_G=self.latent_dim_G, 280 n_features=self.n_features_, 281 G_layers=self.G_layers, 282 dropout_rate=self.dropout_rate, 283 activation_hidden=self.activation_hidden, 284 output_activation=self.output_activation) 285 self.discriminator = Discriminator(n_features=self.n_features_, 286 D_layers=self.D_layers, 287 dropout_rate=self.dropout_rate, 288 activation_hidden=self.activation_hidden) 289 290 self.generator.to(self.device) 291 self.discriminator.to(self.device) 292 293 optimizer_g = optim.Adam( 294 self.generator.parameters(), lr=self.learning_rate) 295 optimizer_d = optim.Adam( 296 self.discriminator.parameters(), lr=self.learning_rate) 297 298 dataset = TensorDataset(X_norm) 299 dataloader = DataLoader( 300 dataset, batch_size=self.batch_size, shuffle=True) 301 302 for n in range(self.epochs): 303 if n % 100 == 0 and n != 0 and self.verbose == 1: 304 print(f'Train iter: {n}') 305 306 self.generator.train() 307 self.discriminator.train() 308 for X_train_ in dataloader: 309 X_train_sel = X_train_[0].to(self.device) 310 latent_noise = torch.rand(X_train_sel.size( 311 0), self.latent_dim_G, dtype=torch.float32).to(self.device) 312 313 generated_data = self.generator(latent_noise) 314 real_output = self.discriminator(X_train_sel) 315 fake_output = self.discriminator(generated_data.detach()) 316 317 loss_D_real = nn.BCELoss()(real_output, torch.ones_like( 318 real_output) * 0.9).to(self.device) 319 loss_D_fake = nn.BCELoss()(fake_output, 320 torch.zeros_like(fake_output)).to( 321 self.device) 322 loss_D = loss_D_real + loss_D_fake 323 optimizer_d.zero_grad() 324 loss_D.backward() 325 optimizer_d.step() 326 327 fake_output = self.discriminator(generated_data) 328 loss_G = nn.BCELoss()(fake_output, 329 torch.ones_like(fake_output)).to( 330 self.device) 331 optimizer_g.zero_grad() 332 loss_G.backward() 333 optimizer_g.step() 334 335 self.hist_loss_discriminator.append(loss_D.item()) 336 self.hist_loss_generator.append(loss_G.item()) 337 338 # Instantiate and train the query model 339 self.generator.eval() 340 self.discriminator.eval() 341 self.query_model = QueryModel( 342 self.generator, self.discriminator, self.latent_dim_G).to( 343 self.device) 344 optimizer_query = optim.Adam( 345 self.query_model.parameters(), lr=self.learning_rate_query) 346 scores = [] 347 # For each sample, use a few backpropagation steps to obtain a point in the latent space that best resembles the query sample 348 self.query_model.train() 349 for i in range(X_norm.shape[0]): 350 if self.verbose == 1: 351 print('query sample {} / {}'.format(i + 1, X_norm.shape[0])) 352 353 query_sample = X_norm[[i],].to(self.device) 354 assert (query_sample.shape[0] == 1) 355 assert (query_sample.shape[1] == self.n_features_) 356 357 # Make pseudo input (just zeros) 358 zeros = torch.zeros((1, self.latent_dim_G)).to(self.device) 359 360 # build model for back-propagating a approximate latent space where 361 # reconstruction with query sample is optimal 362 for i in range(self.epochs_query): 363 if i % 25 == 0 and self.verbose == 1: 364 print('iter:', i) 365 366 z, sample_gen, sample_disc_latent = self.query_model(zeros) 367 with torch.no_grad(): 368 sample_disc_latent_original = self.discriminator( 369 query_sample) 370 # Reconstruction loss generator 371 loss_recon_gen = torch.mean(torch.mean( 372 torch.abs(query_sample - sample_gen), axis=-1)) 373 # Reconstruction loss latent space of discrimator 374 loss_recon_disc = torch.mean(torch.mean( 375 torch.abs( 376 sample_disc_latent_original - sample_disc_latent), 377 axis=-1)) 378 total_loss = loss_recon_gen + loss_recon_disc 379 380 optimizer_query.zero_grad() 381 total_loss.backward() 382 optimizer_query.step() 383 # Predict on X itself and calculate the reconstruction error as 384 # the outlier scores. 385 scores.append(total_loss.item()) 386 387 self.decision_scores_ = np.array(scores) 388 self._process_decision_scores() 389 return self 390 391 def decision_function(self, X): 392 """Predict raw anomaly score of X using the fitted detector. 393 394 The anomaly score of an input sample is computed based on different 395 detector algorithms. For consistency, outliers are assigned with 396 larger anomaly scores. 397 398 Parameters 399 ---------- 400 X : numpy array of shape (n_samples, n_features) 401 The training input samples. Sparse matrices are accepted only 402 if they are supported by the base estimator. 403 404 Returns 405 ------- 406 anomaly_scores : numpy array of shape (n_samples,) 407 The anomaly score of the input samples. 408 """ 409 check_is_fitted(self, ['decision_scores_']) 410 X = check_array(X) 411 412 if self.preprocessing: 413 X_norm = self.scaler_.transform(X) 414 else: 415 X_norm = np.copy(X) 416 417 X_norm = torch.tensor(X_norm, dtype=torch.float32) 418 419 # Predict on X 420 pred_scores = [] 421 422 self.query_model.eval() 423 with torch.no_grad(): 424 for i in range(X_norm.shape[0]): 425 if self.verbose == 1: 426 print( 427 'query sample {} / {}'.format(i + 1, X_norm.shape[0])) 428 429 query_sample = X_norm[[i],].to(self.device) 430 431 zeros = torch.zeros((1, self.latent_dim_G)).to(self.device) 432 z, sample_gen, sample_disc_latent = self.query_model(zeros) 433 sample_disc_latent_original = self.discriminator(query_sample) 434 435 loss_recon_gen = torch.mean(torch.mean( 436 torch.abs(query_sample - sample_gen), axis=-1)) 437 loss_recon_disc = torch.mean(torch.mean( 438 torch.abs( 439 sample_disc_latent_original - sample_disc_latent), 440 axis=-1)) 441 total_loss = loss_recon_gen + loss_recon_disc 442 pred_scores.append(total_loss.item()) 443 444 return np.array(pred_scores)