vae_loss.py
1 from builtins import range 2 import torch.nn as nn 3 import torch.nn.functional as F 4 import torch 5 from utils.logger import Logger 6 from functools import reduce 7 8 # KLD for two gaussians 9 # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 10 # https://arxiv.org/abs/1312.6114 11 # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 12 def KLD_gaussian(mu,logvar): 13 return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 14 15 def vae_loss(reconstruction,mu,logvar,targets): 16 # TODO, this should be the mse loss for a gaussian likelihood, (as opposed to BCE for bernouli, eg for MNIST) 17 BCE = F.binary_cross_entropy(reconstruction, targets) 18 KLD = KLD_gaussian(mu,logvar) 19 # Normalise by same number of elements as in reconstruction 20 # assume bchw format 21 total_elements = reduce((lambda x, y: x * y), reconstruction.size()) 22 KLD /= total_elements 23 Logger().set('loss_component.variance_mean',logvar.exp().data.mean().item()) 24 Logger().set('loss_component.mu_mean',mu.data.mean().item()) 25 Logger().set('loss_component.reconstruction_mean',reconstruction.data.mean().item()) 26 Logger().set('loss_component.reconstruction_std',reconstruction.data.std().item()) 27 Logger().set('loss_component.KLD',KLD.data.cpu().item()) 28 Logger().set('loss_component.BCE',BCE.data.cpu().item()) 29 return BCE + KLD 30 31 # for DRAW model https://arxiv.org/pdf/1502.04623.pdf 32 def sequence_vae_loss(recs,mus,logvars,target): 33 assert len(recs)>0 and len(mus)==len(logvars), "sequence_vae_loss: dimensions don't match" 34 # TODO, this should be the mse loss for a gaussian likelihood, (as opposed to BCE for bernouli, eg for MNIST) 35 BCE = F.binary_cross_entropy(recs[-1], target) 36 KLD = KLD_gaussian(mus[0],logvars[0]) 37 for t in range(1,len(mus)): 38 KLD = torch.add(KLD, KLD_gaussian(mus[t],logvars[t])) 39 40 total_elements = recs[-1].nelement() 41 total_elements *= len(mus) 42 KLD /= total_elements 43 44 Logger().set('loss_component.reconstruction_mean',recs[-1].data.mean().item()) 45 Logger().set('loss_component.reconstruction_std',recs[-1].data.std().item()) 46 Logger().set('loss_component.KLD',KLD.data.cpu().item()) 47 Logger().set('loss_component.BCE',BCE.data.cpu().item()) 48 return BCE + KLD 49