gaal_base.py
1 # -*- coding: utf-8 -*- 2 """Base file for Generative Adversarial Active Learning. 3 Part of the codes are adapted from 4 https://github.com/leibinghe/GAAL-based-outlier-detection 5 """ 6 7 import math 8 9 try: 10 import torch 11 except ImportError: 12 print('please install torch first') 13 14 import torch 15 import torch.nn as nn 16 import torch.nn.functional as F 17 18 19 def create_discriminator(latent_size, data_size): 20 """ 21 Create the discriminator of the GAN for a given latent size. 22 23 Parameters 24 ---------- 25 latent_size : int 26 The size of the latent space of the generator. 27 data_size : int 28 Size of the input data. 29 30 Returns 31 ------- 32 discriminator : torch.nn.Module 33 A PyTorch model of the discriminator. 34 """ 35 36 class Discriminator(nn.Module): 37 def __init__(self, latent_size, data_size): 38 super(Discriminator, self).__init__() 39 self.layer1 = nn.Linear(latent_size, 40 math.ceil(math.sqrt(data_size))) 41 self.layer2 = nn.Linear(math.ceil(math.sqrt(data_size)), 1) 42 nn.init.kaiming_normal_(self.layer1.weight, mode='fan_in', 43 nonlinearity='relu') 44 nn.init.kaiming_normal_(self.layer2.weight, mode='fan_in', 45 nonlinearity='sigmoid') 46 47 def forward(self, x): 48 x = F.relu(self.layer1(x)) 49 x = torch.sigmoid(self.layer2(x)) 50 return x 51 52 return Discriminator(latent_size, data_size) 53 54 55 def create_generator(latent_size): 56 """ 57 Create the generator of the GAN for a given latent size. 58 59 Parameters 60 ---------- 61 latent_size : int 62 The size of the latent space of the generator. 63 64 Returns 65 ------- 66 generator : torch.nn.Module 67 A PyTorch model of the generator. 68 """ 69 70 class Generator(nn.Module): 71 def __init__(self, latent_size): 72 super(Generator, self).__init__() 73 self.layer1 = nn.Linear(latent_size, latent_size) 74 self.layer2 = nn.Linear(latent_size, latent_size) 75 nn.init.eye_(self.layer1.weight) 76 nn.init.eye_(self.layer2.weight) 77 78 def forward(self, x): 79 x = F.relu(self.layer1(x)) 80 x = F.relu(self.layer2(x)) 81 return x 82 83 return Generator(latent_size)