batch_box_utils.py
1 from __future__ import division 2 from builtins import range 3 from past.utils import old_div 4 import torch 5 import torch.nn.functional as F 6 import numpy as np 7 8 def rescale_boxes(boxes, scale, offset=[0,0]): 9 assert boxes.shape[1]==4, 'dim[1] of boxes must the box dimensions, instead got {}'.format(boxes.shape) 10 height,width = scale 11 new_boxes = boxes.clone() 12 new_boxes[:,[1,3]] = boxes[:,[1,3]]*height + offset[0] 13 new_boxes[:,[0,2]] = boxes[:,[0,2]]*width + offset[1] 14 return new_boxes 15 16 def batch_box_intersection(t1,t2): 17 assert len(t1.shape)==2 and len(t2.shape)==2 and \ 18 t1.shape[1]==4 and t2.shape[1]==4, 'boxes must be Nx4 tensors' 19 x_min1,y_min1,x_max1,y_max1 = torch.chunk(t1,4,dim=1) 20 x_min2,y_min2,x_max2,y_max2 = torch.chunk(t2,4,dim=1) 21 y_mins = torch.max(y_min1,y_min2.transpose(0,1)) 22 y_maxs = torch.min(y_max1,y_max2.transpose(0,1)) 23 intersect_heights = torch.max(torch.zeros_like(y_maxs), y_maxs - y_mins) 24 x_mins = torch.max(x_min1,x_min2.transpose(0,1)) 25 x_maxs = torch.min(x_max1,x_max2.transpose(0,1)) 26 intersect_widths = torch.max(torch.zeros_like(x_maxs), x_maxs - x_mins) 27 result = intersect_heights * intersect_widths 28 return result 29 30 def batch_box_area(boxes): 31 assert len(boxes.shape)==2 and boxes.shape[1]==4, 'boxes must be Nx4 tensors' 32 x_min,y_min,x_max,y_max = torch.chunk(boxes,4,dim=1) 33 area = ((x_max - x_min)*(y_max - y_min)) 34 return area 35 36 def batch_box_IOU(t1,t2): 37 intersections = batch_box_intersection(t1,t2) 38 width,height = intersections.shape 39 areas1 = batch_box_area(t1) 40 areas2 = batch_box_area(t2) 41 areas1_expanded = areas1.expand(-1,height) 42 areas2_expanded = areas2.transpose(0,1).expand(width,-1) 43 unions = (areas1_expanded + areas2_expanded - intersections) 44 iou = torch.where( 45 torch.eq(intersections, 0.0), 46 torch.zeros_like(intersections), torch.div(intersections, unions)) 47 return iou 48 49 def batch_nms(boxes,thresh=0.5): 50 assert len(boxes.shape)==2 and boxes.shape[1]==4, 'boxes must be Nx4 tensors' 51 #compute IOU of every pair of boxes 52 # NxN grid of ious 53 ious = batch_box_IOU(boxes,boxes) 54 #loop over every row in upper triagle, find row with IOU greater than T, mark that index 55 indices = [] 56 for i in range(0,boxes.shape[0]): 57 idxs = (ious[i,i+1:]>=thresh).nonzero() + i+1 58 indices.append(idxs) 59 mask = torch.ones_like(boxes[:,0],dtype=torch.bool) 60 if indices: 61 all_unique_indices = torch.unique(torch.cat(indices)) 62 mask[all_unique_indices] = 0 63 return boxes[mask], mask 64 65 def euc_distance_cost(boxes1,boxes2): 66 assert len(boxes1.shape)==2 and boxes1.shape[1]==4, 'boxes must be Nx4 tensors' 67 assert len(boxes2.shape)==2 and boxes2.shape[1]==4, 'boxes must be Nx4 tensors' 68 # center squared cost: ((xmax1+xmin1)/2 - (xmax2+xmin2)/2))^2 69 x_min1,y_min1,x_max1,y_max1 = torch.chunk(boxes1,4,dim=1) 70 x_min2,y_min2,x_max2,y_max2 = torch.chunk(boxes2,4,dim=1) 71 cx1 = old_div((x_max1+x_min1),2) 72 cx2 = old_div((x_max2+x_min2),2) 73 cy1 = old_div((y_max1+y_min1),2) 74 cy2 = old_div((y_max2+y_min2),2) 75 xx1,xx2 = torch.meshgrid(cx1.squeeze(),cx2.squeeze()) 76 yy1,yy2 = torch.meshgrid(cy1.squeeze(),cy2.squeeze()) 77 distx = (xx2-xx1)*(xx2-xx1) 78 disty = (yy2-yy1)*(yy2-yy1) 79 return distx+disty 80 81 def generate_region_meshgrid(num_regions,region_size,region_offsets): 82 hh,ww = torch.meshgrid(torch.arange(0,num_regions[0]),torch.arange(0,num_regions[1])) 83 hh = (hh*region_size[0]+region_offsets[0]).cuda().float() 84 ww = (ww*region_size[1]+region_offsets[1]).cuda().float() 85 return hh,ww