multiobject_detection_loader.py
1 from __future__ import print_function 2 from builtins import zip 3 from builtins import range 4 from image.frame import Frame 5 from image.box import Box 6 from image.object import Object 7 from data_loading.loaders.loader import Loader 8 from data_loading.sample import Sample 9 from image.affine import Affine 10 from excepts.general_exceptions import NoFramesException 11 from utils.dict_utils import get_deep 12 from image.ptimage import PTImage,Ordering,ValueClass 13 from image.random_perturber import RandomPerturber 14 from image.image_utils import draw_objects_on_np_image 15 from image.affine_transforms import resize_image_center_crop,apply_affine_to_frame 16 import numpy as np 17 import random 18 import torch 19 from interface import implements 20 from visualization.image_visualizer import ImageVisualizer 21 from networks.multi_object_detector import MultiObjectDetector 22 23 class MultiObjectDetectionSample(implements(Sample)): 24 def __init__(self,data,target,class_lookup=dict()): 25 self.data = data 26 self.target = target 27 self.output = None 28 # a dictionary of value to name for decoding the class 29 self.class_lookup = class_lookup 30 31 def __convert_to_objects(self,boxes,classes): 32 boxlist = Box.tensor_to_boxes(boxes.cpu()) 33 objects = [] 34 for x,y in zip(*(boxlist,classes.cpu().numpy())): 35 objects.append(Object(x,0,self.class_lookup[y])) 36 return objects 37 38 def visualize(self,parameters={}): 39 image_original = PTImage.from_cwh_torch(self.data[0]) 40 drawing_image = image_original.to_order_and_class(Ordering.HWC,ValueClass.BYTE0255).get_data().copy() 41 42 boxes,classes = self.output[1:] 43 # Nx4 boxes and N class tensor 44 valid_boxes, valid_classes = MultiObjectDetector.post_process_boxes(self.data[0],boxes,classes,len(self.class_lookup)) 45 # convert targets 46 real_targets = self.target[0][:,0]>-1 47 filtered_targets = self.target[0][real_targets].reshape(-1,self.target[0].shape[1]) 48 target_boxes = filtered_targets[:,1:] 49 target_classes = filtered_targets[:,0] 50 51 if target_boxes.shape[0]>0: 52 draw_objects_on_np_image(drawing_image,self.__convert_to_objects(target_boxes,target_classes),color=(255,0,0)) 53 if valid_boxes.shape[0]>0: 54 draw_objects_on_np_image(drawing_image,self.__convert_to_objects(valid_boxes,valid_classes),color=None) 55 ImageVisualizer().set_image(PTImage(drawing_image),parameters.get('title','') + ' : Output') 56 57 def set_output(self,output): 58 self.output = output 59 60 def get_data(self): 61 return self.data 62 63 def get_target(self): 64 return self.target 65 66 # loads frames with multiple objects as targets 67 class MultiObjectDetectionLoader(implements(Loader)): 68 def __init__(self,source,crop_size,max_objects=100,obj_types=None): 69 self.source = source 70 self.crop_size = crop_size 71 self.obj_types = set(obj_types) 72 self.max_objects = max_objects 73 self.frame_ids = [] 74 75 #index all the frames that have at least one item we want 76 # TODO turn this into a re-usable filter module 77 for i,frame in enumerate(self.source): 78 crop_objs = [x for x in frame.get_objects() if not self.obj_types or x.obj_type in self.obj_types] 79 if(len(crop_objs)>0): 80 self.frame_ids.append(i) 81 82 print('The source has {0} items'.format(len(self.source))) 83 if len(self.frame_ids)==0: 84 raise NoFramesException('No Valid Frames Found!') 85 86 print('{0} frames found'.format(len(self.frame_ids))) 87 88 def __next__(self): 89 # 1) pick a random frame 90 frame = self.source[random.choice(self.frame_ids)] 91 92 # 2) generate a random perturbation and perturb the frame 93 perturb_params = {'translation_range':[-0.1,0.1], 94 'scaling_range':[0.9,1.1]} 95 perturbed_frame = RandomPerturber.perturb_frame(frame,perturb_params) 96 crop_affine = resize_image_center_crop(perturbed_frame.image,self.crop_size) 97 output_size = [self.crop_size[1],self.crop_size[0]] 98 perturbed_frame = apply_affine_to_frame(perturbed_frame,crop_affine,output_size) 99 # perturbed_frame.visualize(title='chw_image',display=True) 100 101 # 3) encode the objects into targets with size that does not exceed max_objects 102 # if there are more objects than max_objects, the remaining ones are dropped. 103 # vector of length max_objects 104 # each comp vector has the form [class(1),bbox(4)] 105 # a padding target is used to represent a non-existent object, this has the form [-1,-1,-1,-1,-1] 106 107 # create the padding vector 108 class_encoding,class_decoding = dict(),dict() 109 padvec = [np.array([-1]*5) for i in range(self.max_objects)] 110 for i,obj in enumerate(perturbed_frame.objects[0:min(self.max_objects,len(perturbed_frame.objects))]): 111 if obj.obj_type not in self.obj_types: 112 continue 113 if obj.obj_type not in class_encoding: 114 code = len(class_encoding) 115 class_encoding[obj.obj_type] = code 116 class_decoding[code] = obj.obj_type 117 box_coords = obj.box.to_single_array() 118 padvec[i] = np.concatenate((np.array([class_encoding[obj.obj_type]]),box_coords),axis=0) 119 chw_image = perturbed_frame.image.to_order_and_class(Ordering.CHW,ValueClass.FLOAT01) 120 sample = MultiObjectDetectionSample([torch.Tensor(chw_image.get_data().astype(float))], 121 [torch.Tensor(padvec)], 122 class_decoding) 123 return sample 124 125 126 127 128