/ pytlib / loss_functions / vae_loss.py
vae_loss.py
 1  from builtins import range
 2  import torch.nn as nn
 3  import torch.nn.functional as F
 4  import torch
 5  from utils.logger import Logger
 6  from functools import reduce
 7  
 8  # KLD for two gaussians
 9  # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
10  # https://arxiv.org/abs/1312.6114
11  # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
12  def KLD_gaussian(mu,logvar):
13      return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
14  
15  def vae_loss(reconstruction,mu,logvar,targets):
16      # TODO, this should be the mse loss for a gaussian likelihood, (as opposed to BCE for bernouli, eg for MNIST)
17      BCE = F.binary_cross_entropy(reconstruction, targets)
18      KLD = KLD_gaussian(mu,logvar)
19      # Normalise by same number of elements as in reconstruction
20      # assume bchw format
21      total_elements = reduce((lambda x, y: x * y), reconstruction.size())
22      KLD /= total_elements
23      Logger().set('loss_component.variance_mean',logvar.exp().data.mean().item())
24      Logger().set('loss_component.mu_mean',mu.data.mean().item())
25      Logger().set('loss_component.reconstruction_mean',reconstruction.data.mean().item())
26      Logger().set('loss_component.reconstruction_std',reconstruction.data.std().item())
27      Logger().set('loss_component.KLD',KLD.data.cpu().item())
28      Logger().set('loss_component.BCE',BCE.data.cpu().item())
29      return BCE + KLD
30  
31  # for DRAW model https://arxiv.org/pdf/1502.04623.pdf
32  def sequence_vae_loss(recs,mus,logvars,target):
33      assert len(recs)>0 and len(mus)==len(logvars), "sequence_vae_loss: dimensions don't match"
34      # TODO, this should be the mse loss for a gaussian likelihood, (as opposed to BCE for bernouli, eg for MNIST)
35      BCE = F.binary_cross_entropy(recs[-1], target)
36      KLD = KLD_gaussian(mus[0],logvars[0])
37      for t in range(1,len(mus)):
38          KLD = torch.add(KLD, KLD_gaussian(mus[t],logvars[t]))
39  
40      total_elements = recs[-1].nelement()
41      total_elements *= len(mus)
42      KLD /= total_elements
43  
44      Logger().set('loss_component.reconstruction_mean',recs[-1].data.mean().item())
45      Logger().set('loss_component.reconstruction_std',recs[-1].data.std().item())    
46      Logger().set('loss_component.KLD',KLD.data.cpu().item())
47      Logger().set('loss_component.BCE',BCE.data.cpu().item())    
48      return BCE + KLD
49