conv_stack.py
1 import torch.nn as nn 2 import torch.nn.functional as F 3 from torch.nn import ModuleList 4 5 class ConvolutionStack(nn.Module): 6 def __init__(self,in_chans,final_relu=True,padding=1): 7 super(ConvolutionStack, self).__init__() 8 self.convs = ModuleList() 9 self.batchnorms = ModuleList() 10 self.in_chans = in_chans 11 self.final_relu = final_relu 12 self.padding = padding 13 14 def append(self,out_chans,filter_size,stride): 15 if len(self.convs)==0: 16 self.convs.append(nn.Conv2d(self.in_chans, out_chans, filter_size, stride=stride, padding=self.padding)) 17 else: 18 self.convs.append(nn.Conv2d(self.convs[-1].out_channels, out_chans, filter_size, stride=stride, padding=self.padding)) 19 self.batchnorms.append(nn.BatchNorm2d(out_chans)) 20 21 def get_output_dims(self): 22 return self.output_dims 23 24 def forward(self, x): 25 self.output_dims = [] 26 27 for i,c in enumerate(self.convs): 28 # lrelu = nn.LeakyReLU(0.2,inplace=True) 29 # x = lrelu(c(x)) 30 x = c(x) 31 x = self.batchnorms[i](x) 32 if i<len(self.convs)-1 or self.final_relu: 33 x = F.relu(x) 34 self.output_dims.append(x.size()) 35 return x 36 37 class TransposedConvolutionStack(nn.Module): 38 def __init__(self,in_chans,final_relu=True,padding=1): 39 super(TransposedConvolutionStack, self).__init__() 40 self.convs = ModuleList() 41 self.batchnorms = ModuleList() 42 self.in_chans = in_chans 43 self.output_dims = [] 44 self.final_relu = final_relu 45 self.padding = padding 46 47 def append(self,out_chans,filter_size,stride): 48 if len(self.convs)==0: 49 self.convs.append(nn.ConvTranspose2d(self.in_chans, out_chans, filter_size, stride=stride, padding=self.padding)) 50 else: 51 self.convs.append(nn.ConvTranspose2d(self.convs[-1].out_channels, out_chans, filter_size, stride=stride, padding=self.padding)) 52 self.batchnorms.append(nn.BatchNorm2d(out_chans)) 53 54 def forward(self, x, output_dims=[]): 55 # print self.convs 56 if output_dims: 57 assert len(output_dims)==len(self.convs), "number of output_dims must match number of tconvs!" 58 for i,c in enumerate(self.convs): 59 if not output_dims: 60 output_dim = [x.shape[2:][0]*c.stride[0],x.shape[2:][1]*c.stride[1]] 61 else: 62 output_dim = output_dims[i] 63 x = c(x,output_size=output_dim) 64 x = self.batchnorms[i](x) 65 if i<len(self.convs)-1 or self.final_relu: 66 x = F.relu(x) 67 return x