/ pyod / models / gaal_base.py
gaal_base.py
 1  # -*- coding: utf-8 -*-
 2  """Base file for Generative Adversarial Active Learning.
 3  Part of the codes are adapted from
 4  https://github.com/leibinghe/GAAL-based-outlier-detection
 5  """
 6  
 7  import math
 8  
 9  try:
10      import torch
11  except ImportError:
12      print('please install torch first')
13  
14  import torch
15  import torch.nn as nn
16  import torch.nn.functional as F
17  
18  
19  def create_discriminator(latent_size, data_size):
20      """
21      Create the discriminator of the GAN for a given latent size.
22  
23      Parameters
24      ----------
25      latent_size : int
26          The size of the latent space of the generator.
27      data_size : int
28          Size of the input data.
29  
30      Returns
31      -------
32      discriminator : torch.nn.Module
33          A PyTorch model of the discriminator.
34      """
35  
36      class Discriminator(nn.Module):
37          def __init__(self, latent_size, data_size):
38              super(Discriminator, self).__init__()
39              self.layer1 = nn.Linear(latent_size,
40                                      math.ceil(math.sqrt(data_size)))
41              self.layer2 = nn.Linear(math.ceil(math.sqrt(data_size)), 1)
42              nn.init.kaiming_normal_(self.layer1.weight, mode='fan_in',
43                                      nonlinearity='relu')
44              nn.init.kaiming_normal_(self.layer2.weight, mode='fan_in',
45                                      nonlinearity='sigmoid')
46  
47          def forward(self, x):
48              x = F.relu(self.layer1(x))
49              x = torch.sigmoid(self.layer2(x))
50              return x
51  
52      return Discriminator(latent_size, data_size)
53  
54  
55  def create_generator(latent_size):
56      """
57      Create the generator of the GAN for a given latent size.
58  
59      Parameters
60      ----------
61      latent_size : int
62          The size of the latent space of the generator.
63  
64      Returns
65      -------
66      generator : torch.nn.Module
67          A PyTorch model of the generator.
68      """
69  
70      class Generator(nn.Module):
71          def __init__(self, latent_size):
72              super(Generator, self).__init__()
73              self.layer1 = nn.Linear(latent_size, latent_size)
74              self.layer2 = nn.Linear(latent_size, latent_size)
75              nn.init.eye_(self.layer1.weight)
76              nn.init.eye_(self.layer2.weight)
77  
78          def forward(self, x):
79              x = F.relu(self.layer1(x))
80              x = F.relu(self.layer2(x))
81              return x
82  
83      return Generator(latent_size)