CDNet_L.py
1 import sys 2 sys.path.append('/home/dvalsamis/Documents/projects/GAS-Net-main') 3 4 import torch 5 import torch.nn as nn 6 import torch.nn.functional as F 7 from gasnet.encoder import build_backbone 8 from gasnet.decoder import build_decoderGASN 9 from gasnet.data_utils import get_transform 10 import os 11 os.environ["CUDA_VISIBLE_DEVICES"]="1" 12 13 class ChannelAttention(nn.Module): 14 def __init__(self, in_planes, ratio=8): 15 super(ChannelAttention, self).__init__() 16 self.avg_pool = nn.AdaptiveAvgPool2d(1) 17 self.max_pool = nn.AdaptiveMaxPool2d(1) 18 19 self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False) 20 self.relu1 = nn.ReLU() 21 self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False) 22 self.sigmoid = nn.Sigmoid() 23 24 def forward(self, x): 25 avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 26 max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 27 out = avg_out + max_out 28 return self.sigmoid(out) 29 30 class SpatialAttention(nn.Module): 31 def __init__(self, kernel_size=3): 32 super(SpatialAttention, self).__init__() 33 34 assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 35 padding = 3 if kernel_size == 7 else 1 36 37 self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 38 self.sigmoid = nn.Sigmoid() 39 40 def forward(self, x): 41 avg_out = torch.mean(x, dim=1, keepdim=True) 42 max_out, _ = torch.max(x, dim=1, keepdim=True) 43 x = torch.cat([avg_out, max_out], dim=1) 44 x = self.conv1(x) 45 return self.sigmoid(x) 46 47 class CBAM(nn.Module): 48 def __init__(self, in_planes, ratio=8, kernel_size=3): 49 super(CBAM, self).__init__() 50 self.ca = ChannelAttention(in_planes, ratio) 51 self.sa = SpatialAttention(kernel_size) 52 def forward(self, x): 53 x = self.ca(x) * x 54 x = self.sa(x) * x 55 return x 56 57 class CDNet_L(nn.Module): 58 def __init__(self, backbone='resnet18', output_stride=16, f_c=64, freeze_bn=False, in_c=3): 59 super(CDNet_L, self).__init__() 60 BatchNorm = nn.BatchNorm2d 61 62 self.transform = get_transform(convert=True, normalize=True) 63 64 self.backbone = build_backbone(backbone, output_stride, BatchNorm, in_c) 65 self.decoder = build_decoderGASN(f_c, BatchNorm) 66 67 self.cbam0 = CBAM(64) 68 self.cbam1 = CBAM(64) 69 70 self.cbam2 = CBAM(64) 71 self.cbam3 = CBAM(128) 72 self.cbam4 = CBAM(256) 73 self.cbam5 = CBAM(512) 74 75 if freeze_bn: 76 self.freeze_bn() 77 78 # self.conv_up = nn.Upsample(scale_factor=2, mode='bilinear') 79 self.conv_final = nn.Sequential( 80 nn.Upsample(scale_factor=2, mode='bilinear'), 81 nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1), 82 nn.ReLU(), 83 nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1), 84 nn.ReLU(), 85 nn.Conv2d(in_channels=32, out_channels=1, kernel_size=1, padding=0), 86 ) 87 88 def forward(self, hr_img1, hr_img2): 89 x_1, f2_1, f3_1, f4_1 = self.backbone(hr_img1) 90 x_2, f2_2, f3_2, f4_2 = self.backbone(hr_img2) 91 92 x1 = self.decoder(self.cbam5(x_1), self.cbam2(f2_1), self.cbam3(f3_1), self.cbam4(f4_1)) 93 x2 = self.decoder(self.cbam5(x_2), self.cbam2(f2_2), self.cbam3(f3_2), self.cbam4(f4_2)) 94 95 concat_feature = self.conv_final(torch.cat([x1, x2], dim=1)) 96 97 output = torch.sigmoid(concat_feature) 98 return output 99 100 def freeze_bn(self): 101 for m in self.modules(): 102 if isinstance(m, nn.BatchNorm2d): 103 m.eval()