draw.py
1 from __future__ import division 2 from builtins import range 3 from past.utils import old_div 4 from networks.basic_rnn import BasicRNN 5 import torch.nn as nn 6 import torch.nn.functional as F 7 import torch 8 import math 9 10 class DRAW(nn.Module): 11 def __init__(self,q_size=10, 12 encoding_size=128, 13 timesteps=10, 14 training=True, 15 use_attention=False, 16 grid_size=5): 17 super(DRAW, self).__init__() 18 self.training = training 19 self.encoding_size = encoding_size 20 self.q_size = q_size 21 self.use_attention = use_attention 22 self.timesteps = timesteps 23 # use equal encoding and decoding size 24 self.encoder_rnn = BasicRNN(hstate_size=self.encoding_size,output_size=self.encoding_size) 25 self.decoder_rnn = BasicRNN(hstate_size=self.encoding_size,output_size=self.encoding_size) 26 self.register_parameter('decoder_linear_weights', None) 27 self.register_parameter('encoding_mu_weights', None) 28 self.register_parameter('encoding_logvar_weights', None) 29 self.filter_linear_layer = nn.Linear(self.encoding_size,5) 30 self.grid_size = grid_size 31 self.minclamp = 1e-8 32 self.maxclamp = 1e8 33 34 def initialize(self,x): 35 batch_size = x.size(0) 36 # we use attention, the decoder producers a patch of grid_size x grid_size 37 # else it produces an output of the original image size 38 if self.use_attention: 39 self.decoder_linear_weights = nn.Parameter(torch.Tensor(self.grid_size*self.grid_size,self.encoding_size)) 40 else: 41 self.decoder_linear_weights = nn.Parameter(torch.Tensor(old_div(x.nelement(),batch_size),self.encoding_size)) 42 43 stdv = 1. / math.sqrt(self.decoder_linear_weights.size(1)) 44 self.decoder_linear_weights.data.uniform_(-stdv, stdv) 45 46 self.encoding_mu_weights = nn.Parameter(torch.Tensor(self.q_size,self.encoding_size)) 47 stdv = 1. / math.sqrt(self.encoding_mu_weights.size(1)) 48 self.encoding_mu_weights.data.uniform_(-stdv, stdv) 49 50 self.encoding_logvar_weights = nn.Parameter(torch.Tensor(self.q_size,self.encoding_size)) 51 stdv = 1. / math.sqrt(self.encoding_logvar_weights.size(1)) 52 self.encoding_logvar_weights.data.uniform_(-stdv, stdv) 53 if x.data.is_cuda: 54 self.cuda() 55 56 # selects where to sample from the input image, no attention version 57 # dims is 2*W*H 58 def read(self,x,x_hat,dec_state): 59 return torch.cat((x,x_hat),1) 60 61 # generate two sets of filterbanks 62 # 1) batch x N x W (Fx) 63 # 2) batch x N x H (Fy) 64 def generate_filter_matrices(self,gx,gy,sigma2,delta): 65 N = self.grid_size 66 grid_points = torch.arange(0,N).view((1,N,1)) 67 a = torch.arange(0,self.image_w).view((1,1,-1)) 68 b = torch.arange(0,self.image_h).view((1,1,-1)) 69 if gx.data.is_cuda: 70 grid_points = grid_points.cuda() 71 a = a.cuda() 72 b = b.cuda() 73 74 # gx is Bx1, grid is (1xNx1), so this is a broadcast op -> BxNx1 75 mux = gx.view((-1,1,1)) + (grid_points.float() - old_div(N,2) - 0.5) * delta.view((-1,1,1)) 76 muy = gy.view((-1,1,1)) + (grid_points.float() - old_div(N,2) - 0.5) * delta.view((-1,1,1)) 77 78 s2 = sigma2.view((-1,1,1)) 79 fx = torch.exp(old_div(-(a.float()-mux).pow(2),(2*s2))) 80 fy = torch.exp(old_div(-(b.float()-muy).pow(2),(2*s2))) 81 # normalize 82 fx = old_div(fx,torch.clamp(torch.sum(fx,2,keepdim=True),self.minclamp,self.maxclamp)) 83 fy = old_div(fy,torch.clamp(torch.sum(fy,2,keepdim=True),self.minclamp,self.maxclamp)) 84 return fx,fy 85 86 def generate_filter_params(self,state): 87 filter_vector = self.filter_linear_layer(state) 88 _gx,_gy,log_sigma2,log_delta,loggamma = filter_vector.split(1,1) 89 gx=old_div((self.image_w+1),2)*(_gx+1) 90 gy=old_div((self.image_h+1),2)*(_gy+1) 91 sigma2=torch.exp(log_sigma2) 92 delta=old_div((max(self.image_w,self.image_h)-1),(self.grid_size-1))*torch.exp(log_delta) 93 gamma=torch.exp(loggamma) 94 return gx,gy,sigma2,delta,gamma 95 96 def read_w_att(self,x,x_hat,dec_state): 97 batch_size = x.size()[0] 98 99 # 1) linear to convert dec_state into batchx5 params gx,gy,logsigma2,logdelta,loggamma 100 # 2) convert to gaussian parameters 101 gx,gy,sigma2,delta,gamma = self.generate_filter_params(dec_state) 102 103 # 3) generate filter matrices 104 fx,fy = self.generate_filter_matrices(gx,gy,sigma2,delta) 105 106 # 4) apply filter matrices to get glimpses 107 output = gamma.view(-1,1,1)*torch.bmm(torch.bmm(fy,x.view(batch_size,self.image_h,self.image_w)),torch.transpose(fx,1,2)) 108 output_hat = gamma.view(-1,1,1)*torch.bmm(torch.bmm(fy,x_hat.view(batch_size,self.image_h,self.image_w)),torch.transpose(fx,1,2)) 109 output_total = torch.cat((output.view(batch_size,self.grid_size*self.grid_size),output_hat.view(batch_size,self.grid_size*self.grid_size)),1) 110 return output_total 111 112 # write takes use from "encoding space" to image space 113 def write(self,decoding): 114 return F.linear(decoding,self.decoder_linear_weights) 115 116 def write_w_att(self,decoding): 117 batch_size = decoding.size()[0] 118 write_patch = F.linear(decoding,self.decoder_linear_weights).view(batch_size,self.grid_size,self.grid_size) 119 gx,gy,sigma2,gamma,delta = self.generate_filter_params(decoding) 120 fx,fy = self.generate_filter_matrices(gx,gy,sigma2,delta) 121 output = (old_div(1,gamma)).view(-1,1,1)*torch.bmm(torch.bmm(fy.transpose(1,2),write_patch),fx) 122 return output 123 124 # this converts the encoding into both a mu and logvar vector 125 def sampleZ(self,encoding): 126 mu = F.linear(encoding,self.encoding_mu_weights) 127 logvar = F.linear(encoding,self.encoding_logvar_weights) 128 return self.reparameterize(mu, logvar),mu,logvar 129 130 def reparameterize(self, mu, logvar): 131 if self.training: 132 std = logvar.mul(0.5).exp_() 133 eps = std.data.new(std.size()).normal_() 134 return eps.mul(std).add_(mu) 135 else: 136 return mu 137 138 # takes an input, returns the sequence of outputs, mus, and logvars 139 def forward(self,x): 140 # flatten x to 1-d, except for batch dimension 141 xview = x.view(x.size()[0],old_div(x.nelement(),x.size()[0])) 142 # assume bchw dims 143 self.image_w = x.size(3) 144 self.image_h = x.size(2) 145 batch_size = x.size()[0] 146 147 if self.decoder_linear_weights is None: 148 self.initialize(xview) 149 150 # zero out initial states 151 self.encoder_rnn.reset_hidden_state(batch_size,x.data.is_cuda) 152 self.decoder_rnn.reset_hidden_state(batch_size,x.data.is_cuda) 153 outputs,mus,logvars = [],[],[] 154 155 init_tensor = torch.zeros(x.size()) 156 if x.data.is_cuda: 157 init_tensor = init_tensor.cuda() 158 outputs.append(init_tensor) 159 160 if self.use_attention: 161 read_fn = self.read_w_att 162 write_fn = self.write_w_att 163 else: 164 read_fn = self.read 165 write_fn = self.write 166 167 for t in range(0,self.timesteps): 168 # import ipdb;ipdb.set_trace() 169 # Step 1: diff the input against the prev output 170 x_hat = xview - torch.sigmoid(outputs[t].view(xview.size())) 171 # Step 2: read 172 rvec = read_fn(xview,x_hat,self.decoder_rnn.get_hidden_state()) 173 # Step 3: encoder rnn 174 # note the dimensions of r doesn't have to match with the decoding size because 175 # we are just concating 2 dim-1 tensors, which is kind of wierd, but ok... 176 cat = torch.cat((rvec,self.decoder_rnn.get_hidden_state().view(batch_size,self.encoding_size)),1) 177 encoding = self.encoder_rnn.forward(cat) 178 # Step 4: sample z 179 z,mu,logvar = self.sampleZ(encoding) 180 # store the mu and logvar for the loss function 181 mus.append(mu) 182 logvars.append(logvar) 183 184 # Step 5: decoder rnn 185 decoding = self.decoder_rnn.forward(z) 186 # Step 6: write to canvas, (in the original dimensions of the input) 187 outputs.append(torch.add(outputs[-1],write_fn(decoding).view(x.size()))) 188 189 # return the sigmoided versions 190 for i in range(len(outputs)): 191 outputs[i] = torch.sigmoid(outputs[i]) 192 return outputs, mus, logvars 193