autoencoder_loader.py
1 from __future__ import division 2 from __future__ import print_function 3 from past.utils import old_div 4 from image.frame import Frame 5 from image.box import Box 6 from image.object import Object 7 from data_loading.loaders.loader import Loader 8 from data_loading.sample import Sample 9 from image.affine import Affine 10 from excepts.general_exceptions import NoFramesException 11 from utils.dict_utils import get_deep 12 from image.ptimage import PTImage,Ordering,ValueClass 13 from image.random_perturber import RandomPerturber 14 import numpy as np 15 import random 16 import torch 17 from interface import implements 18 from visualization.image_visualizer import ImageVisualizer 19 20 # This is a sample where the image and the target are both images 21 class AutoEncoderSample(implements(Sample)): 22 def __init__(self,data,target): 23 self.data = data 24 self.target = target 25 self.output = None 26 27 def visualize(self,parameters={}): 28 # here output[0] could either be a single image or a sequence of images 29 30 if isinstance(self.output[0],list): 31 image_target = PTImage.from_cwh_torch(self.target[0]) 32 ImageVisualizer().set_image(image_target,parameters.get('title','') + ' : Target') 33 for i,o in enumerate(self.output[0]): 34 image_output = PTImage.from_cwh_torch(o) 35 ImageVisualizer().set_image(image_output,parameters.get('title','') + ' : Output{:02d}'.format(i)) 36 else: 37 image_target = PTImage.from_cwh_torch(self.target[0]) 38 image_output = PTImage.from_cwh_torch(self.output[0]) 39 ImageVisualizer().set_image(image_target,parameters.get('title','') + ' : Target') 40 ImageVisualizer().set_image(image_output,parameters.get('title','') + ' : Output') 41 42 # specific to the AE sample, the first element of the output has the same shape as the target 43 def set_output(self,output): 44 # assert output[0].size() == self.target[0].size() 45 self.output = output 46 47 def get_data(self): 48 return self.data 49 50 def get_target(self): 51 return self.target 52 53 # This loader provides images that contains a single item of something 54 # we want to encode, the target tensor includes both the crop itself and the coordinates 55 # random perturbations of the crop is an optional parameter 56 class AutoEncoderLoader(implements(Loader)): 57 58 def __init__(self,source,crop_size,obj_types=None): 59 self.source = source 60 self.crop_size = crop_size 61 self.obj_types = obj_types 62 self.frame_ids = [] 63 64 #index all the frames that have at least one item we want 65 # TODO turn this into a re-usable filter module 66 for i,frame in enumerate(self.source): 67 crop_objs = [x for x in frame.get_objects() if not self.obj_types or x.obj_type in self.obj_types] 68 if(len(crop_objs)>0): 69 self.frame_ids.append(i) 70 71 print('The source has {0} items'.format(len(self.source))) 72 if len(self.frame_ids)==0: 73 raise NoFramesException('No Valid Frames Found!') 74 75 print('{0} frames found'.format(len(self.frame_ids))) 76 77 def __next__(self): 78 # just grab the next random frame 79 frame = self.source[random.choice(self.frame_ids)] 80 # frame.show_image_with_labels() 81 # get a random crop object 82 crop_objs = [x for x in frame.get_objects() if not self.obj_types or x.obj_type in self.obj_types] 83 # print 'Num crop objs in sample: {0}'.format(len(crop_objs)) 84 crop = random.choice(crop_objs) 85 # print 'crop_box: ' + str(crop.box) 86 87 # frame.show_image_with_labels() 88 89 # 1) Randomly perturb crop box (scale and translation) 90 transformed_box = RandomPerturber.perturb_crop_box(crop.box,{}) 91 92 # 2) Take crop, todo, change to using center crop to preserve aspect ratio 93 # check if the affine is identity within some toleranc, then don't bother applying 94 affine = Affine() 95 scalex = old_div(float(self.crop_size[0]),transformed_box.edges()[0]) 96 scaley = old_div(float(self.crop_size[1]),transformed_box.edges()[1]) 97 affine.append(Affine.translation(-transformed_box.xy_min())) 98 affine.append(Affine.scaling((scalex,scaley))) 99 100 transformed_image = affine.apply_to_image(frame.image,self.crop_size) 101 # transformed_image.visualize(title='transformed image') 102 103 # 3) Randomly perturb cropped image (rotation only) 104 105 chw_image = transformed_image.to_order_and_class(Ordering.CHW,ValueClass.FLOAT01) 106 # chw_image.visualize(title='chw_image') 107 sample = AutoEncoderSample([torch.Tensor(chw_image.get_data().astype(float))], 108 [torch.Tensor(chw_image.get_data().astype(float))]) 109 return sample