/ pytlib / data_loading / loaders / autoencoder_loader.py
autoencoder_loader.py
  1  from __future__ import division
  2  from __future__ import print_function
  3  from past.utils import old_div
  4  from image.frame import Frame
  5  from image.box import Box
  6  from image.object import Object
  7  from data_loading.loaders.loader import Loader
  8  from data_loading.sample import Sample
  9  from image.affine import Affine
 10  from excepts.general_exceptions import NoFramesException
 11  from utils.dict_utils import get_deep
 12  from image.ptimage import PTImage,Ordering,ValueClass
 13  from image.random_perturber import RandomPerturber
 14  import numpy as np
 15  import random
 16  import torch
 17  from interface import implements
 18  from visualization.image_visualizer import ImageVisualizer
 19  
 20  # This is a sample where the image and the target are both images
 21  class AutoEncoderSample(implements(Sample)):
 22      def __init__(self,data,target):
 23          self.data = data
 24          self.target = target
 25          self.output = None
 26  
 27      def visualize(self,parameters={}):
 28          # here output[0] could either be a single image or a sequence of images
 29  
 30          if isinstance(self.output[0],list):
 31              image_target = PTImage.from_cwh_torch(self.target[0])
 32              ImageVisualizer().set_image(image_target,parameters.get('title','') + ' : Target')
 33              for i,o in enumerate(self.output[0]):
 34                  image_output = PTImage.from_cwh_torch(o)
 35                  ImageVisualizer().set_image(image_output,parameters.get('title','') + ' : Output{:02d}'.format(i))                
 36          else:
 37              image_target = PTImage.from_cwh_torch(self.target[0])
 38              image_output = PTImage.from_cwh_torch(self.output[0])
 39              ImageVisualizer().set_image(image_target,parameters.get('title','') + ' : Target')
 40              ImageVisualizer().set_image(image_output,parameters.get('title','') + ' : Output')
 41  
 42      # specific to the AE sample, the first element of the output has the same shape as the target
 43      def set_output(self,output):
 44          # assert output[0].size() == self.target[0].size()
 45          self.output = output
 46  
 47      def get_data(self):
 48          return self.data
 49  
 50      def get_target(self):
 51          return self.target
 52  
 53  # This loader provides images that contains a single item of something
 54  # we want to encode, the target tensor includes both the crop itself and the coordinates
 55  # random perturbations of the crop is an optional parameter
 56  class AutoEncoderLoader(implements(Loader)):
 57  
 58      def __init__(self,source,crop_size,obj_types=None):
 59          self.source = source
 60          self.crop_size = crop_size
 61          self.obj_types = obj_types
 62          self.frame_ids = []
 63  
 64          #index all the frames that have at least one item we want
 65          # TODO turn this into a re-usable filter module
 66          for i,frame in enumerate(self.source):
 67              crop_objs = [x for x in frame.get_objects() if not self.obj_types or x.obj_type in self.obj_types]
 68              if(len(crop_objs)>0):
 69                  self.frame_ids.append(i)
 70  
 71          print('The source has {0} items'.format(len(self.source)))
 72          if len(self.frame_ids)==0:
 73              raise NoFramesException('No Valid Frames Found!')
 74  
 75          print('{0} frames found'.format(len(self.frame_ids)))
 76  
 77      def __next__(self):
 78          # just grab the next random frame
 79          frame = self.source[random.choice(self.frame_ids)]
 80          # frame.show_image_with_labels()
 81          # get a random crop object
 82          crop_objs = [x for x in frame.get_objects() if not self.obj_types or x.obj_type in self.obj_types]
 83          # print 'Num crop objs in sample: {0}'.format(len(crop_objs))
 84          crop = random.choice(crop_objs)
 85          # print 'crop_box: ' + str(crop.box)
 86  
 87          # frame.show_image_with_labels()
 88  
 89          # 1) Randomly perturb crop box (scale and translation)
 90          transformed_box = RandomPerturber.perturb_crop_box(crop.box,{})
 91  
 92          # 2) Take crop, todo, change to using center crop to preserve aspect ratio
 93          # check if the affine is identity within some toleranc, then don't bother applying
 94          affine = Affine()
 95          scalex = old_div(float(self.crop_size[0]),transformed_box.edges()[0])
 96          scaley = old_div(float(self.crop_size[1]),transformed_box.edges()[1])
 97          affine.append(Affine.translation(-transformed_box.xy_min()))
 98          affine.append(Affine.scaling((scalex,scaley)))
 99  
100          transformed_image = affine.apply_to_image(frame.image,self.crop_size) 
101          # transformed_image.visualize(title='transformed image')
102  
103          # 3) Randomly perturb cropped image (rotation only)
104  
105          chw_image = transformed_image.to_order_and_class(Ordering.CHW,ValueClass.FLOAT01)
106          # chw_image.visualize(title='chw_image')
107          sample = AutoEncoderSample([torch.Tensor(chw_image.get_data().astype(float))],
108                                     [torch.Tensor(chw_image.get_data().astype(float))])
109          return sample