/ pytlib / networks / unet.py
unet.py
 1  # full assembly of the sub-parts to form the complete net
 2  # code from https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py
 3  
 4  import torch.nn.functional as F
 5  import torch.nn as nn
 6  from networks.unet_parts import up, down, inconv, outconv
 7  
 8  # shallow unet 
 9  class UNet(nn.Module):
10      def __init__(self, n_channels, n_classes):
11          super(UNet, self).__init__()
12          self.inc = inconv(n_channels, 32)
13          self.down1 = down(32, 64)
14          self.down2 = down(64, 128)
15          self.down3 = down(128, 128)
16          self.up1 = up(256, 64)
17          self.up2 = up(128, 32)
18          self.up3 = up(64, 32)
19          self.outc = outconv(32, n_classes)
20  
21      def feature_channels(self):
22          return 128
23  
24      def forward(self, x):
25          x1 = self.inc(x)
26          x2 = self.down1(x1)
27          x3 = self.down2(x2)
28          x4 = self.down3(x3)
29          x = self.up1(x4, x3)
30          x = self.up2(x, x2)
31          x = self.up3(x, x1)
32          x = self.outc(x)        
33          return x, x4