multi_object_detector_loss_test.py
1 import unittest 2 import torch 3 from loss_functions.multi_object_detector_loss import assign_targets 4 5 class TestMultiObjectDetectorLoss(unittest.TestCase): 6 7 def test_assign_targets(self): 8 preds = torch.Tensor([[0,0,1,1], 9 [2,2,4,4]]).unsqueeze(0) 10 targets = torch.Tensor([[3,3,4,4], 11 [0,0,1,1]]).unsqueeze(0) 12 pred_inds,target_inds = assign_targets(preds,targets) 13 # print pred_inds, target_inds 14 expected_pred_inds = [[0, 0], [0, 1]] 15 expected_target_inds = [[0, 0], [1, 0]] 16 self.assertEqual(pred_inds,expected_pred_inds) 17 self.assertEqual(target_inds,expected_target_inds) 18 19 if __name__ == '__main__': 20 unittest.main()