/ pytlib / networks / draw.py
draw.py
  1  from __future__ import division
  2  from builtins import range
  3  from past.utils import old_div
  4  from networks.basic_rnn import BasicRNN
  5  import torch.nn as nn
  6  import torch.nn.functional as F
  7  import torch
  8  import math
  9  
 10  class DRAW(nn.Module):
 11      def __init__(self,q_size=10,
 12                        encoding_size=128,
 13                        timesteps=10,
 14                        training=True,
 15                        use_attention=False,
 16                        grid_size=5):
 17          super(DRAW, self).__init__()
 18          self.training = training
 19          self.encoding_size = encoding_size
 20          self.q_size = q_size
 21          self.use_attention = use_attention
 22          self.timesteps = timesteps
 23          # use equal encoding and decoding size
 24          self.encoder_rnn = BasicRNN(hstate_size=self.encoding_size,output_size=self.encoding_size)
 25          self.decoder_rnn = BasicRNN(hstate_size=self.encoding_size,output_size=self.encoding_size)
 26          self.register_parameter('decoder_linear_weights', None)
 27          self.register_parameter('encoding_mu_weights', None)
 28          self.register_parameter('encoding_logvar_weights', None)
 29          self.filter_linear_layer = nn.Linear(self.encoding_size,5)
 30          self.grid_size = grid_size
 31          self.minclamp = 1e-8
 32          self.maxclamp = 1e8
 33  
 34      def initialize(self,x):
 35          batch_size = x.size(0)
 36          # we use attention, the decoder producers a patch of grid_size x grid_size
 37          # else it produces an output of the original image size
 38          if self.use_attention:
 39              self.decoder_linear_weights = nn.Parameter(torch.Tensor(self.grid_size*self.grid_size,self.encoding_size))
 40          else:
 41              self.decoder_linear_weights = nn.Parameter(torch.Tensor(old_div(x.nelement(),batch_size),self.encoding_size))
 42          
 43          stdv = 1. / math.sqrt(self.decoder_linear_weights.size(1))
 44          self.decoder_linear_weights.data.uniform_(-stdv, stdv)        
 45          
 46          self.encoding_mu_weights = nn.Parameter(torch.Tensor(self.q_size,self.encoding_size))
 47          stdv = 1. / math.sqrt(self.encoding_mu_weights.size(1))
 48          self.encoding_mu_weights.data.uniform_(-stdv, stdv)   
 49  
 50          self.encoding_logvar_weights = nn.Parameter(torch.Tensor(self.q_size,self.encoding_size))
 51          stdv = 1. / math.sqrt(self.encoding_logvar_weights.size(1))
 52          self.encoding_logvar_weights.data.uniform_(-stdv, stdv)   
 53          if x.data.is_cuda:
 54              self.cuda()
 55  
 56      # selects where to sample from the input image, no attention version
 57      # dims is 2*W*H
 58      def read(self,x,x_hat,dec_state):
 59          return torch.cat((x,x_hat),1)
 60  
 61      # generate two sets of filterbanks 
 62      # 1) batch x N x W (Fx)
 63      # 2) batch x N x H (Fy)
 64      def generate_filter_matrices(self,gx,gy,sigma2,delta):
 65          N = self.grid_size
 66          grid_points = torch.arange(0,N).view((1,N,1))
 67          a = torch.arange(0,self.image_w).view((1,1,-1))
 68          b = torch.arange(0,self.image_h).view((1,1,-1))
 69          if gx.data.is_cuda:
 70              grid_points = grid_points.cuda()
 71              a = a.cuda()
 72              b = b.cuda()
 73  
 74          # gx is Bx1, grid is (1xNx1), so this is a broadcast op -> BxNx1
 75          mux = gx.view((-1,1,1)) + (grid_points.float() - old_div(N,2) - 0.5) * delta.view((-1,1,1))
 76          muy = gy.view((-1,1,1)) + (grid_points.float() - old_div(N,2) - 0.5) * delta.view((-1,1,1))
 77  
 78          s2 = sigma2.view((-1,1,1))
 79          fx = torch.exp(old_div(-(a.float()-mux).pow(2),(2*s2)))
 80          fy = torch.exp(old_div(-(b.float()-muy).pow(2),(2*s2)))
 81          # normalize
 82          fx = old_div(fx,torch.clamp(torch.sum(fx,2,keepdim=True),self.minclamp,self.maxclamp))
 83          fy = old_div(fy,torch.clamp(torch.sum(fy,2,keepdim=True),self.minclamp,self.maxclamp))
 84          return fx,fy
 85  
 86      def generate_filter_params(self,state):
 87          filter_vector = self.filter_linear_layer(state)
 88          _gx,_gy,log_sigma2,log_delta,loggamma = filter_vector.split(1,1)
 89          gx=old_div((self.image_w+1),2)*(_gx+1)
 90          gy=old_div((self.image_h+1),2)*(_gy+1)
 91          sigma2=torch.exp(log_sigma2)
 92          delta=old_div((max(self.image_w,self.image_h)-1),(self.grid_size-1))*torch.exp(log_delta)       
 93          gamma=torch.exp(loggamma)
 94          return gx,gy,sigma2,delta,gamma        
 95  
 96      def read_w_att(self,x,x_hat,dec_state):
 97          batch_size = x.size()[0]
 98  
 99          # 1) linear to convert dec_state into batchx5 params gx,gy,logsigma2,logdelta,loggamma
100          # 2) convert to gaussian parameters
101          gx,gy,sigma2,delta,gamma = self.generate_filter_params(dec_state)
102  
103          # 3) generate filter matrices
104          fx,fy = self.generate_filter_matrices(gx,gy,sigma2,delta)
105  
106          # 4) apply filter matrices to get glimpses
107          output = gamma.view(-1,1,1)*torch.bmm(torch.bmm(fy,x.view(batch_size,self.image_h,self.image_w)),torch.transpose(fx,1,2))
108          output_hat = gamma.view(-1,1,1)*torch.bmm(torch.bmm(fy,x_hat.view(batch_size,self.image_h,self.image_w)),torch.transpose(fx,1,2))
109          output_total = torch.cat((output.view(batch_size,self.grid_size*self.grid_size),output_hat.view(batch_size,self.grid_size*self.grid_size)),1)
110          return output_total
111  
112      # write takes use from "encoding space" to image space
113      def write(self,decoding):
114          return F.linear(decoding,self.decoder_linear_weights)
115  
116      def write_w_att(self,decoding):
117          batch_size = decoding.size()[0]
118          write_patch = F.linear(decoding,self.decoder_linear_weights).view(batch_size,self.grid_size,self.grid_size)
119          gx,gy,sigma2,gamma,delta = self.generate_filter_params(decoding)
120          fx,fy = self.generate_filter_matrices(gx,gy,sigma2,delta)
121          output = (old_div(1,gamma)).view(-1,1,1)*torch.bmm(torch.bmm(fy.transpose(1,2),write_patch),fx)
122          return output
123  
124      # this converts the encoding into both a mu and logvar vector
125      def sampleZ(self,encoding):
126          mu = F.linear(encoding,self.encoding_mu_weights)
127          logvar = F.linear(encoding,self.encoding_logvar_weights)
128          return self.reparameterize(mu, logvar),mu,logvar
129  
130      def reparameterize(self, mu, logvar):
131          if self.training:
132            std = logvar.mul(0.5).exp_()
133            eps = std.data.new(std.size()).normal_()
134            return eps.mul(std).add_(mu)
135          else:
136            return mu
137  
138      # takes an input, returns the sequence of outputs, mus, and logvars
139      def forward(self,x):
140          # flatten x to 1-d, except for batch dimension
141          xview = x.view(x.size()[0],old_div(x.nelement(),x.size()[0]))
142          # assume bchw dims
143          self.image_w = x.size(3)
144          self.image_h = x.size(2)
145          batch_size = x.size()[0]
146  
147          if self.decoder_linear_weights is None:
148              self.initialize(xview)
149  
150          # zero out initial states
151          self.encoder_rnn.reset_hidden_state(batch_size,x.data.is_cuda)
152          self.decoder_rnn.reset_hidden_state(batch_size,x.data.is_cuda)
153          outputs,mus,logvars = [],[],[]
154  
155          init_tensor = torch.zeros(x.size())
156          if x.data.is_cuda:
157              init_tensor = init_tensor.cuda()
158          outputs.append(init_tensor)
159  
160          if self.use_attention:
161              read_fn = self.read_w_att
162              write_fn = self.write_w_att
163          else:
164              read_fn = self.read
165              write_fn = self.write
166  
167          for t in range(0,self.timesteps):
168              # import ipdb;ipdb.set_trace()
169              # Step 1: diff the input against the prev output
170              x_hat = xview - torch.sigmoid(outputs[t].view(xview.size()))
171              # Step 2: read
172              rvec = read_fn(xview,x_hat,self.decoder_rnn.get_hidden_state())
173              # Step 3: encoder rnn
174              # note the dimensions of r doesn't have to match with the decoding size because
175              # we are just concating 2 dim-1 tensors, which is kind of wierd, but ok...
176              cat = torch.cat((rvec,self.decoder_rnn.get_hidden_state().view(batch_size,self.encoding_size)),1)
177              encoding = self.encoder_rnn.forward(cat)
178              # Step 4: sample z
179              z,mu,logvar = self.sampleZ(encoding)
180              # store the mu and logvar for the loss function
181              mus.append(mu)
182              logvars.append(logvar)
183  
184              # Step 5: decoder rnn
185              decoding = self.decoder_rnn.forward(z)
186              # Step 6: write to canvas, (in the original dimensions of the input)
187              outputs.append(torch.add(outputs[-1],write_fn(decoding).view(x.size())))
188  
189          # return the sigmoided versions
190          for i in range(len(outputs)):
191              outputs[i] = torch.sigmoid(outputs[i])
192          return outputs, mus, logvars
193