/ pytlib / loss_functions / triplet_correlation_loss.py
triplet_correlation_loss.py
 1  from __future__ import division
 2  from builtins import range
 3  from past.utils import old_div
 4  import torch
 5  import torch.nn.functional as F
 6  from utils.logger import Logger
 7  from loss_functions.response_map_loss import response_map_loss
 8  from loss_functions.vae_loss import vae_loss
 9  import numpy as np
10  
11  # TODO, alot of these ops could be inplaced to reduce memory use and improve compute
12  def pearson_correlation_loss(x1,x2,eps=1e-6):
13      assert len(x1.size()) == 1 and len(x2.size()) ==1 , 'Sizes must be the same and one-dimensional' 
14      m1 = torch.mean(x1,0,keepdim=True)
15      m2 = torch.mean(x2,0,keepdim=True)
16      c1 = (torch.add(x1,-m1)).div(m1+eps)
17      c2 = (torch.add(x2,-m2)).div(m2+eps)
18      numerator = torch.sum(c1.mul(c2),0,keepdim=True)
19      denom1 = torch.sqrt(torch.sum(c1.mul(c1),0,keepdim=True))
20      denom2 = torch.sqrt(torch.sum(c2.mul(c2),0,keepdim=True))
21      cor = numerator.div( torch.add(denom1.mul(denom2),eps))
22      return cor
23  
24  def triplet_correlation_loss(anchor,pos,neg,dummy_target,margin=1.0,eps=1e-6):
25      # for now, debatch this computation, to batch properly need to figure out how to broadcast in torch
26      batch_size = anchor.size(0)
27      pcps,pcns = [],[]
28      ne = old_div(anchor.nelement(),batch_size)
29      for i in range(0,batch_size):
30          pcps.append(pearson_correlation_loss(anchor[i,:].view(ne),pos[i,:].view(ne)))
31          pcns.append(pearson_correlation_loss(anchor[i,:].view(ne),neg[i,:].view(ne)))
32      pcp = torch.stack(pcps,0)
33      pcn = torch.stack(pcns,0)
34      dist_hinge = torch.clamp(margin + pcn - pcp, min=0.0)
35      loss = torch.mean(dist_hinge)
36      return loss
37  
38  def triplet_correlation_loss2(anchor,pcp,pcn,recon_output,mu,logvar,pos_map,neg_map,recon_target):
39      # down sample the map (or upsample the response)
40      # use avgpool + rounding
41      # take the ratio of the spatial extends
42  
43      vloss = vae_loss(recon_output,mu,logvar,recon_target)
44      pool_kernel = old_div(np.array(pos_map.size()[1:]),np.array(pcp.size()[-2:])).astype(np.int)
45      pos_map_resized = torch.round(F.avg_pool2d(pos_map,tuple(pool_kernel)))
46      neg_map_resized = torch.round(F.avg_pool2d(neg_map,tuple(pool_kernel)))
47  
48      nloss = F.binary_cross_entropy_with_logits(pcn.squeeze(),neg_map_resized.squeeze())
49      ploss = F.binary_cross_entropy_with_logits(pcp.squeeze(),pos_map_resized.squeeze())
50      # nloss = response_map_loss(pcn.squeeze())
51      # ploss = response_map_loss(pcp.squeeze(),boxes)
52      Logger().set('loss_component.anchor_mean',anchor.data.mean().item())
53      Logger().set('loss_component.anchor_std',anchor.data.std().item())
54      Logger().set('loss_component.ploss2',ploss.data.cpu().item())
55      Logger().set('loss_component.nloss2',nloss.data.cpu().item())
56      Logger().set('loss_component.vloss',vloss.data.cpu().item())
57      return ploss+nloss+vloss