/ pytlib / networks / attention_segmenter.py
attention_segmenter.py
  1  from __future__ import division
  2  from builtins import range
  3  from past.utils import old_div
  4  import torch
  5  import math
  6  import torch.nn as nn
  7  import torch.nn.functional as F
  8  from torch.nn import ModuleList
  9  from networks.basic_rnn import BasicRNN
 10  from networks.conv_stack import ConvolutionStack,TransposedConvolutionStack
 11  from networks.gaussian_attention_sampler import GaussianAttentionReader,GaussianAttentionWriter
 12  from visualization.image_visualizer import ImageVisualizer
 13  from image.ptimage import PTImage
 14  
 15  class AttentionSegmenter(nn.Module):
 16      def __init__(self,num_classes,inchans=3,att_encoding_size=128,timesteps=10,attn_grid_size=50):
 17          super(AttentionSegmenter, self).__init__()
 18          self.num_classes = num_classes
 19          self.att_encoding_size = att_encoding_size
 20          self.timesteps = timesteps
 21          self.attn_grid_size = attn_grid_size
 22          self.encoder = ConvolutionStack(inchans,final_relu=False,padding=0)
 23          self.encoder.append(32,3,1)
 24          self.encoder.append(32,3,2)
 25          self.encoder.append(64,3,1)
 26          self.encoder.append(64,3,2)
 27          self.encoder.append(96,3,1)
 28          self.encoder.append(96,3,2)
 29  
 30          self.decoder = TransposedConvolutionStack(96,final_relu=False,padding=0)
 31          self.decoder.append(96,3,2)
 32          self.decoder.append(64,3,1)
 33          self.decoder.append(64,3,2)
 34          self.decoder.append(32,3,1)
 35          self.decoder.append(32,3,2)
 36          self.decoder.append(self.num_classes,3,1)
 37  
 38          self.attn_reader = GaussianAttentionReader()
 39          self.attn_writer = GaussianAttentionWriter()
 40          self.att_rnn = BasicRNN(hstate_size=att_encoding_size,output_size=5)
 41          self.register_parameter('att_decoder_weights', None)
 42  
 43      def init_weights(self,hstate):
 44          if self.att_decoder_weights is None:
 45              batch_size = hstate.size(0)
 46              self.att_decoder_weights = nn.Parameter(torch.Tensor(5,old_div(hstate.nelement(),batch_size)))
 47              stdv = 1. / math.sqrt(self.att_decoder_weights.size(1))
 48              self.att_decoder_weights.data.uniform_(-stdv, stdv)
 49          if hstate.data.is_cuda:
 50              self.cuda()
 51  
 52      def forward(self, x):
 53          batch_size,chans,height,width = x.size()
 54  
 55          # need to first determine the hidden state size, which is tied to the cnn feature size
 56          dummy_glimpse = torch.Tensor(batch_size,chans,self.attn_grid_size,self.attn_grid_size)
 57          if x.is_cuda:
 58              dummy_glimpse = dummy_glimpse.cuda()
 59          dummy_feature_map = self.encoder.forward(dummy_glimpse)
 60          self.att_rnn.forward(dummy_feature_map.view(batch_size,old_div(dummy_feature_map.nelement(),batch_size)))
 61          self.att_rnn.reset_hidden_state(batch_size,x.data.is_cuda)
 62  
 63          outputs = []
 64          init_tensor = torch.zeros(batch_size,self.num_classes,height,width)
 65          if x.data.is_cuda:
 66              init_tensor = init_tensor.cuda()
 67          outputs.append(init_tensor) 
 68  
 69          self.init_weights(self.att_rnn.get_hidden_state())
 70  
 71          for t in range(self.timesteps):
 72              # 1) decode hidden state to generate gaussian attention parameters
 73              state = self.att_rnn.get_hidden_state()
 74              gauss_attn_params = torch.tanh(F.linear(state,self.att_decoder_weights))
 75  
 76              # 2) extract glimpse
 77              glimpse = self.attn_reader.forward(x,gauss_attn_params,self.attn_grid_size)
 78  
 79              # visualize first glimpse in batch for all t
 80              torch_glimpses = torch.chunk(glimpse,batch_size,dim=0)
 81              ImageVisualizer().set_image(PTImage.from_cwh_torch(torch_glimpses[0].squeeze().data),'zGlimpse {}'.format(t))            
 82  
 83              # 3) use conv stack or resnet to extract features
 84              feature_map = self.encoder.forward(glimpse)
 85              conv_output_dims = self.encoder.get_output_dims()[:-1][::-1]
 86              conv_output_dims.append(glimpse.size())
 87              # import ipdb;ipdb.set_trace()
 88  
 89              # 4) update hidden state # think about this connection a bit more
 90              self.att_rnn.forward(feature_map.view(batch_size,old_div(feature_map.nelement(),batch_size)))
 91  
 92              # 5) use deconv network to get partial masks
 93              partial_mask = self.decoder.forward(feature_map,conv_output_dims)
 94  
 95              # 6) write masks additively to mask canvas
 96              partial_canvas = self.attn_writer.forward(partial_mask,gauss_attn_params,(height,width))
 97              outputs.append(torch.add(outputs[-1],partial_canvas))
 98  
 99                  # return the sigmoided versions
100          for i in range(len(outputs)):
101              outputs[i] = torch.sigmoid(outputs[i])
102          return outputs