batch_box_utils_test.py
1 import unittest 2 import torch 3 from utils.batch_box_utils import batch_box_intersection, batch_box_area, batch_box_IOU, rescale_boxes, euc_distance_cost, batch_nms 4 from utils.test_utils import near_tensor_equality 5 6 class TestBatchBoxUtils(unittest.TestCase): 7 8 def test_rescale_boxes_no_offset(self): 9 boxes = torch.Tensor([[0,0,1,1], 10 [2,2,4,4]]) 11 scale = [1,2] 12 scaled_boxes = rescale_boxes(boxes, scale) 13 expected_boxes = torch.Tensor([[0.,0.,2.,1.], 14 [4.,2.,8.,4.]]) 15 # print expected_boxes 16 self.assertTrue(near_tensor_equality(scaled_boxes,expected_boxes)) 17 18 def test_rescale_boxes_with_offset(self): 19 boxes = torch.Tensor([[0,0,1,1], 20 [2,2,4,4]]) 21 scale = [1,2] 22 offset = [40,30] 23 scaled_boxes = rescale_boxes(boxes, scale, offset) 24 expected_boxes = torch.Tensor([[30.,40.,32.,41.], 25 [34.,42.,38.,44.]]) 26 # print scaled_boxes 27 self.assertTrue(near_tensor_equality(scaled_boxes,expected_boxes)) 28 29 def test_batch_box_IOU_edge_cases(self): 30 t1 = torch.Tensor([[0,0,0,0]]) 31 t2 = torch.Tensor([[0,0,0,0]]) 32 ious = batch_box_IOU(t1,t2) 33 expected_matrix = torch.Tensor([[0.]]) 34 self.assertTrue(near_tensor_equality(ious,expected_matrix)) 35 36 def test_euc_distance_cost(self): 37 t1 = torch.Tensor([[0,0,1,1], 38 [2,2,8,8], 39 [-1,-1,0,0]]) 40 t2 = torch.Tensor([[0,0,1,1], 41 [3,3,9,9]]) 42 distmat = euc_distance_cost(t1,t2) 43 expected_matrix = torch.Tensor([[ 0.0000, 60.5000], 44 [40.5000, 2.0000], 45 [ 2.0000, 84.5000]]) 46 self.assertTrue(near_tensor_equality(distmat,expected_matrix)) 47 48 def test_batch_box_IOU(self): 49 t1 = torch.Tensor([[0,0,1,1], 50 [2,2,8,8], 51 [-1,-1,0,0]]) 52 t2 = torch.Tensor([[0,0,1,1], 53 [3,3,9,9]]) 54 ious = batch_box_IOU(t1,t2) 55 expected_matrix = torch.Tensor([[1.0000, 0.0000], 56 [0.0000, 0.5319], 57 [0.0000, 0.0000]]) 58 self.assertTrue(near_tensor_equality(ious,expected_matrix)) 59 60 def test_batch_box_intersection_edge_cases(self): 61 # no intersection 62 t1 = torch.Tensor([[0,0,1,1]]) 63 t2 = torch.Tensor([[3,4,9,10], 64 [3,3,9,9]]) 65 intersection_matrix = batch_box_intersection(t1,t2) 66 expected_matrix = torch.Tensor([[ 0.], 67 [ 0.]]) 68 self.assertTrue(torch.all(torch.eq(intersection_matrix,expected_matrix))) 69 70 # negative intersection 71 t1 = torch.Tensor([[0,0,-1,-1]]) 72 t2 = torch.Tensor([[-3,-4,9,10], 73 [-3,-3,9,9]]) 74 intersection_matrix = batch_box_intersection(t1,t2) 75 expected_matrix = torch.Tensor([[ 0.], 76 [ 0.]]) 77 self.assertTrue(torch.all(torch.eq(intersection_matrix,expected_matrix))) 78 79 def test_batch_box_intersection(self): 80 t1 = torch.Tensor([[0,0,1,1], 81 [2,2,8,8]]) 82 t2 = torch.Tensor([[0,0,1,1], 83 [3,3,9,9]]) 84 intersection_matrix = batch_box_intersection(t1,t2) 85 expected_matrix = torch.Tensor([[ 1., 0.], 86 [ 0., 25.]]) 87 self.assertTrue(torch.all(torch.eq(intersection_matrix,expected_matrix))) 88 89 def test_batch_box_area(self): 90 boxes = torch.Tensor([[0,0,1,1], 91 [2,2,8,8], 92 [0,0,0,0]]) 93 areas = batch_box_area(boxes) 94 expected_areas = torch.tensor([[1.,36.,0.]]).transpose(0,1) 95 self.assertTrue(near_tensor_equality(areas,expected_areas)) 96 97 def test_batch_nms_a(self): 98 boxes = torch.Tensor([[0,0,1,1], 99 [0,0,1,1], 100 [0,0,5,5], 101 [1,1,6,6], 102 [3,3,10,10], 103 [3,3,6,6], 104 [0,0,1,1]]) 105 106 expected_result = torch.Tensor([[0,0,1,1], 107 [0,0,5,5], 108 [1,1,6,6], 109 [3,3,10,10], 110 [3,3,6,6]]) 111 # thresh=0.5 112 results, _ = batch_nms(boxes) 113 self.assertTrue(torch.all(torch.eq(results,expected_result))) 114 115 def test_batch_nms_b(self): 116 boxes = torch.Tensor([[0,0,1,1], 117 [0,0,1,1], 118 [0,0,5,5], 119 [1,1,6,6], 120 [3,3,10,10], 121 [3,3,6,6], 122 [0,0,1,1]]) 123 124 expected_result = torch.Tensor([[0,0,1,1], 125 [0,0,5,5]]) 126 results, _ = batch_nms(boxes,thresh=0.1) 127 self.assertTrue(torch.all(torch.eq(results,expected_result))) 128 129 if __name__ == '__main__': 130 unittest.main()