/ gasnet / CDNet_L.py
CDNet_L.py
  1  import sys
  2  sys.path.append('/home/dvalsamis/Documents/projects/GAS-Net-main')
  3  
  4  import torch
  5  import torch.nn as nn
  6  import torch.nn.functional as F
  7  from gasnet.encoder import build_backbone
  8  from gasnet.decoder import build_decoderGASN
  9  from gasnet.data_utils import get_transform
 10  import os 
 11  os.environ["CUDA_VISIBLE_DEVICES"]="1"
 12  
 13  class ChannelAttention(nn.Module):
 14      def __init__(self, in_planes, ratio=8):
 15          super(ChannelAttention, self).__init__()
 16          self.avg_pool = nn.AdaptiveAvgPool2d(1)
 17          self.max_pool = nn.AdaptiveMaxPool2d(1)
 18  
 19          self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
 20          self.relu1 = nn.ReLU()
 21          self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
 22          self.sigmoid = nn.Sigmoid()
 23  
 24      def forward(self, x):
 25          avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
 26          max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
 27          out = avg_out + max_out
 28          return self.sigmoid(out)
 29  
 30  class SpatialAttention(nn.Module):
 31      def __init__(self, kernel_size=3):
 32          super(SpatialAttention, self).__init__()
 33  
 34          assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
 35          padding = 3 if kernel_size == 7 else 1
 36  
 37          self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
 38          self.sigmoid = nn.Sigmoid()
 39  
 40      def forward(self, x):
 41          avg_out = torch.mean(x, dim=1, keepdim=True)
 42          max_out, _ = torch.max(x, dim=1, keepdim=True)
 43          x = torch.cat([avg_out, max_out], dim=1)
 44          x = self.conv1(x)
 45          return self.sigmoid(x)
 46  
 47  class CBAM(nn.Module):
 48      def __init__(self, in_planes, ratio=8, kernel_size=3):
 49          super(CBAM, self).__init__()
 50          self.ca = ChannelAttention(in_planes, ratio)
 51          self.sa = SpatialAttention(kernel_size)
 52      def forward(self, x):
 53          x = self.ca(x) * x
 54          x = self.sa(x) * x
 55          return x
 56  
 57  class CDNet_L(nn.Module):
 58      def __init__(self, backbone='resnet18', output_stride=16, f_c=64, freeze_bn=False, in_c=3):
 59          super(CDNet_L, self).__init__()
 60          BatchNorm = nn.BatchNorm2d
 61  
 62          self.transform = get_transform(convert=True, normalize=True)
 63  
 64          self.backbone = build_backbone(backbone, output_stride, BatchNorm, in_c)
 65          self.decoder = build_decoderGASN(f_c, BatchNorm)
 66  
 67          self.cbam0 = CBAM(64)
 68          self.cbam1 = CBAM(64)
 69  
 70          self.cbam2 = CBAM(64)
 71          self.cbam3 = CBAM(128)
 72          self.cbam4 = CBAM(256)
 73          self.cbam5 = CBAM(512)
 74  
 75          if freeze_bn:
 76              self.freeze_bn()
 77  
 78          # self.conv_up = nn.Upsample(scale_factor=2, mode='bilinear')
 79          self.conv_final = nn.Sequential(
 80              nn.Upsample(scale_factor=2, mode='bilinear'),
 81              nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1),
 82              nn.ReLU(),
 83              nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1),
 84              nn.ReLU(),
 85              nn.Conv2d(in_channels=32, out_channels=1, kernel_size=1, padding=0),
 86          )
 87          
 88      def forward(self, hr_img1, hr_img2):
 89          x_1, f2_1, f3_1, f4_1 = self.backbone(hr_img1)
 90          x_2, f2_2, f3_2, f4_2 = self.backbone(hr_img2)
 91  
 92          x1 = self.decoder(self.cbam5(x_1), self.cbam2(f2_1), self.cbam3(f3_1), self.cbam4(f4_1))
 93          x2 = self.decoder(self.cbam5(x_2), self.cbam2(f2_2), self.cbam3(f3_2), self.cbam4(f4_2))
 94          
 95          concat_feature = self.conv_final(torch.cat([x1, x2], dim=1))
 96                      
 97          output = torch.sigmoid(concat_feature)
 98          return output
 99  
100      def freeze_bn(self):
101          for m in self.modules():
102              if isinstance(m, nn.BatchNorm2d):
103                  m.eval()