decoder.py
1 import sys 2 sys.path.append('/home/dvalsamis/Documents/projects/GAS-Net-main') 3 import torch 4 import torch.nn as nn 5 import torch.nn.functional as F 6 7 class DR(nn.Module): 8 def __init__(self, in_d, out_d): 9 super(DR, self).__init__() 10 self.in_d = in_d 11 self.out_d = out_d 12 self.conv1 = nn.Conv2d(self.in_d, self.out_d, 1, bias=False) 13 self.bn1 = nn.BatchNorm2d(self.out_d) 14 self.relu = nn.ReLU() 15 16 def forward(self, input): 17 x = self.conv1(input) 18 x = self.bn1(x) 19 x = self.relu(x) 20 return x 21 22 class Decoder_GASN(nn.Module): 23 def __init__(self, fc, BatchNorm): 24 super(Decoder_GASN, self).__init__() 25 self.fc = fc 26 self.dr2 = DR(64, 64) 27 self.dr3 = DR(128, 64) 28 self.dr4 = DR(256, 64) 29 self.dr5 = DR(512, 64) 30 self.last_conv = nn.Sequential(nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1, bias=False), 31 BatchNorm(128), 32 nn.ReLU(), 33 nn.Dropout(0.5), 34 nn.Conv2d(128, self.fc, kernel_size=1, stride=1, padding=0, bias=False), 35 BatchNorm(self.fc), 36 nn.ReLU(), 37 ) 38 self._init_weight() 39 self.FAM_feat = [512, 512, 512, 256] 40 self.FAM = FAM_layers(self.FAM_feat, in_channels=256, dilation=True) 41 self.FAM_output_layer = nn.Conv2d(256, 256, kernel_size=1) 42 43 def forward(self, x, low_level_feat2, low_level_feat3, low_level_feat4): 44 45 x2 = self.dr2(low_level_feat2) 46 x3 = self.dr3(low_level_feat3) 47 x4 = self.dr4(low_level_feat4) 48 x = self.dr5(x) 49 50 x = F.interpolate(x, size=x2.size()[2:], mode='bilinear', align_corners=True) 51 x3 = F.interpolate(x3, size=x2.size()[2:], mode='bilinear', align_corners=True) 52 x4 = F.interpolate(x4, size=x2.size()[2:], mode='bilinear', align_corners=True) 53 54 x = torch.cat((x, x2, x3, x4), dim=1) 55 x = x + self.FAM(x) 56 57 # x = torch.cat((self.FAM(x), x), dim=1) 58 x = self.last_conv(x) 59 60 return x 61 62 def _init_weight(self): 63 for m in self.modules(): 64 if isinstance(m, nn.Conv2d): 65 torch.nn.init.kaiming_normal_(m.weight) 66 elif isinstance(m, nn.BatchNorm2d): 67 m.weight.data.fill_(1) 68 m.bias.data.zero_() 69 70 def build_decoderGASN(fc, BatchNorm): 71 return Decoder_GASN(fc, BatchNorm) 72 73 def FAM_layers(cfg, in_channels=3, batch_norm=False, dilation=False): 74 if dilation: 75 d_rate = 3 76 else: 77 d_rate = 1 78 layers = [] 79 for v in cfg: 80 conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate, dilation=d_rate) 81 if batch_norm: 82 layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 83 else: 84 layers += [conv2d, nn.ReLU(inplace=True)] 85 in_channels = v 86 return nn.Sequential(*layers)