attention_segmenter.py
1 from __future__ import division 2 from builtins import range 3 from past.utils import old_div 4 import torch 5 import math 6 import torch.nn as nn 7 import torch.nn.functional as F 8 from torch.nn import ModuleList 9 from networks.basic_rnn import BasicRNN 10 from networks.conv_stack import ConvolutionStack,TransposedConvolutionStack 11 from networks.gaussian_attention_sampler import GaussianAttentionReader,GaussianAttentionWriter 12 from visualization.image_visualizer import ImageVisualizer 13 from image.ptimage import PTImage 14 15 class AttentionSegmenter(nn.Module): 16 def __init__(self,num_classes,inchans=3,att_encoding_size=128,timesteps=10,attn_grid_size=50): 17 super(AttentionSegmenter, self).__init__() 18 self.num_classes = num_classes 19 self.att_encoding_size = att_encoding_size 20 self.timesteps = timesteps 21 self.attn_grid_size = attn_grid_size 22 self.encoder = ConvolutionStack(inchans,final_relu=False,padding=0) 23 self.encoder.append(32,3,1) 24 self.encoder.append(32,3,2) 25 self.encoder.append(64,3,1) 26 self.encoder.append(64,3,2) 27 self.encoder.append(96,3,1) 28 self.encoder.append(96,3,2) 29 30 self.decoder = TransposedConvolutionStack(96,final_relu=False,padding=0) 31 self.decoder.append(96,3,2) 32 self.decoder.append(64,3,1) 33 self.decoder.append(64,3,2) 34 self.decoder.append(32,3,1) 35 self.decoder.append(32,3,2) 36 self.decoder.append(self.num_classes,3,1) 37 38 self.attn_reader = GaussianAttentionReader() 39 self.attn_writer = GaussianAttentionWriter() 40 self.att_rnn = BasicRNN(hstate_size=att_encoding_size,output_size=5) 41 self.register_parameter('att_decoder_weights', None) 42 43 def init_weights(self,hstate): 44 if self.att_decoder_weights is None: 45 batch_size = hstate.size(0) 46 self.att_decoder_weights = nn.Parameter(torch.Tensor(5,old_div(hstate.nelement(),batch_size))) 47 stdv = 1. / math.sqrt(self.att_decoder_weights.size(1)) 48 self.att_decoder_weights.data.uniform_(-stdv, stdv) 49 if hstate.data.is_cuda: 50 self.cuda() 51 52 def forward(self, x): 53 batch_size,chans,height,width = x.size() 54 55 # need to first determine the hidden state size, which is tied to the cnn feature size 56 dummy_glimpse = torch.Tensor(batch_size,chans,self.attn_grid_size,self.attn_grid_size) 57 if x.is_cuda: 58 dummy_glimpse = dummy_glimpse.cuda() 59 dummy_feature_map = self.encoder.forward(dummy_glimpse) 60 self.att_rnn.forward(dummy_feature_map.view(batch_size,old_div(dummy_feature_map.nelement(),batch_size))) 61 self.att_rnn.reset_hidden_state(batch_size,x.data.is_cuda) 62 63 outputs = [] 64 init_tensor = torch.zeros(batch_size,self.num_classes,height,width) 65 if x.data.is_cuda: 66 init_tensor = init_tensor.cuda() 67 outputs.append(init_tensor) 68 69 self.init_weights(self.att_rnn.get_hidden_state()) 70 71 for t in range(self.timesteps): 72 # 1) decode hidden state to generate gaussian attention parameters 73 state = self.att_rnn.get_hidden_state() 74 gauss_attn_params = torch.tanh(F.linear(state,self.att_decoder_weights)) 75 76 # 2) extract glimpse 77 glimpse = self.attn_reader.forward(x,gauss_attn_params,self.attn_grid_size) 78 79 # visualize first glimpse in batch for all t 80 torch_glimpses = torch.chunk(glimpse,batch_size,dim=0) 81 ImageVisualizer().set_image(PTImage.from_cwh_torch(torch_glimpses[0].squeeze().data),'zGlimpse {}'.format(t)) 82 83 # 3) use conv stack or resnet to extract features 84 feature_map = self.encoder.forward(glimpse) 85 conv_output_dims = self.encoder.get_output_dims()[:-1][::-1] 86 conv_output_dims.append(glimpse.size()) 87 # import ipdb;ipdb.set_trace() 88 89 # 4) update hidden state # think about this connection a bit more 90 self.att_rnn.forward(feature_map.view(batch_size,old_div(feature_map.nelement(),batch_size))) 91 92 # 5) use deconv network to get partial masks 93 partial_mask = self.decoder.forward(feature_map,conv_output_dims) 94 95 # 6) write masks additively to mask canvas 96 partial_canvas = self.attn_writer.forward(partial_mask,gauss_attn_params,(height,width)) 97 outputs.append(torch.add(outputs[-1],partial_canvas)) 98 99 # return the sigmoided versions 100 for i in range(len(outputs)): 101 outputs[i] = torch.sigmoid(outputs[i]) 102 return outputs