/ pytlib / loss_functions / multi_object_detector_loss.py
multi_object_detector_loss.py
  1  from __future__ import division
  2  from builtins import range
  3  from past.utils import old_div
  4  import torch
  5  import scipy.optimize
  6  import torch.nn.functional as F
  7  from loss_functions.box_loss import box_loss
  8  from utils.logger import Logger
  9  from utils.batch_box_utils import rescale_boxes, euc_distance_cost, generate_region_meshgrid
 10  import numpy as np
 11  
 12  def preprocess_targets_and_preds(targets, box_preds, class_preds, original_image):
 13      # 1) prepare dummy masks
 14      # 2) separate box targets
 15      # 3) separate class targets, no need for 1-hot encoding, apply NLL loss directly
 16      batch_size = original_image.shape[0]
 17  
 18      dummy_mask = targets[:,:,0]>-1
 19      dummy_masks = [x.squeeze() for x in torch.chunk(dummy_mask,targets.shape[0])]
 20      box_targets = targets[:,:,1:]
 21      class_targets = targets[:,:,0]
 22  
 23      # scale predictions by visual regions
 24      # use a meshgrid, for example we have 2x2 feature map with 32pix strides
 25      # then the grid looks like
 26      # yy = [[16,16],[48,48]]
 27      # xx = [[16,48],[16,48]]
 28      # we use this meshgrid to create an offset for the original predictions
 29      # so that a 0,0 prediction is exactly at the center of that region
 30      # and -1,1 are the edges of this region
 31      # use original_image.shape/cnn feature map shape to approximate the vision region size
 32      # TODO, check this works for non-multiple sizes of cnn reduction factor
 33      region_size = old_div(np.array(original_image.shape[2:]),np.array(box_preds.shape[3:]))
 34      hh,ww = generate_region_meshgrid(box_preds.shape[3:], region_size, old_div(region_size,2))
 35      rescaled_box_preds = rescale_boxes(box_preds, region_size, [hh,ww])
 36      rescaled_box_preds_flat = rescaled_box_preds.flatten(start_dim=2).transpose(1,2)
 37  
 38      box_targets_flat = box_targets.reshape(batch_size,-1,4)
 39  
 40      num_classes = class_preds.shape[1]
 41      class_preds_flatten_hw = class_preds.flatten(start_dim=2).transpose(1,2)
 42      logsoftmax_preds = F.log_softmax(class_preds_flatten_hw,dim=2)
 43      return rescaled_box_preds_flat, box_targets_flat, logsoftmax_preds, class_targets, dummy_masks
 44  
 45  def assign_targets(box_preds, box_targets, dummy_target_masks=None):
 46      if dummy_target_masks is None:
 47          dummy_target_masks = [torch.ones_like(box_targets[0,:,0],dtype=torch.bool)]*box_targets.shape[0]
 48  
 49      assert len(box_preds.shape)==3 and len(box_targets.shape)==3, 'boxes must be BxNx4'
 50      assert len(dummy_target_masks)==box_preds.shape[0], 'dummy_masks must match batch size'
 51      # explicit loop over batches
 52      pred_indices,target_indices = [[],[]],[[],[]]
 53      for i in range(0,box_targets.shape[0]):
 54          # Note, not using IOU cost here because most boxes would never get matched
 55          cost = euc_distance_cost(box_preds[i,:,:],box_targets[i,dummy_target_masks[i],:])
 56          original_target_indices = (dummy_target_masks[i]!=0).nonzero().squeeze(1).cpu().numpy()
 57          # next use scipy's hungarian to create the assignment
 58          row_inds, col_inds = scipy.optimize.linear_sum_assignment(cost.detach().cpu().numpy())
 59          pred_indices[0].extend([i]*len(row_inds)) # batch index
 60          pred_indices[1].extend(row_inds)
 61          target_indices[0].extend([i]*len(col_inds)) # batch index
 62          original_col_indices = [original_target_indices[k] for k in col_inds]
 63          target_indices[1].extend(original_col_indices) #original target indices        
 64      return pred_indices,target_indices
 65  
 66  def multi_object_detector_loss(original_image, 
 67                                 box_preds, 
 68                                 class_preds, 
 69                                 targets, 
 70                                 pos_to_neg_class_weight_ratio=0.25,
 71                                 class_loss_weight=2.0,
 72                                 box_loss_weight=1.0):
 73      # 1) preprocess targets
 74      # p_box_targets: BxNx4
 75      # p_box_preds: BxNx4
 76      # p_class_preds: BxNxC
 77      # p_class_targets: BxNx1
 78      # dummy_masks: list, batch items of N
 79      p_box_preds, p_box_targets, p_class_preds, p_class_targets, dummy_target_masks = \
 80          preprocess_targets_and_preds(targets, box_preds, class_preds, original_image)
 81      # total number of classes including bg
 82      num_classes = p_class_preds.shape[2]
 83  
 84      # 2) globally assign targets against predictions
 85      pred_indices, target_indices = assign_targets(p_box_preds, p_box_targets, dummy_target_masks)
 86  
 87      targets_exist = p_box_targets[target_indices].numel()
 88      # 3) only targets that have been assigned gets a box loss
 89      total_box_loss = 0
 90      if targets_exist:
 91          total_box_loss = box_loss(p_box_preds[pred_indices], p_box_targets[target_indices])
 92          Logger().set('loss_component.total_box_loss',total_box_loss.mean().item())
 93  
 94      # 4) all targets get classification loss
 95      # TODO, move this into a function with a unit test
 96      positive_class_loss = 0
 97      Logger().set('loss_component.positve_class_targets_size',p_class_targets[target_indices].shape[0])
 98      if targets_exist:
 99          positive_class_loss += F.nll_loss(p_class_preds[pred_indices], p_class_targets[target_indices].long())
100          Logger().set('loss_component.positive_class_loss',positive_class_loss.mean().item())
101      
102      mask = torch.ones_like(p_class_preds,dtype=torch.bool)
103      mask[pred_indices] = 0
104      neg_preds = torch.masked_select(p_class_preds,mask).reshape(-1,num_classes)
105      neg_targets = neg_preds.new_ones(neg_preds.shape[0],dtype=torch.long)*(p_class_preds.shape[2]-1)
106      Logger().set('loss_component.negative_class_targets_size',neg_targets.flatten().shape[0])
107      negative_class_loss = F.nll_loss(neg_preds,neg_targets)
108      Logger().set('loss_component.negative_class_loss',negative_class_loss.mean().item())
109      total_class_loss = old_div(pos_to_neg_class_weight_ratio,(1.+pos_to_neg_class_weight_ratio))*positive_class_loss \
110          + old_div(1,(1.+pos_to_neg_class_weight_ratio))*negative_class_loss       
111      
112      # 5) total loss = w0*class_loss + w1*box_loss
113      Logger().set('loss_component.total_class_loss',total_class_loss.mean().item())
114      total_loss = class_loss_weight*total_class_loss + box_loss_weight*total_box_loss
115      
116      total_positive_targets = torch.sum(F.softmax(class_preds.flatten(start_dim=2).transpose(1,2),dim=2)[:,:,0]>0.5)
117      total_negative_targets = torch.sum(F.softmax(class_preds.flatten(start_dim=2).transpose(1,2),dim=2)[:,:,1]>0.5)
118      Logger().set('loss_component.total_negative_targets',total_negative_targets.item())
119      Logger().set('loss_component.total_positive_targets',total_positive_targets.item())
120      return total_loss