/ pytlib / utils / batch_box_utils_test.py
batch_box_utils_test.py
  1  import unittest
  2  import torch
  3  from utils.batch_box_utils import batch_box_intersection, batch_box_area, batch_box_IOU, rescale_boxes, euc_distance_cost, batch_nms
  4  from utils.test_utils import near_tensor_equality
  5  
  6  class TestBatchBoxUtils(unittest.TestCase):
  7  
  8      def test_rescale_boxes_no_offset(self):
  9          boxes = torch.Tensor([[0,0,1,1],
 10                                [2,2,4,4]])
 11          scale = [1,2]
 12          scaled_boxes = rescale_boxes(boxes, scale)
 13          expected_boxes = torch.Tensor([[0.,0.,2.,1.],
 14                                         [4.,2.,8.,4.]])
 15          # print expected_boxes
 16          self.assertTrue(near_tensor_equality(scaled_boxes,expected_boxes))
 17  
 18      def test_rescale_boxes_with_offset(self):
 19          boxes = torch.Tensor([[0,0,1,1],
 20                                [2,2,4,4]])
 21          scale = [1,2]
 22          offset = [40,30]
 23          scaled_boxes = rescale_boxes(boxes, scale, offset)
 24          expected_boxes = torch.Tensor([[30.,40.,32.,41.],
 25                                         [34.,42.,38.,44.]])
 26          # print scaled_boxes
 27          self.assertTrue(near_tensor_equality(scaled_boxes,expected_boxes))
 28  
 29      def test_batch_box_IOU_edge_cases(self):
 30          t1 = torch.Tensor([[0,0,0,0]])
 31          t2 = torch.Tensor([[0,0,0,0]])
 32          ious = batch_box_IOU(t1,t2)
 33          expected_matrix = torch.Tensor([[0.]])
 34          self.assertTrue(near_tensor_equality(ious,expected_matrix))
 35  
 36      def test_euc_distance_cost(self):
 37          t1 = torch.Tensor([[0,0,1,1],
 38                             [2,2,8,8],
 39                             [-1,-1,0,0]])
 40          t2 = torch.Tensor([[0,0,1,1],
 41                             [3,3,9,9]])
 42          distmat = euc_distance_cost(t1,t2)
 43          expected_matrix = torch.Tensor([[ 0.0000, 60.5000],
 44                                          [40.5000,  2.0000],
 45                                          [ 2.0000, 84.5000]])
 46          self.assertTrue(near_tensor_equality(distmat,expected_matrix))
 47  
 48      def test_batch_box_IOU(self):
 49          t1 = torch.Tensor([[0,0,1,1],
 50                             [2,2,8,8],
 51                             [-1,-1,0,0]])
 52          t2 = torch.Tensor([[0,0,1,1],
 53                             [3,3,9,9]])
 54          ious = batch_box_IOU(t1,t2)
 55          expected_matrix = torch.Tensor([[1.0000, 0.0000],
 56                                          [0.0000, 0.5319],
 57                                          [0.0000, 0.0000]])
 58          self.assertTrue(near_tensor_equality(ious,expected_matrix))
 59  
 60      def test_batch_box_intersection_edge_cases(self):
 61          # no intersection
 62          t1 = torch.Tensor([[0,0,1,1]])
 63          t2 = torch.Tensor([[3,4,9,10],
 64                             [3,3,9,9]])
 65          intersection_matrix = batch_box_intersection(t1,t2)
 66          expected_matrix = torch.Tensor([[ 0.],
 67                                          [ 0.]])
 68          self.assertTrue(torch.all(torch.eq(intersection_matrix,expected_matrix)))
 69  
 70          # negative intersection
 71          t1 = torch.Tensor([[0,0,-1,-1]])
 72          t2 = torch.Tensor([[-3,-4,9,10],
 73                             [-3,-3,9,9]])
 74          intersection_matrix = batch_box_intersection(t1,t2)
 75          expected_matrix = torch.Tensor([[ 0.],
 76                                          [ 0.]])
 77          self.assertTrue(torch.all(torch.eq(intersection_matrix,expected_matrix)))
 78  
 79      def test_batch_box_intersection(self):
 80          t1 = torch.Tensor([[0,0,1,1],
 81                             [2,2,8,8]])
 82          t2 = torch.Tensor([[0,0,1,1],
 83                             [3,3,9,9]])
 84          intersection_matrix = batch_box_intersection(t1,t2)
 85          expected_matrix = torch.Tensor([[ 1.,  0.],
 86                                          [ 0., 25.]])
 87          self.assertTrue(torch.all(torch.eq(intersection_matrix,expected_matrix)))
 88  
 89      def test_batch_box_area(self):
 90          boxes = torch.Tensor([[0,0,1,1],
 91                                [2,2,8,8],
 92                                [0,0,0,0]])
 93          areas = batch_box_area(boxes)
 94          expected_areas = torch.tensor([[1.,36.,0.]]).transpose(0,1)
 95          self.assertTrue(near_tensor_equality(areas,expected_areas))
 96  
 97      def test_batch_nms_a(self):
 98          boxes = torch.Tensor([[0,0,1,1],
 99                                [0,0,1,1],
100                                [0,0,5,5],
101                                [1,1,6,6],
102                                [3,3,10,10],
103                                [3,3,6,6],
104                                [0,0,1,1]])
105  
106          expected_result = torch.Tensor([[0,0,1,1],
107                                       [0,0,5,5],
108                                       [1,1,6,6],
109                                       [3,3,10,10],
110                                       [3,3,6,6]])
111          # thresh=0.5        
112          results, _ = batch_nms(boxes)
113          self.assertTrue(torch.all(torch.eq(results,expected_result)))
114  
115      def test_batch_nms_b(self):
116          boxes = torch.Tensor([[0,0,1,1],
117                                [0,0,1,1],
118                                [0,0,5,5],
119                                [1,1,6,6],
120                                [3,3,10,10],
121                                [3,3,6,6],
122                                [0,0,1,1]])
123  
124          expected_result = torch.Tensor([[0,0,1,1],
125                                          [0,0,5,5]])
126          results, _ = batch_nms(boxes,thresh=0.1)
127          self.assertTrue(torch.all(torch.eq(results,expected_result)))
128  
129  if __name__ == '__main__':
130      unittest.main()