/ pytlib / loss_functions / mono_depth_loss_test.py
mono_depth_loss_test.py
 1  import unittest
 2  import torch
 3  from loss_functions.mono_depth_loss import process_single_batch
 4  from utils.test_utils import near_tensor_equality
 5  import numpy as np
 6  
 7  class TestMonoDepthLoss(unittest.TestCase):
 8  
 9      # if disp_map are 0s and ego transform is identity
10      # then we should expect to get the same image back
11      def test_process_single_batch_no_motion(self):
12          image_0 = torch.Tensor([[[0,0],[2,2]],
13                                  [[1,1],[3,3]],
14                                  [[2,2],[4,4]]])
15          ego_motion_vectors = [torch.Tensor([0,0,0,0,0,0])] 
16          disp_map = torch.Tensor([[0,0],[0,0]])
17          calib_0 = torch.Tensor([[50,0,20],
18                                  [0,60,10],
19                                  [0,0,1]])
20          image_input = torch.stack([image_0,image_0],dim=0)
21          disp_input = torch.stack([disp_map,disp_map],dim=0)
22          calib_input = torch.stack([calib_0,calib_0],dim=0)
23          loss,out_images = process_single_batch(image_input,ego_motion_vectors,disp_input,calib_input)
24          self.assertTrue(near_tensor_equality(image_input,out_images[0]))
25          self.assertTrue(near_tensor_equality(loss,torch.Tensor([0])))
26  
27      # when depth is large, then small motions should be irrelevant
28      # and therefore reprojection error is small
29      def test_process_single_batch_motion_large_depth(self):
30          image_0 = torch.Tensor([[[0,0],[2,2]],
31                                  [[1,1],[3,3]],
32                                  [[2,2],[4,4]]])
33          ego_motion_vectors = [torch.Tensor([10,10,10,0,0,0])] 
34          disp_map_0 = torch.Tensor([[1e-8,1e-8],[1e-8,1e-8]])
35          calib_0 = torch.Tensor([[50,0,20],
36                                  [0,60,10],
37                                  [0,0,1]])
38          image_input = torch.stack([image_0,image_0],dim=0)
39          disp_input = torch.stack([disp_map_0,disp_map_0],dim=0)
40          calib_input = torch.stack([calib_0,calib_0],dim=0)
41          loss,out_images = process_single_batch(image_input,ego_motion_vectors,disp_input,calib_input)
42          self.assertTrue(near_tensor_equality(image_0,out_images[0],tol=1e-2))
43          self.assertTrue(near_tensor_equality(loss,torch.Tensor([0])))
44  
45  if __name__ == '__main__':
46      unittest.main()