/ pytlib / data_loading / loaders / sequence_video_loader.py
sequence_video_loader.py
 1  from builtins import range
 2  from data_loading.loaders.loader import Loader
 3  from data_loading.sample import Sample
 4  from interface import implements
 5  from image.random_perturber import RandomPerturber
 6  from image.affine_transforms import resize_image_center_crop,apply_affine_to_frame
 7  from image.ptimage import PTImage,Ordering,ValueClass
 8  import numpy as np
 9  import random
10  import torch
11  from visualization.image_visualizer import ImageVisualizer
12  from image.ptimage import ValueClass
13  
14  class SequenceVideoSample(implements(Sample)):
15      def __init__(self,data,target):
16          self.data = data
17          self.target = target
18          self.output = None
19  
20      def visualize(self,parameters={}):
21          # visualizes a sequence
22          for i in range(self.data[0].shape[0]):
23              img = PTImage.from_cwh_torch(self.data[0][i])
24              ImageVisualizer().set_image(img,parameters.get('title','') + ' : Image {}'.format(i))
25          for i in range(self.output[2].shape[0]):
26              dmap = self.output[2][i]
27              depth_map = PTImage.from_2d_wh_torch(dmap)
28              ImageVisualizer().set_image(depth_map,parameters.get('title','') + ' : DepthMap {}'.format(i))
29  
30      def set_output(self,output):
31          self.output = output
32  
33      def get_data(self):
34          return self.data
35  
36      def get_target(self):
37          return self.target
38  
39  class SequenceVideoLoader(implements(Loader)):
40      def __init__(self,source,crop_size,num_frames=3):
41          self.source = source
42          self.crop_size = crop_size
43          self.num_frames = num_frames
44  
45      def __next__(self):
46          # randomly pick 3 frames in a row
47          num_frames_in_src = len(self.source)
48          # print("Number of frames in src {}".format(num_frames_in_src))
49  
50          # 1) choose the first frame from 0 -> N-2
51          frames = []
52          first_frame = random.randint(0,num_frames_in_src - self.num_frames)
53          for i in range(0,self.num_frames):
54              frames.append(self.source[first_frame+i])
55  
56          # 2) generate a random perturbation and perturb all the frames
57          # note, need to apply same perts to all frames
58          perturb_params = {'translation_range':[0.0,0.0],
59                            'scaling_range':[1.0,1.0]}
60          perturbed_frames = []
61          for f in frames:
62              perturbed_frame = RandomPerturber.perturb_frame(f,perturb_params)
63              crop_affine = resize_image_center_crop(perturbed_frame.image,self.crop_size)
64              output_size = [self.crop_size[1],self.crop_size[0]]
65              perturbed_frame = apply_affine_to_frame(perturbed_frame,crop_affine,output_size)
66              perturbed_frames.append(perturbed_frame)
67              # perturbed_frame.visualize(title='chw_image',display=True)
68  
69          # 3) prepare tensors
70          # -make a tensor with a stack of 3 frame 
71          # -add the calibration to the target
72          input_tensors = []
73          calib_mats = []
74          for f in perturbed_frames:
75              img = f.image.to_order_and_class(Ordering.CHW,ValueClass.FLOAT01)
76              input_tensors.append(torch.Tensor(img.get_data().astype(float)))
77              calib_mats.append(torch.Tensor(f.calib_mat))
78  
79          # the input is now 3xCxWxH
80          sample = SequenceVideoSample([torch.stack(input_tensors,dim=0)],
81                                       [torch.stack(calib_mats)])
82          return sample
83  
84  
85  
86  
87