/ pytlib / loss_functions / multi_object_detector_loss_test.py
multi_object_detector_loss_test.py
 1  import unittest
 2  import torch
 3  from loss_functions.multi_object_detector_loss import assign_targets
 4  
 5  class TestMultiObjectDetectorLoss(unittest.TestCase):
 6  
 7      def test_assign_targets(self):
 8          preds = torch.Tensor([[0,0,1,1],
 9                                [2,2,4,4]]).unsqueeze(0)
10          targets = torch.Tensor([[3,3,4,4],
11                                  [0,0,1,1]]).unsqueeze(0)
12          pred_inds,target_inds = assign_targets(preds,targets)
13          # print pred_inds, target_inds
14          expected_pred_inds = [[0, 0], [0, 1]]
15          expected_target_inds = [[0, 0], [1, 0]]
16          self.assertEqual(pred_inds,expected_pred_inds)
17          self.assertEqual(target_inds,expected_target_inds)
18  
19  if __name__ == '__main__':
20      unittest.main()