encoder.py
1 import sys 2 sys.path.append('/home/dvalsamis/Documents/projects/GAS-Net-main') 3 import torch.nn as nn 4 import torch.utils.model_zoo as model_zoo 5 import math 6 7 8 def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 9 return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 padding=dilation, groups=groups, bias=False, dilation=dilation) 11 12 model_urls = { 13 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 14 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 15 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 16 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 17 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 18 } 19 20 def ResNet34(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3): 21 """ 22 output, low_level_feat: 23 512, 64 24 """ 25 print(in_c) 26 model = ResNet(BasicBlock, [3, 4, 6, 3], output_stride, BatchNorm, in_c=in_c) 27 if in_c != 3: 28 pretrained = False 29 if pretrained: 30 model._load_pretrained_model(model_urls['resnet34']) 31 return model 32 33 def ResNet18(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3): 34 """ 35 output, low_level_feat: 36 512, 256, 128, 64, 64 37 """ 38 model = ResNet(BasicBlock, [2, 2, 2, 2], output_stride, BatchNorm, in_c=in_c) 39 if in_c !=3: 40 pretrained=False 41 if pretrained: 42 model._load_pretrained_model(model_urls['resnet18']) 43 return model 44 45 def ResNet50(output_stride, BatchNorm=nn.BatchNorm2d, pretrained=True, in_c=3): 46 """ 47 output, low_level_feat: 48 2048, 256 49 """ 50 model = ResNet(Bottleneck, [3, 4, 6, 3], output_stride, BatchNorm, in_c=in_c) 51 if in_c !=3: 52 pretrained=False 53 if pretrained: 54 model._load_pretrained_model(model_urls['resnet50']) 55 return model 56 57 58 class BasicBlock(nn.Module): 59 expansion = 1 60 61 def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 62 super(BasicBlock, self).__init__() 63 64 self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, 65 dilation=dilation, padding=dilation, bias=False) 66 self.bn1 = BatchNorm(planes) 67 self.relu = nn.ReLU(inplace=True) 68 # self.do1 = nn.Dropout2d(p=0.2) 69 70 self.conv2 = conv3x3(planes, planes) 71 self.bn2 = BatchNorm(planes) 72 self.downsample = downsample 73 self.stride = stride 74 75 def forward(self, x): 76 identity = x 77 78 out = self.conv1(x) 79 out = self.bn1(out) 80 out = self.relu(out) 81 # out = self.do1(out) 82 83 out = self.conv2(out) 84 out = self.bn2(out) 85 86 if self.downsample is not None: 87 identity = self.downsample(x) 88 89 out += identity 90 out = self.relu(out) 91 92 return out 93 94 95 class Bottleneck(nn.Module): 96 expansion = 4 97 98 def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 99 super(Bottleneck, self).__init__() 100 self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 101 self.bn1 = BatchNorm(planes) 102 self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 103 dilation=dilation, padding=dilation, bias=False) 104 self.bn2 = BatchNorm(planes) 105 self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 106 self.bn3 = BatchNorm(planes * 4) 107 self.relu = nn.ReLU() 108 self.downsample = downsample 109 self.stride = stride 110 self.dilation = dilation 111 112 def forward(self, x): 113 residual = x 114 115 out = self.conv1(x) 116 out = self.bn1(out) 117 out = self.relu(out) 118 119 out = self.conv2(out) 120 out = self.bn2(out) 121 out = self.relu(out) 122 123 out = self.conv3(out) 124 out = self.bn3(out) 125 126 if self.downsample is not None: 127 residual = self.downsample(x) 128 129 out += residual 130 out = self.relu(out) 131 132 return out 133 134 class ResNet(nn.Module): 135 136 def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True, in_c=3): 137 138 self.inplanes = 64 139 self.in_c = in_c 140 print('in_c: ',self.in_c) 141 super(ResNet, self).__init__() 142 blocks = [1, 2, 4] 143 if output_stride == 32: 144 strides = [1, 2, 2, 2] 145 dilations = [1, 1, 1, 1] 146 elif output_stride == 16: 147 strides = [1, 2, 2, 1] 148 dilations = [1, 1, 1, 2] 149 elif output_stride == 8: 150 strides = [1, 2, 1, 1] 151 dilations = [1, 1, 2, 4] 152 elif output_stride == 4: 153 strides = [1, 1, 1, 1] 154 dilations = [1, 2, 4, 8] 155 else: 156 raise NotImplementedError 157 158 # Modules 159 self.conv1 = nn.Conv2d(self.in_c, 64, kernel_size=7, stride=1, padding=3, 160 bias=False) 161 self.bn1 = BatchNorm(64) 162 self.relu = nn.ReLU(inplace=True) 163 self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 164 self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm) 165 self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm) 166 self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm) 167 self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 168 self._init_weight() 169 170 171 def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 172 downsample = None 173 if stride != 1 or self.inplanes != planes * block.expansion: 174 downsample = nn.Sequential( 175 nn.Conv2d(self.inplanes, planes * block.expansion, 176 kernel_size=1, stride=stride, bias=False), 177 BatchNorm(planes * block.expansion), 178 ) 179 180 layers = [] 181 layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) 182 self.inplanes = planes * block.expansion 183 for i in range(1, blocks): 184 layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) 185 186 return nn.Sequential(*layers) 187 188 def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 189 downsample = None 190 if stride != 1 or self.inplanes != planes * block.expansion: 191 downsample = nn.Sequential( 192 nn.Conv2d(self.inplanes, planes * block.expansion, 193 kernel_size=1, stride=stride, bias=False), 194 BatchNorm(planes * block.expansion), 195 ) 196 197 layers = [] 198 layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, 199 downsample=downsample, BatchNorm=BatchNorm)) 200 self.inplanes = planes * block.expansion 201 for i in range(1, len(blocks)): 202 layers.append(block(self.inplanes, planes, stride=1, 203 dilation=blocks[i]*dilation, BatchNorm=BatchNorm)) 204 205 return nn.Sequential(*layers) 206 207 def forward(self, input): 208 x = self.conv1(input) 209 x = self.bn1(x) 210 x = self.relu(x) 211 x = self.maxpool(x) 212 x = self.layer1(x) 213 low_level_feat2 = x 214 x = self.layer2(x) 215 low_level_feat3 = x 216 x = self.layer3(x) 217 low_level_feat4 = x 218 x = self.layer4(x) 219 return x, low_level_feat2, low_level_feat3, low_level_feat4 220 221 def _init_weight(self): 222 for m in self.modules(): 223 if isinstance(m, nn.Conv2d): 224 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 225 m.weight.data.normal_(0, math.sqrt(2. / n)) 226 elif isinstance(m, nn.BatchNorm2d): 227 m.weight.data.fill_(1) 228 m.bias.data.zero_() 229 230 def _load_pretrained_model(self, model_path): 231 pretrain_dict = model_zoo.load_url(model_path) 232 model_dict = {} 233 state_dict = self.state_dict() 234 for k, v in pretrain_dict.items(): 235 if k in state_dict: 236 model_dict[k] = v 237 state_dict.update(model_dict) 238 self.load_state_dict(state_dict) 239 240 def build_backbone(backbone, output_stride, BatchNorm, in_c=3): 241 if backbone == 'resnet50': 242 return ResNet50(output_stride, BatchNorm, in_c=in_c) 243 elif backbone == 'resnet34': 244 return ResNet34(output_stride, BatchNorm, in_c=in_c) 245 elif backbone == 'resnet18': 246 return ResNet18(output_stride, BatchNorm, in_c=in_c) 247 else: 248 raise NotImplementedError 249