triplet_correlation_loss.py
1 from __future__ import division 2 from builtins import range 3 from past.utils import old_div 4 import torch 5 import torch.nn.functional as F 6 from utils.logger import Logger 7 from loss_functions.response_map_loss import response_map_loss 8 from loss_functions.vae_loss import vae_loss 9 import numpy as np 10 11 # TODO, alot of these ops could be inplaced to reduce memory use and improve compute 12 def pearson_correlation_loss(x1,x2,eps=1e-6): 13 assert len(x1.size()) == 1 and len(x2.size()) ==1 , 'Sizes must be the same and one-dimensional' 14 m1 = torch.mean(x1,0,keepdim=True) 15 m2 = torch.mean(x2,0,keepdim=True) 16 c1 = (torch.add(x1,-m1)).div(m1+eps) 17 c2 = (torch.add(x2,-m2)).div(m2+eps) 18 numerator = torch.sum(c1.mul(c2),0,keepdim=True) 19 denom1 = torch.sqrt(torch.sum(c1.mul(c1),0,keepdim=True)) 20 denom2 = torch.sqrt(torch.sum(c2.mul(c2),0,keepdim=True)) 21 cor = numerator.div( torch.add(denom1.mul(denom2),eps)) 22 return cor 23 24 def triplet_correlation_loss(anchor,pos,neg,dummy_target,margin=1.0,eps=1e-6): 25 # for now, debatch this computation, to batch properly need to figure out how to broadcast in torch 26 batch_size = anchor.size(0) 27 pcps,pcns = [],[] 28 ne = old_div(anchor.nelement(),batch_size) 29 for i in range(0,batch_size): 30 pcps.append(pearson_correlation_loss(anchor[i,:].view(ne),pos[i,:].view(ne))) 31 pcns.append(pearson_correlation_loss(anchor[i,:].view(ne),neg[i,:].view(ne))) 32 pcp = torch.stack(pcps,0) 33 pcn = torch.stack(pcns,0) 34 dist_hinge = torch.clamp(margin + pcn - pcp, min=0.0) 35 loss = torch.mean(dist_hinge) 36 return loss 37 38 def triplet_correlation_loss2(anchor,pcp,pcn,recon_output,mu,logvar,pos_map,neg_map,recon_target): 39 # down sample the map (or upsample the response) 40 # use avgpool + rounding 41 # take the ratio of the spatial extends 42 43 vloss = vae_loss(recon_output,mu,logvar,recon_target) 44 pool_kernel = old_div(np.array(pos_map.size()[1:]),np.array(pcp.size()[-2:])).astype(np.int) 45 pos_map_resized = torch.round(F.avg_pool2d(pos_map,tuple(pool_kernel))) 46 neg_map_resized = torch.round(F.avg_pool2d(neg_map,tuple(pool_kernel))) 47 48 nloss = F.binary_cross_entropy_with_logits(pcn.squeeze(),neg_map_resized.squeeze()) 49 ploss = F.binary_cross_entropy_with_logits(pcp.squeeze(),pos_map_resized.squeeze()) 50 # nloss = response_map_loss(pcn.squeeze()) 51 # ploss = response_map_loss(pcp.squeeze(),boxes) 52 Logger().set('loss_component.anchor_mean',anchor.data.mean().item()) 53 Logger().set('loss_component.anchor_std',anchor.data.std().item()) 54 Logger().set('loss_component.ploss2',ploss.data.cpu().item()) 55 Logger().set('loss_component.nloss2',nloss.data.cpu().item()) 56 Logger().set('loss_component.vloss',vloss.data.cpu().item()) 57 return ploss+nloss+vloss