semantic_segmentation_loader.py
1 from __future__ import print_function 2 from builtins import range 3 from image.frame import Frame 4 from image.box import Box 5 from image.object import Object 6 from data_loading.loaders.loader import Loader 7 from data_loading.sample import Sample 8 from image.affine import Affine 9 from excepts.general_exceptions import NoFramesException 10 from utils.dict_utils import get_deep 11 from image.ptimage import PTImage,Ordering,ValueClass 12 from image.random_perturber import RandomPerturber 13 from image.polygon import Polygon 14 from image.affine_transforms import resize_image_center_crop,apply_affine_to_frame 15 import numpy as np 16 import random 17 import torch 18 from interface import implements 19 from visualization.image_visualizer import ImageVisualizer 20 21 class SegmentationSample(implements(Sample)): 22 def __init__(self,data,target,class_lookup=dict()): 23 self.data = data 24 self.target = target 25 self.output = None 26 # a dictionary of value to name for decoding the class 27 self.class_lookup = dict(class_lookup) 28 29 def visualize(self,parameters={}): 30 image_original = PTImage.from_cwh_torch(self.data[0]) 31 ImageVisualizer().set_image(image_original,parameters.get('title','') + ' : Input') 32 # need to draw the mask layers ontop of the data with transparency 33 target_mask_chw = self.target[0] 34 output_mask_chw = self.output[0][-1] 35 # draw a separate image for each channel for now 36 for i in range(target_mask_chw.size(0)): 37 imt = PTImage.from_cwh_torch(target_mask_chw[i,:,:].unsqueeze(0)) 38 imo = PTImage.from_cwh_torch(output_mask_chw[i,:,:].unsqueeze(0)) 39 ImageVisualizer().set_image(imt,parameters.get('title','') + ' : Target-{}'.format(self.class_lookup[i])) 40 ImageVisualizer().set_image(imo,parameters.get('title','') + ' : LOutput-{}'.format(self.class_lookup[i])) 41 42 def set_output(self,output): 43 self.output = output 44 45 def get_data(self): 46 return self.data 47 48 def get_target(self): 49 return self.target 50 51 # This is for semantic segmentation taskes, not instance segmentation 52 class SegmentationLoader(implements(Loader)): 53 def __init__(self,source,crop_size,max_frames=1e8,obj_types=set()): 54 self.source = source 55 self.crop_size = crop_size 56 self.obj_types_to_ids = dict() 57 self.ids_to_obj_types = dict() 58 self.frame_ids = [] 59 self.max_frames = max_frames 60 61 print('SegmentationLoader: finding valid frames') 62 # parallelize this, this is too slow 63 for i,frame in enumerate(self.source): 64 if len(self.frame_ids)>=self.max_frames: 65 break 66 valid_obj_count = 0 67 for obj in frame.get_objects(): 68 if obj.obj_type in obj_types or not obj_types: 69 valid_obj_count+=1 70 self.obj_types_to_ids[obj.obj_type]=self.obj_types_to_ids.get(obj.obj_type,len(self.obj_types_to_ids)) 71 oid = self.obj_types_to_ids[obj.obj_type] 72 self.ids_to_obj_types[oid]=obj.obj_type 73 if valid_obj_count>0: 74 self.frame_ids.append(i) 75 print('The source has {0} items'.format(len(self.source))) 76 if len(self.frame_ids)==0: 77 raise NoFramesException('No Valid Frames Found!') 78 79 print('{0} frames and {1} classes found'.format(len(self.frame_ids),len(self.obj_types_to_ids))) 80 81 def __next__(self): 82 # 1) pick a random frame 83 frame = self.source[random.choice(self.frame_ids)] 84 # 2) generate a random perturbation and perturb the frame, this also perturbs the objects including segementation polygons 85 perturbed_frame = RandomPerturber.perturb_frame(frame,{}) 86 # 3) scale the perturbed frame to the desired input resolution 87 crop_affine = resize_image_center_crop(perturbed_frame.image,self.crop_size) 88 perturbed_frame = apply_affine_to_frame(perturbed_frame,crop_affine,self.crop_size) 89 # visualize the perturbed_frame along with its perturbed objects and masks here 90 # perturbed_frame.visualize(display=True) 91 92 # 3) for each object type, produce a merged binary mask over the frame, 93 # this results in a w x h x k target map where k is the number of classes in consideration 94 # for now we will use the pycocotool's merge and polygon mapping functions since they are implemented in c 95 # although I prefer to not have this dependency 96 # loop over all object type and create a binary mask for each 97 # declare a np array of whk 98 masks = np.zeros(perturbed_frame.image.get_hw().tolist()+[len(self.obj_types_to_ids)]) 99 for k,v in list(self.obj_types_to_ids.items()): 100 # a) for all objs in the frame that belong to this type, create a merged mask 101 polygons = [] 102 for obj in perturbed_frame.get_objects(): 103 if obj.obj_type==k: 104 polygons.extend(obj.polygons) 105 masks[:,:,v] = Polygon.create_mask(polygons,perturbed_frame.image.get_wh()[0],perturbed_frame.image.get_wh()[1]) 106 107 # 4) create the segmentation sample 108 chw_image = perturbed_frame.image.to_order_and_class(Ordering.CHW,ValueClass.FLOAT01) 109 # transpose the mask 110 111 chw_mask = np.transpose(masks,axes=(2,0,1)) 112 113 # chw_image.visualize(title='chw_image') 114 sample = SegmentationSample([torch.Tensor(chw_image.get_data().astype(float))], 115 [torch.Tensor(chw_mask)], 116 self.ids_to_obj_types) 117 118 return sample 119 120 121 122 123