unet.py
1 # full assembly of the sub-parts to form the complete net 2 # code from https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py 3 4 import torch.nn.functional as F 5 import torch.nn as nn 6 from networks.unet_parts import up, down, inconv, outconv 7 8 # shallow unet 9 class UNet(nn.Module): 10 def __init__(self, n_channels, n_classes): 11 super(UNet, self).__init__() 12 self.inc = inconv(n_channels, 32) 13 self.down1 = down(32, 64) 14 self.down2 = down(64, 128) 15 self.down3 = down(128, 128) 16 self.up1 = up(256, 64) 17 self.up2 = up(128, 32) 18 self.up3 = up(64, 32) 19 self.outc = outconv(32, n_classes) 20 21 def feature_channels(self): 22 return 128 23 24 def forward(self, x): 25 x1 = self.inc(x) 26 x2 = self.down1(x1) 27 x3 = self.down2(x2) 28 x4 = self.down3(x3) 29 x = self.up1(x4, x3) 30 x = self.up2(x, x2) 31 x = self.up3(x, x1) 32 x = self.outc(x) 33 return x, x4