/ gasnet / decoder.py
decoder.py
 1  import sys
 2  sys.path.append('/home/dvalsamis/Documents/projects/GAS-Net-main')
 3  import torch
 4  import torch.nn as nn
 5  import torch.nn.functional as F
 6  
 7  class DR(nn.Module):
 8      def __init__(self, in_d, out_d):
 9          super(DR, self).__init__()
10          self.in_d = in_d
11          self.out_d = out_d
12          self.conv1 = nn.Conv2d(self.in_d, self.out_d, 1, bias=False)
13          self.bn1 = nn.BatchNorm2d(self.out_d)
14          self.relu = nn.ReLU()
15  
16      def forward(self, input):
17          x = self.conv1(input)
18          x = self.bn1(x)
19          x = self.relu(x)
20          return x
21  
22  class Decoder_GASN(nn.Module):
23      def __init__(self, fc, BatchNorm):
24          super(Decoder_GASN, self).__init__()
25          self.fc = fc
26          self.dr2 = DR(64, 64)
27          self.dr3 = DR(128, 64)
28          self.dr4 = DR(256, 64)
29          self.dr5 = DR(512, 64)
30          self.last_conv = nn.Sequential(nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1, bias=False),
31                                         BatchNorm(128),
32                                         nn.ReLU(),
33                                         nn.Dropout(0.5),
34                                         nn.Conv2d(128, self.fc, kernel_size=1, stride=1, padding=0, bias=False),
35                                         BatchNorm(self.fc),
36                                         nn.ReLU(),
37                                         )
38          self._init_weight()
39          self.FAM_feat  = [512, 512, 512, 256]
40          self.FAM = FAM_layers(self.FAM_feat, in_channels=256, dilation=True)
41          self.FAM_output_layer = nn.Conv2d(256, 256, kernel_size=1)
42  
43      def forward(self, x, low_level_feat2, low_level_feat3, low_level_feat4):
44          
45          x2 = self.dr2(low_level_feat2)
46          x3 = self.dr3(low_level_feat3)
47          x4 = self.dr4(low_level_feat4)
48          x = self.dr5(x)
49  
50          x = F.interpolate(x, size=x2.size()[2:], mode='bilinear', align_corners=True)
51          x3 = F.interpolate(x3, size=x2.size()[2:], mode='bilinear', align_corners=True)
52          x4 = F.interpolate(x4, size=x2.size()[2:], mode='bilinear', align_corners=True)
53          
54          x = torch.cat((x, x2, x3, x4), dim=1)
55          x = x + self.FAM(x)
56  
57          # x = torch.cat((self.FAM(x), x), dim=1)
58          x = self.last_conv(x)
59  
60          return x
61  
62      def _init_weight(self):
63          for m in self.modules():
64              if isinstance(m, nn.Conv2d):
65                  torch.nn.init.kaiming_normal_(m.weight)
66              elif isinstance(m, nn.BatchNorm2d):
67                  m.weight.data.fill_(1)
68                  m.bias.data.zero_()
69  
70  def build_decoderGASN(fc, BatchNorm):
71      return Decoder_GASN(fc, BatchNorm)
72  
73  def FAM_layers(cfg, in_channels=3, batch_norm=False, dilation=False):
74      if dilation:
75          d_rate = 3
76      else:
77          d_rate = 1
78      layers = []
79      for v in cfg:
80          conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate, dilation=d_rate)
81          if batch_norm:
82              layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
83          else:
84              layers += [conv2d, nn.ReLU(inplace=True)]
85          in_channels = v
86      return nn.Sequential(*layers)