multi_object_detector_loss.py
1 from __future__ import division 2 from builtins import range 3 from past.utils import old_div 4 import torch 5 import scipy.optimize 6 import torch.nn.functional as F 7 from loss_functions.box_loss import box_loss 8 from utils.logger import Logger 9 from utils.batch_box_utils import rescale_boxes, euc_distance_cost, generate_region_meshgrid 10 import numpy as np 11 12 def preprocess_targets_and_preds(targets, box_preds, class_preds, original_image): 13 # 1) prepare dummy masks 14 # 2) separate box targets 15 # 3) separate class targets, no need for 1-hot encoding, apply NLL loss directly 16 batch_size = original_image.shape[0] 17 18 dummy_mask = targets[:,:,0]>-1 19 dummy_masks = [x.squeeze() for x in torch.chunk(dummy_mask,targets.shape[0])] 20 box_targets = targets[:,:,1:] 21 class_targets = targets[:,:,0] 22 23 # scale predictions by visual regions 24 # use a meshgrid, for example we have 2x2 feature map with 32pix strides 25 # then the grid looks like 26 # yy = [[16,16],[48,48]] 27 # xx = [[16,48],[16,48]] 28 # we use this meshgrid to create an offset for the original predictions 29 # so that a 0,0 prediction is exactly at the center of that region 30 # and -1,1 are the edges of this region 31 # use original_image.shape/cnn feature map shape to approximate the vision region size 32 # TODO, check this works for non-multiple sizes of cnn reduction factor 33 region_size = old_div(np.array(original_image.shape[2:]),np.array(box_preds.shape[3:])) 34 hh,ww = generate_region_meshgrid(box_preds.shape[3:], region_size, old_div(region_size,2)) 35 rescaled_box_preds = rescale_boxes(box_preds, region_size, [hh,ww]) 36 rescaled_box_preds_flat = rescaled_box_preds.flatten(start_dim=2).transpose(1,2) 37 38 box_targets_flat = box_targets.reshape(batch_size,-1,4) 39 40 num_classes = class_preds.shape[1] 41 class_preds_flatten_hw = class_preds.flatten(start_dim=2).transpose(1,2) 42 logsoftmax_preds = F.log_softmax(class_preds_flatten_hw,dim=2) 43 return rescaled_box_preds_flat, box_targets_flat, logsoftmax_preds, class_targets, dummy_masks 44 45 def assign_targets(box_preds, box_targets, dummy_target_masks=None): 46 if dummy_target_masks is None: 47 dummy_target_masks = [torch.ones_like(box_targets[0,:,0],dtype=torch.bool)]*box_targets.shape[0] 48 49 assert len(box_preds.shape)==3 and len(box_targets.shape)==3, 'boxes must be BxNx4' 50 assert len(dummy_target_masks)==box_preds.shape[0], 'dummy_masks must match batch size' 51 # explicit loop over batches 52 pred_indices,target_indices = [[],[]],[[],[]] 53 for i in range(0,box_targets.shape[0]): 54 # Note, not using IOU cost here because most boxes would never get matched 55 cost = euc_distance_cost(box_preds[i,:,:],box_targets[i,dummy_target_masks[i],:]) 56 original_target_indices = (dummy_target_masks[i]!=0).nonzero().squeeze(1).cpu().numpy() 57 # next use scipy's hungarian to create the assignment 58 row_inds, col_inds = scipy.optimize.linear_sum_assignment(cost.detach().cpu().numpy()) 59 pred_indices[0].extend([i]*len(row_inds)) # batch index 60 pred_indices[1].extend(row_inds) 61 target_indices[0].extend([i]*len(col_inds)) # batch index 62 original_col_indices = [original_target_indices[k] for k in col_inds] 63 target_indices[1].extend(original_col_indices) #original target indices 64 return pred_indices,target_indices 65 66 def multi_object_detector_loss(original_image, 67 box_preds, 68 class_preds, 69 targets, 70 pos_to_neg_class_weight_ratio=0.25, 71 class_loss_weight=2.0, 72 box_loss_weight=1.0): 73 # 1) preprocess targets 74 # p_box_targets: BxNx4 75 # p_box_preds: BxNx4 76 # p_class_preds: BxNxC 77 # p_class_targets: BxNx1 78 # dummy_masks: list, batch items of N 79 p_box_preds, p_box_targets, p_class_preds, p_class_targets, dummy_target_masks = \ 80 preprocess_targets_and_preds(targets, box_preds, class_preds, original_image) 81 # total number of classes including bg 82 num_classes = p_class_preds.shape[2] 83 84 # 2) globally assign targets against predictions 85 pred_indices, target_indices = assign_targets(p_box_preds, p_box_targets, dummy_target_masks) 86 87 targets_exist = p_box_targets[target_indices].numel() 88 # 3) only targets that have been assigned gets a box loss 89 total_box_loss = 0 90 if targets_exist: 91 total_box_loss = box_loss(p_box_preds[pred_indices], p_box_targets[target_indices]) 92 Logger().set('loss_component.total_box_loss',total_box_loss.mean().item()) 93 94 # 4) all targets get classification loss 95 # TODO, move this into a function with a unit test 96 positive_class_loss = 0 97 Logger().set('loss_component.positve_class_targets_size',p_class_targets[target_indices].shape[0]) 98 if targets_exist: 99 positive_class_loss += F.nll_loss(p_class_preds[pred_indices], p_class_targets[target_indices].long()) 100 Logger().set('loss_component.positive_class_loss',positive_class_loss.mean().item()) 101 102 mask = torch.ones_like(p_class_preds,dtype=torch.bool) 103 mask[pred_indices] = 0 104 neg_preds = torch.masked_select(p_class_preds,mask).reshape(-1,num_classes) 105 neg_targets = neg_preds.new_ones(neg_preds.shape[0],dtype=torch.long)*(p_class_preds.shape[2]-1) 106 Logger().set('loss_component.negative_class_targets_size',neg_targets.flatten().shape[0]) 107 negative_class_loss = F.nll_loss(neg_preds,neg_targets) 108 Logger().set('loss_component.negative_class_loss',negative_class_loss.mean().item()) 109 total_class_loss = old_div(pos_to_neg_class_weight_ratio,(1.+pos_to_neg_class_weight_ratio))*positive_class_loss \ 110 + old_div(1,(1.+pos_to_neg_class_weight_ratio))*negative_class_loss 111 112 # 5) total loss = w0*class_loss + w1*box_loss 113 Logger().set('loss_component.total_class_loss',total_class_loss.mean().item()) 114 total_loss = class_loss_weight*total_class_loss + box_loss_weight*total_box_loss 115 116 total_positive_targets = torch.sum(F.softmax(class_preds.flatten(start_dim=2).transpose(1,2),dim=2)[:,:,0]>0.5) 117 total_negative_targets = torch.sum(F.softmax(class_preds.flatten(start_dim=2).transpose(1,2),dim=2)[:,:,1]>0.5) 118 Logger().set('loss_component.total_negative_targets',total_negative_targets.item()) 119 Logger().set('loss_component.total_positive_targets',total_positive_targets.item()) 120 return total_loss