/ pytlib / utils / batch_box_utils.py
batch_box_utils.py
 1  from __future__ import division
 2  from builtins import range
 3  from past.utils import old_div
 4  import torch
 5  import torch.nn.functional as F
 6  import numpy as np
 7  
 8  def rescale_boxes(boxes, scale, offset=[0,0]):
 9      assert boxes.shape[1]==4, 'dim[1] of boxes must the box dimensions, instead got {}'.format(boxes.shape)
10      height,width = scale
11      new_boxes = boxes.clone()
12      new_boxes[:,[1,3]] = boxes[:,[1,3]]*height + offset[0]
13      new_boxes[:,[0,2]] = boxes[:,[0,2]]*width + offset[1]
14      return new_boxes
15  
16  def batch_box_intersection(t1,t2):
17      assert len(t1.shape)==2 and len(t2.shape)==2 and \
18          t1.shape[1]==4 and t2.shape[1]==4, 'boxes must be Nx4 tensors'
19      x_min1,y_min1,x_max1,y_max1 = torch.chunk(t1,4,dim=1)
20      x_min2,y_min2,x_max2,y_max2 = torch.chunk(t2,4,dim=1)
21      y_mins = torch.max(y_min1,y_min2.transpose(0,1))
22      y_maxs = torch.min(y_max1,y_max2.transpose(0,1))
23      intersect_heights = torch.max(torch.zeros_like(y_maxs), y_maxs - y_mins)
24      x_mins = torch.max(x_min1,x_min2.transpose(0,1))
25      x_maxs = torch.min(x_max1,x_max2.transpose(0,1))
26      intersect_widths = torch.max(torch.zeros_like(x_maxs), x_maxs - x_mins)
27      result = intersect_heights * intersect_widths 
28      return result
29  
30  def batch_box_area(boxes):
31      assert len(boxes.shape)==2 and boxes.shape[1]==4, 'boxes must be Nx4 tensors'
32      x_min,y_min,x_max,y_max = torch.chunk(boxes,4,dim=1)
33      area = ((x_max - x_min)*(y_max - y_min))
34      return area
35  
36  def batch_box_IOU(t1,t2):
37      intersections = batch_box_intersection(t1,t2)
38      width,height = intersections.shape
39      areas1 = batch_box_area(t1)
40      areas2 = batch_box_area(t2)
41      areas1_expanded = areas1.expand(-1,height)
42      areas2_expanded = areas2.transpose(0,1).expand(width,-1)
43      unions = (areas1_expanded + areas2_expanded - intersections)
44      iou = torch.where(
45          torch.eq(intersections, 0.0),
46          torch.zeros_like(intersections), torch.div(intersections, unions))        
47      return iou
48  
49  def batch_nms(boxes,thresh=0.5):
50      assert len(boxes.shape)==2 and boxes.shape[1]==4, 'boxes must be Nx4 tensors'
51      #compute IOU of every pair of boxes
52      # NxN grid of ious
53      ious = batch_box_IOU(boxes,boxes)
54      #loop over every row in upper triagle, find row with IOU greater than T, mark that index
55      indices = []
56      for i in range(0,boxes.shape[0]):
57          idxs = (ious[i,i+1:]>=thresh).nonzero() + i+1
58          indices.append(idxs)
59      mask = torch.ones_like(boxes[:,0],dtype=torch.bool)
60      if indices:
61          all_unique_indices = torch.unique(torch.cat(indices))
62          mask[all_unique_indices] = 0
63      return boxes[mask], mask
64  
65  def euc_distance_cost(boxes1,boxes2):
66      assert len(boxes1.shape)==2 and boxes1.shape[1]==4, 'boxes must be Nx4 tensors'
67      assert len(boxes2.shape)==2 and boxes2.shape[1]==4, 'boxes must be Nx4 tensors'
68      # center squared cost: ((xmax1+xmin1)/2 - (xmax2+xmin2)/2))^2
69      x_min1,y_min1,x_max1,y_max1 = torch.chunk(boxes1,4,dim=1)
70      x_min2,y_min2,x_max2,y_max2 = torch.chunk(boxes2,4,dim=1)
71      cx1 = old_div((x_max1+x_min1),2)
72      cx2 = old_div((x_max2+x_min2),2)
73      cy1 = old_div((y_max1+y_min1),2)
74      cy2 = old_div((y_max2+y_min2),2)
75      xx1,xx2 = torch.meshgrid(cx1.squeeze(),cx2.squeeze())
76      yy1,yy2 = torch.meshgrid(cy1.squeeze(),cy2.squeeze())
77      distx = (xx2-xx1)*(xx2-xx1)
78      disty = (yy2-yy1)*(yy2-yy1)
79      return distx+disty
80  
81  def generate_region_meshgrid(num_regions,region_size,region_offsets):
82      hh,ww = torch.meshgrid(torch.arange(0,num_regions[0]),torch.arange(0,num_regions[1]))
83      hh = (hh*region_size[0]+region_offsets[0]).cuda().float()
84      ww = (ww*region_size[1]+region_offsets[1]).cuda().float()
85      return hh,ww