sequence_video_loader.py
1 from builtins import range 2 from data_loading.loaders.loader import Loader 3 from data_loading.sample import Sample 4 from interface import implements 5 from image.random_perturber import RandomPerturber 6 from image.affine_transforms import resize_image_center_crop,apply_affine_to_frame 7 from image.ptimage import PTImage,Ordering,ValueClass 8 import numpy as np 9 import random 10 import torch 11 from visualization.image_visualizer import ImageVisualizer 12 from image.ptimage import ValueClass 13 14 class SequenceVideoSample(implements(Sample)): 15 def __init__(self,data,target): 16 self.data = data 17 self.target = target 18 self.output = None 19 20 def visualize(self,parameters={}): 21 # visualizes a sequence 22 for i in range(self.data[0].shape[0]): 23 img = PTImage.from_cwh_torch(self.data[0][i]) 24 ImageVisualizer().set_image(img,parameters.get('title','') + ' : Image {}'.format(i)) 25 for i in range(self.output[2].shape[0]): 26 dmap = self.output[2][i] 27 depth_map = PTImage.from_2d_wh_torch(dmap) 28 ImageVisualizer().set_image(depth_map,parameters.get('title','') + ' : DepthMap {}'.format(i)) 29 30 def set_output(self,output): 31 self.output = output 32 33 def get_data(self): 34 return self.data 35 36 def get_target(self): 37 return self.target 38 39 class SequenceVideoLoader(implements(Loader)): 40 def __init__(self,source,crop_size,num_frames=3): 41 self.source = source 42 self.crop_size = crop_size 43 self.num_frames = num_frames 44 45 def __next__(self): 46 # randomly pick 3 frames in a row 47 num_frames_in_src = len(self.source) 48 # print("Number of frames in src {}".format(num_frames_in_src)) 49 50 # 1) choose the first frame from 0 -> N-2 51 frames = [] 52 first_frame = random.randint(0,num_frames_in_src - self.num_frames) 53 for i in range(0,self.num_frames): 54 frames.append(self.source[first_frame+i]) 55 56 # 2) generate a random perturbation and perturb all the frames 57 # note, need to apply same perts to all frames 58 perturb_params = {'translation_range':[0.0,0.0], 59 'scaling_range':[1.0,1.0]} 60 perturbed_frames = [] 61 for f in frames: 62 perturbed_frame = RandomPerturber.perturb_frame(f,perturb_params) 63 crop_affine = resize_image_center_crop(perturbed_frame.image,self.crop_size) 64 output_size = [self.crop_size[1],self.crop_size[0]] 65 perturbed_frame = apply_affine_to_frame(perturbed_frame,crop_affine,output_size) 66 perturbed_frames.append(perturbed_frame) 67 # perturbed_frame.visualize(title='chw_image',display=True) 68 69 # 3) prepare tensors 70 # -make a tensor with a stack of 3 frame 71 # -add the calibration to the target 72 input_tensors = [] 73 calib_mats = [] 74 for f in perturbed_frames: 75 img = f.image.to_order_and_class(Ordering.CHW,ValueClass.FLOAT01) 76 input_tensors.append(torch.Tensor(img.get_data().astype(float))) 77 calib_mats.append(torch.Tensor(f.calib_mat)) 78 79 # the input is now 3xCxWxH 80 sample = SequenceVideoSample([torch.stack(input_tensors,dim=0)], 81 [torch.stack(calib_mats)]) 82 return sample 83 84 85 86 87