/ pytlib / networks / conv_stack.py
conv_stack.py
 1  import torch.nn as nn
 2  import torch.nn.functional as F
 3  from torch.nn import ModuleList
 4  
 5  class ConvolutionStack(nn.Module):
 6      def __init__(self,in_chans,final_relu=True,padding=1):
 7          super(ConvolutionStack, self).__init__()
 8          self.convs = ModuleList()
 9          self.batchnorms = ModuleList()
10          self.in_chans = in_chans
11          self.final_relu = final_relu
12          self.padding = padding
13  
14      def append(self,out_chans,filter_size,stride):
15          if len(self.convs)==0:
16              self.convs.append(nn.Conv2d(self.in_chans, out_chans, filter_size, stride=stride, padding=self.padding))
17          else:
18              self.convs.append(nn.Conv2d(self.convs[-1].out_channels, out_chans, filter_size, stride=stride, padding=self.padding))
19          self.batchnorms.append(nn.BatchNorm2d(out_chans))
20  
21      def get_output_dims(self):
22          return self.output_dims
23  
24      def forward(self, x):
25          self.output_dims = []
26  
27          for i,c in enumerate(self.convs):
28              # lrelu = nn.LeakyReLU(0.2,inplace=True)
29              # x = lrelu(c(x))
30              x = c(x)
31              x = self.batchnorms[i](x)
32              if i<len(self.convs)-1 or self.final_relu:
33                  x = F.relu(x)
34              self.output_dims.append(x.size())
35          return x
36  
37  class TransposedConvolutionStack(nn.Module):
38      def __init__(self,in_chans,final_relu=True,padding=1):
39          super(TransposedConvolutionStack, self).__init__()
40          self.convs = ModuleList()
41          self.batchnorms = ModuleList()
42          self.in_chans = in_chans
43          self.output_dims = []
44          self.final_relu = final_relu
45          self.padding = padding
46  
47      def append(self,out_chans,filter_size,stride):
48          if len(self.convs)==0:
49              self.convs.append(nn.ConvTranspose2d(self.in_chans, out_chans, filter_size, stride=stride, padding=self.padding))
50          else:
51              self.convs.append(nn.ConvTranspose2d(self.convs[-1].out_channels, out_chans, filter_size, stride=stride, padding=self.padding))
52          self.batchnorms.append(nn.BatchNorm2d(out_chans))
53  
54      def forward(self, x, output_dims=[]):
55          # print self.convs
56          if output_dims:
57              assert len(output_dims)==len(self.convs), "number of output_dims must match number of tconvs!"
58          for i,c in enumerate(self.convs):
59              if not output_dims:
60                  output_dim = [x.shape[2:][0]*c.stride[0],x.shape[2:][1]*c.stride[1]]
61              else:
62                  output_dim = output_dims[i]
63              x = c(x,output_size=output_dim)
64              x = self.batchnorms[i](x)
65              if i<len(self.convs)-1 or self.final_relu:
66                  x = F.relu(x)
67          return x