/ pytlib / data_loading / loaders / multiobject_detection_loader.py
multiobject_detection_loader.py
  1  from __future__ import print_function
  2  from builtins import zip
  3  from builtins import range
  4  from image.frame import Frame
  5  from image.box import Box
  6  from image.object import Object
  7  from data_loading.loaders.loader import Loader
  8  from data_loading.sample import Sample
  9  from image.affine import Affine
 10  from excepts.general_exceptions import NoFramesException
 11  from utils.dict_utils import get_deep
 12  from image.ptimage import PTImage,Ordering,ValueClass
 13  from image.random_perturber import RandomPerturber
 14  from image.image_utils import draw_objects_on_np_image
 15  from image.affine_transforms import resize_image_center_crop,apply_affine_to_frame
 16  import numpy as np
 17  import random
 18  import torch
 19  from interface import implements
 20  from visualization.image_visualizer import ImageVisualizer
 21  from networks.multi_object_detector import MultiObjectDetector
 22  
 23  class MultiObjectDetectionSample(implements(Sample)):
 24      def __init__(self,data,target,class_lookup=dict()):
 25          self.data = data
 26          self.target = target
 27          self.output = None
 28          # a dictionary of value to name for decoding the class
 29          self.class_lookup = class_lookup
 30  
 31      def __convert_to_objects(self,boxes,classes):
 32          boxlist = Box.tensor_to_boxes(boxes.cpu())
 33          objects = []
 34          for x,y in zip(*(boxlist,classes.cpu().numpy())):
 35              objects.append(Object(x,0,self.class_lookup[y]))
 36          return objects
 37  
 38      def visualize(self,parameters={}):
 39          image_original = PTImage.from_cwh_torch(self.data[0])
 40          drawing_image = image_original.to_order_and_class(Ordering.HWC,ValueClass.BYTE0255).get_data().copy()
 41  
 42          boxes,classes = self.output[1:]
 43          # Nx4 boxes and N class tensor 
 44          valid_boxes, valid_classes = MultiObjectDetector.post_process_boxes(self.data[0],boxes,classes,len(self.class_lookup))
 45          # convert targets
 46          real_targets = self.target[0][:,0]>-1
 47          filtered_targets = self.target[0][real_targets].reshape(-1,self.target[0].shape[1])
 48          target_boxes = filtered_targets[:,1:]
 49          target_classes = filtered_targets[:,0]
 50  
 51          if target_boxes.shape[0]>0:
 52              draw_objects_on_np_image(drawing_image,self.__convert_to_objects(target_boxes,target_classes),color=(255,0,0))
 53          if valid_boxes.shape[0]>0:
 54              draw_objects_on_np_image(drawing_image,self.__convert_to_objects(valid_boxes,valid_classes),color=None)   
 55          ImageVisualizer().set_image(PTImage(drawing_image),parameters.get('title','') + ' : Output')
 56  
 57      def set_output(self,output):
 58          self.output = output
 59  
 60      def get_data(self):
 61          return self.data
 62  
 63      def get_target(self):
 64          return self.target
 65  
 66  # loads frames with multiple objects as targets
 67  class MultiObjectDetectionLoader(implements(Loader)):
 68      def __init__(self,source,crop_size,max_objects=100,obj_types=None):
 69          self.source = source
 70          self.crop_size = crop_size
 71          self.obj_types = set(obj_types)
 72          self.max_objects = max_objects
 73          self.frame_ids = []
 74  
 75          #index all the frames that have at least one item we want
 76          # TODO turn this into a re-usable filter module
 77          for i,frame in enumerate(self.source):
 78              crop_objs = [x for x in frame.get_objects() if not self.obj_types or x.obj_type in self.obj_types]
 79              if(len(crop_objs)>0):
 80                  self.frame_ids.append(i)
 81  
 82          print('The source has {0} items'.format(len(self.source)))
 83          if len(self.frame_ids)==0:
 84              raise NoFramesException('No Valid Frames Found!')
 85  
 86          print('{0} frames found'.format(len(self.frame_ids)))
 87  
 88      def __next__(self):
 89          # 1) pick a random frame
 90          frame = self.source[random.choice(self.frame_ids)]
 91  
 92          # 2) generate a random perturbation and perturb the frame
 93          perturb_params = {'translation_range':[-0.1,0.1],
 94                            'scaling_range':[0.9,1.1]}
 95          perturbed_frame = RandomPerturber.perturb_frame(frame,perturb_params)
 96          crop_affine = resize_image_center_crop(perturbed_frame.image,self.crop_size)
 97          output_size = [self.crop_size[1],self.crop_size[0]]
 98          perturbed_frame = apply_affine_to_frame(perturbed_frame,crop_affine,output_size)
 99          # perturbed_frame.visualize(title='chw_image',display=True)
100  
101          # 3) encode the objects into targets with size that does not exceed max_objects
102          # if there are more objects than max_objects, the remaining ones are dropped.
103          # vector of length max_objects
104          # each comp vector has the form [class(1),bbox(4)]
105          # a padding target is used to represent a non-existent object, this has the form [-1,-1,-1,-1,-1]
106          
107          # create the padding vector
108          class_encoding,class_decoding = dict(),dict()
109          padvec = [np.array([-1]*5) for i in range(self.max_objects)]
110          for i,obj in enumerate(perturbed_frame.objects[0:min(self.max_objects,len(perturbed_frame.objects))]):
111              if obj.obj_type not in self.obj_types:
112                  continue
113              if obj.obj_type not in class_encoding:
114                  code = len(class_encoding)
115                  class_encoding[obj.obj_type] = code
116                  class_decoding[code] = obj.obj_type
117              box_coords = obj.box.to_single_array()
118              padvec[i] = np.concatenate((np.array([class_encoding[obj.obj_type]]),box_coords),axis=0)
119          chw_image = perturbed_frame.image.to_order_and_class(Ordering.CHW,ValueClass.FLOAT01)
120          sample = MultiObjectDetectionSample([torch.Tensor(chw_image.get_data().astype(float))],
121                                              [torch.Tensor(padvec)],
122                                              class_decoding)
123          return sample
124  
125  
126  
127  
128