/ pytlib / data_loading / loaders / semantic_segmentation_loader.py
semantic_segmentation_loader.py
  1  from __future__ import print_function
  2  from builtins import range
  3  from image.frame import Frame
  4  from image.box import Box
  5  from image.object import Object
  6  from data_loading.loaders.loader import Loader
  7  from data_loading.sample import Sample
  8  from image.affine import Affine
  9  from excepts.general_exceptions import NoFramesException
 10  from utils.dict_utils import get_deep
 11  from image.ptimage import PTImage,Ordering,ValueClass
 12  from image.random_perturber import RandomPerturber
 13  from image.polygon import Polygon
 14  from image.affine_transforms import resize_image_center_crop,apply_affine_to_frame
 15  import numpy as np
 16  import random
 17  import torch
 18  from interface import implements
 19  from visualization.image_visualizer import ImageVisualizer
 20  
 21  class SegmentationSample(implements(Sample)):
 22      def __init__(self,data,target,class_lookup=dict()):
 23          self.data = data
 24          self.target = target
 25          self.output = None
 26          # a dictionary of value to name for decoding the class
 27          self.class_lookup = dict(class_lookup)
 28  
 29      def visualize(self,parameters={}):
 30          image_original = PTImage.from_cwh_torch(self.data[0])
 31          ImageVisualizer().set_image(image_original,parameters.get('title','') + ' : Input')
 32          # need to draw the mask layers ontop of the data with transparency
 33          target_mask_chw = self.target[0]
 34          output_mask_chw = self.output[0][-1]
 35          # draw a separate image for each channel for now
 36          for i in range(target_mask_chw.size(0)):
 37              imt = PTImage.from_cwh_torch(target_mask_chw[i,:,:].unsqueeze(0))
 38              imo = PTImage.from_cwh_torch(output_mask_chw[i,:,:].unsqueeze(0))
 39              ImageVisualizer().set_image(imt,parameters.get('title','') + ' : Target-{}'.format(self.class_lookup[i]))
 40              ImageVisualizer().set_image(imo,parameters.get('title','') + ' : LOutput-{}'.format(self.class_lookup[i]))             
 41  
 42      def set_output(self,output):
 43          self.output = output
 44  
 45      def get_data(self):
 46          return self.data
 47  
 48      def get_target(self):
 49          return self.target
 50  
 51  # This is for semantic segmentation taskes, not instance segmentation
 52  class SegmentationLoader(implements(Loader)):
 53      def __init__(self,source,crop_size,max_frames=1e8,obj_types=set()):
 54          self.source = source
 55          self.crop_size = crop_size
 56          self.obj_types_to_ids = dict()
 57          self.ids_to_obj_types = dict()
 58          self.frame_ids = []
 59          self.max_frames = max_frames
 60  
 61          print('SegmentationLoader: finding valid frames')
 62          # parallelize this, this is too slow
 63          for i,frame in enumerate(self.source):
 64              if len(self.frame_ids)>=self.max_frames:
 65                  break
 66              valid_obj_count = 0
 67              for obj in frame.get_objects():
 68                  if obj.obj_type in obj_types or not obj_types:
 69                      valid_obj_count+=1
 70                      self.obj_types_to_ids[obj.obj_type]=self.obj_types_to_ids.get(obj.obj_type,len(self.obj_types_to_ids))
 71                      oid = self.obj_types_to_ids[obj.obj_type]
 72                      self.ids_to_obj_types[oid]=obj.obj_type
 73              if valid_obj_count>0:
 74                  self.frame_ids.append(i)
 75          print('The source has {0} items'.format(len(self.source)))
 76          if len(self.frame_ids)==0:
 77              raise NoFramesException('No Valid Frames Found!')
 78  
 79          print('{0} frames and {1} classes found'.format(len(self.frame_ids),len(self.obj_types_to_ids)))
 80  
 81      def __next__(self):
 82          # 1) pick a random frame
 83          frame = self.source[random.choice(self.frame_ids)]
 84          # 2) generate a random perturbation and perturb the frame, this also perturbs the objects including segementation polygons
 85          perturbed_frame = RandomPerturber.perturb_frame(frame,{})
 86          # 3) scale the perturbed frame to the desired input resolution
 87          crop_affine = resize_image_center_crop(perturbed_frame.image,self.crop_size)
 88          perturbed_frame = apply_affine_to_frame(perturbed_frame,crop_affine,self.crop_size)
 89          # visualize the perturbed_frame along with its perturbed objects and masks here
 90          # perturbed_frame.visualize(display=True)
 91  
 92          # 3) for each object type, produce a merged binary mask over the frame, 
 93          # this results in a w x h x k target map where k is the number of classes in consideration 
 94          # for now we will use the pycocotool's merge and polygon mapping functions since they are implemented in c
 95          # although I prefer to not have this dependency
 96          # loop over all object type and create a binary mask for each
 97          # declare a np array of whk
 98          masks = np.zeros(perturbed_frame.image.get_hw().tolist()+[len(self.obj_types_to_ids)])
 99          for k,v in list(self.obj_types_to_ids.items()):
100              # a) for all objs in the frame that belong to this type, create a merged mask
101              polygons = []
102              for obj in perturbed_frame.get_objects():
103                  if obj.obj_type==k:
104                      polygons.extend(obj.polygons)
105              masks[:,:,v] = Polygon.create_mask(polygons,perturbed_frame.image.get_wh()[0],perturbed_frame.image.get_wh()[1])
106  
107          # 4) create the segmentation sample
108          chw_image = perturbed_frame.image.to_order_and_class(Ordering.CHW,ValueClass.FLOAT01)
109          # transpose the mask
110  
111          chw_mask = np.transpose(masks,axes=(2,0,1))
112  
113          # chw_image.visualize(title='chw_image')
114          sample = SegmentationSample([torch.Tensor(chw_image.get_data().astype(float))],
115                                      [torch.Tensor(chw_mask)],
116                                      self.ids_to_obj_types)        
117  
118          return sample
119  
120  
121  
122  
123