/ gasnet / encoder.py
encoder.py
  1  import sys
  2  sys.path.append('/home/dvalsamis/Documents/projects/GAS-Net-main')
  3  import torch.nn as nn
  4  import torch.utils.model_zoo as model_zoo
  5  import math
  6  
  7  
  8  def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
  9      return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
 10                       padding=dilation, groups=groups, bias=False, dilation=dilation)
 11  
 12  model_urls = {
 13      'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
 14      'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
 15      'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
 16      'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
 17      'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
 18  }
 19  
 20  def ResNet34(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3):
 21      """
 22      output, low_level_feat:
 23      512, 64
 24      """
 25      print(in_c)
 26      model = ResNet(BasicBlock, [3, 4, 6, 3], output_stride, BatchNorm, in_c=in_c)
 27      if in_c != 3:
 28          pretrained = False
 29      if pretrained:
 30          model._load_pretrained_model(model_urls['resnet34'])
 31      return model
 32  
 33  def ResNet18(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3):
 34      """
 35      output, low_level_feat:
 36      512, 256, 128, 64, 64
 37      """
 38      model = ResNet(BasicBlock, [2, 2, 2, 2], output_stride, BatchNorm, in_c=in_c)
 39      if in_c !=3:
 40          pretrained=False
 41      if pretrained:
 42          model._load_pretrained_model(model_urls['resnet18'])
 43      return model
 44  
 45  def ResNet50(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3):
 46      """
 47      output, low_level_feat:
 48      2048, 256
 49      """
 50      model = ResNet(Bottleneck, [3, 4, 6, 3], output_stride, BatchNorm, in_c=in_c)
 51      if in_c !=3:
 52          pretrained=False
 53      if pretrained:
 54          model._load_pretrained_model(model_urls['resnet50'])
 55      return model
 56  
 57  
 58  class BasicBlock(nn.Module):
 59      expansion = 1
 60  
 61      def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
 62          super(BasicBlock, self).__init__()
 63  
 64          self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
 65                                 dilation=dilation, padding=dilation, bias=False)
 66          self.bn1 = BatchNorm(planes)
 67          self.relu = nn.ReLU(inplace=True)
 68          # self.do1 = nn.Dropout2d(p=0.2)
 69  
 70          self.conv2 = conv3x3(planes, planes)
 71          self.bn2 = BatchNorm(planes)
 72          self.downsample = downsample
 73          self.stride = stride
 74  
 75      def forward(self, x):
 76          identity = x
 77  
 78          out = self.conv1(x)
 79          out = self.bn1(out)
 80          out = self.relu(out)
 81          # out = self.do1(out)
 82  
 83          out = self.conv2(out)
 84          out = self.bn2(out)
 85  
 86          if self.downsample is not None:
 87              identity = self.downsample(x)
 88  
 89          out += identity
 90          out = self.relu(out)
 91  
 92          return out
 93  
 94  
 95  class Bottleneck(nn.Module):
 96      expansion = 4
 97  
 98      def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
 99          super(Bottleneck, self).__init__()
100          self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
101          self.bn1 = BatchNorm(planes)
102          self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
103                                 dilation=dilation, padding=dilation, bias=False)
104          self.bn2 = BatchNorm(planes)
105          self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
106          self.bn3 = BatchNorm(planes * 4)
107          self.relu = nn.ReLU()
108          self.downsample = downsample
109          self.stride = stride
110          self.dilation = dilation
111  
112      def forward(self, x):
113          residual = x
114  
115          out = self.conv1(x)
116          out = self.bn1(out)
117          out = self.relu(out)
118  
119          out = self.conv2(out)
120          out = self.bn2(out)
121          out = self.relu(out)
122  
123          out = self.conv3(out)
124          out = self.bn3(out)
125  
126          if self.downsample is not None:
127              residual = self.downsample(x)
128  
129          out += residual
130          out = self.relu(out)
131  
132          return out
133  
134  class ResNet(nn.Module):
135  
136      def __init__(self,  block, layers, output_stride, BatchNorm, pretrained=True, in_c=3):
137  
138          self.inplanes = 64
139          self.in_c = in_c
140          print('in_c: ',self.in_c)
141          super(ResNet, self).__init__()
142          blocks = [1, 2, 4]
143          if output_stride == 32:
144              strides = [1, 2, 2, 2]
145              dilations = [1, 1, 1, 1]
146          elif output_stride == 16:
147              strides = [1, 2, 2, 1]
148              dilations = [1, 1, 1, 2]
149          elif output_stride == 8:
150              strides = [1, 2, 1, 1]
151              dilations = [1, 1, 2, 4]
152          elif output_stride == 4:
153              strides = [1, 1, 1, 1]
154              dilations = [1, 2, 4, 8]
155          else:
156              raise NotImplementedError
157  
158          # Modules
159          self.conv1 = nn.Conv2d(self.in_c, 64, kernel_size=7, stride=1, padding=3,
160                                  bias=False)
161          self.bn1 = BatchNorm(64)
162          self.relu = nn.ReLU(inplace=True)
163          self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
164          self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm)
165          self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm)
166          self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm)
167          self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
168          self._init_weight()
169  
170  
171      def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
172          downsample = None
173          if stride != 1 or self.inplanes != planes * block.expansion:
174              downsample = nn.Sequential(
175                  nn.Conv2d(self.inplanes, planes * block.expansion,
176                            kernel_size=1, stride=stride, bias=False),
177                  BatchNorm(planes * block.expansion),
178              )
179  
180          layers = []
181          layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm))
182          self.inplanes = planes * block.expansion
183          for i in range(1, blocks):
184              layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm))
185  
186          return nn.Sequential(*layers)
187  
188      def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
189          downsample = None
190          if stride != 1 or self.inplanes != planes * block.expansion:
191              downsample = nn.Sequential(
192                  nn.Conv2d(self.inplanes, planes * block.expansion,
193                            kernel_size=1, stride=stride, bias=False),
194                  BatchNorm(planes * block.expansion),
195              )
196  
197          layers = []
198          layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,
199                              downsample=downsample, BatchNorm=BatchNorm))
200          self.inplanes = planes * block.expansion
201          for i in range(1, len(blocks)):
202              layers.append(block(self.inplanes, planes, stride=1,
203                                  dilation=blocks[i]*dilation, BatchNorm=BatchNorm))
204  
205          return nn.Sequential(*layers)
206  
207      def forward(self, input):
208          x = self.conv1(input)
209          x = self.bn1(x)
210          x = self.relu(x)
211          x = self.maxpool(x)
212          x = self.layer1(x)
213          low_level_feat2 = x
214          x = self.layer2(x)
215          low_level_feat3 = x
216          x = self.layer3(x)
217          low_level_feat4 = x
218          x = self.layer4(x)
219          return x, low_level_feat2, low_level_feat3, low_level_feat4
220  
221      def _init_weight(self):
222          for m in self.modules():
223              if isinstance(m, nn.Conv2d):
224                  n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
225                  m.weight.data.normal_(0, math.sqrt(2. / n))
226              elif isinstance(m, nn.BatchNorm2d):
227                  m.weight.data.fill_(1)
228                  m.bias.data.zero_()
229  
230      def _load_pretrained_model(self, model_path):
231          pretrain_dict = model_zoo.load_url(model_path)
232          model_dict = {}
233          state_dict = self.state_dict()
234          for k, v in pretrain_dict.items():
235              if k in state_dict:
236                  model_dict[k] = v
237          state_dict.update(model_dict)
238          self.load_state_dict(state_dict)
239  
240  def build_backbone(backbone, output_stride, BatchNorm, in_c=3):
241      if backbone == 'resnet50':
242          return ResNet50(output_stride, BatchNorm, in_c=in_c)
243      elif backbone == 'resnet34':
244          return ResNet34(output_stride, BatchNorm, in_c=in_c)
245      elif backbone == 'resnet18':
246          return ResNet18(output_stride, BatchNorm, in_c=in_c)
247      else:
248          raise NotImplementedError
249