ptimage.py
1 from __future__ import absolute_import 2 from builtins import object 3 import torch 4 import copy 5 import os 6 import numpy as np 7 from PIL import Image 8 import matplotlib.pyplot as plt 9 from .box import Box 10 # This is a general representation of images 11 # and act as a mediator between different types and storage orders 12 # here is main use case: 13 # 1) Load image using PIL to PIL image format 14 # 2) store image as HWC numpy array 15 # 3) apply perturbations/affine transforms 16 # 3) scale and transpose to chw (cudnn format) 17 # 4) convert to pytorch tensor for NN compute 18 # 19 # to take the network output and convert back 20 # 1) unscale and tranpose to HWC and convert to numpy 21 # Note on cudnn storage order is BCHW and PIL uses HWC arrays 22 # 23 # for Numpy 'C' Style row-major arrays, the first dimension is the 24 # slowest changing dimension (last is fastest changing), and thus continguous slices of memory is across the last dim 25 # so for numpy c-style arraynd, we should prefer BCHW for accessing single elements from a batch 26 # the pytorch tensor should also have this memory layout 27 28 class Ordering(object): 29 CHW = 'CHW' 30 HWC = 'HWC' 31 32 class ValueClass(object): 33 FLOAT01 = {'dtype':'float','range':[0,1]} 34 BYTE0255 = {'dtype':'uint8','range':[0,255]} 35 36 @staticmethod 37 def custom_value_class(d,r): 38 return {'dtype':d,'range':r} 39 40 class PTImage(object): 41 def __init__(self,data=None,pil_image_path='',ordering=Ordering.HWC,vc=ValueClass.BYTE0255,persist=True): 42 self.image_path = pil_image_path 43 self.ordering = ordering 44 self.vc = vc 45 self.persist = persist 46 self.__data = data # numpy array 47 48 def copy(self): 49 copy = PTImage(pil_image_path=self.image_path,ordering=self.ordering,vc=self.vc) 50 if self.__data is not None: 51 copy.__data = np.copy(self.__data) 52 return copy 53 54 def get_data(self): 55 if self.__data is None: 56 assert os.path.isfile(self.image_path), "cant open file: %s" % self.image_path 57 tmp_data = np.asarray(Image.open(self.image_path, 'r')) 58 if self.persist: 59 self.__data = tmp_data 60 return tmp_data 61 else: 62 return self.__data 63 64 def get_pil_image(self): 65 transform_image = self.to_order_and_class(Ordering.HWC,ValueClass.BYTE0255) 66 return Image.fromarray(transform_image.get_data()) 67 68 @staticmethod 69 def scale_np_img(image,input_range,output_range,output_type=float): 70 assert len(input_range)==2 and len(output_range)==2 71 scale = float(output_range[1] - output_range[0])/(input_range[1] - input_range[0]) 72 offset = output_range[0] - input_range[0]*scale; 73 return (image*scale+offset).astype(output_type); 74 75 def visualize(self,axes=None,display=False,title='PTImage Visualization'): 76 # TODO if already in the right order, don't both converting 77 display_img = self.to_order_and_class(Ordering.HWC,ValueClass.BYTE0255) 78 fig,cur_ax = None,None 79 if axes is None: 80 fig,cur_ax = plt.subplots(1,figsize=(15, 8)) 81 fig.canvas.set_window_title(title) 82 else: 83 cur_ax = axes 84 # cur_ax.imshow(display_img.get_data()) 85 cur_ax.imshow(display_img.get_data().squeeze(),vmin=0,vmax=255) 86 if display: 87 plt.show(block=True) 88 plt.close() 89 return cur_ax 90 91 # makes a copy 92 def to_order_and_class(self,new_ordering,new_value_class): 93 new_data = None 94 95 if self.ordering == new_ordering: 96 new_data = self.get_data() 97 elif self.ordering == Ordering.CHW and new_ordering == Ordering.HWC: 98 new_data = np.transpose(self.get_data(),axes=(1,2,0)) 99 elif self.ordering == Ordering.HWC and new_ordering == Ordering.CHW: 100 new_data = np.transpose(self.get_data(),axes=(2,0,1)) 101 else: 102 assert False, 'Dont know how to convert to this ordering' 103 104 if self.vc != new_value_class: 105 new_data = PTImage.scale_np_img(new_data,self.vc['range'],new_value_class['range'],new_value_class['dtype']) 106 107 new_img = PTImage(data=new_data,ordering=new_ordering,vc=new_value_class) 108 return new_img 109 110 def get_dims(self): 111 return np.array(self.get_data().shape) 112 113 def get_bounding_box(self): 114 return Box.from_single_array(np.array([0,0,self.get_wh()[0],self.get_wh()[1]])) 115 116 # get height and width, in that order 117 def get_wh(self): 118 shape = self.get_data().shape 119 if self.ordering == Ordering.CHW: 120 return np.array([shape[2],shape[1]]) 121 else: 122 return np.array([shape[1],shape[0]]) 123 124 # get height and width, in that order 125 def get_hw(self): 126 shape = self.get_data().shape 127 if self.ordering == Ordering.CHW: 128 return np.array([shape[1],shape[2]]) 129 else: 130 return np.array([shape[0],shape[1]]) 131 132 @classmethod 133 def from_numpy_array(cls,np_array,ordering=Ordering.HWC,vc=ValueClass.BYTE0255): 134 return cls(data=np_array,ordering=ordering,vc=vc) 135 136 @classmethod 137 def from_pil_image(cls,pil_img): 138 return cls(data=np.asarray(pil_img),ordering=Ordering.HWC,vc=ValueClass.BYTE0255) 139 140 @classmethod 141 def from_cwh_torch(cls,torch_img): 142 return cls(data=torch_img.detach().cpu().numpy(),ordering=Ordering.CHW,vc=ValueClass.FLOAT01) 143 144 @classmethod 145 def from_2d_numpy(cls,map2d): 146 # assumes img2d has 2 dimensions 147 assert len(map2d.shape)==2, 'img2d must have only 2 dimenions, found {}'.format(map2d.shape) 148 map3d = np.expand_dims(map2d, axis=0) 149 map3d = np.repeat(map3d,3,axis=0) 150 # import ipdb;ipdb.set_trace() 151 return cls(data=map3d,ordering=Ordering.CHW,vc=ValueClass.FLOAT01) 152 153 @classmethod 154 def from_2d_wh_torch(cls,img2d,log_scale=False): 155 # assumes img2d has 2 dimensions 156 map2d = img2d.detach().cpu().numpy().squeeze() 157 assert len(map2d.shape)==2, 'img2d must have only 2 dimenions, found {}'.format(map2d.shape) 158 map3d = np.expand_dims(map2d, axis=0) 159 map3d = np.repeat(map3d,3,axis=0) 160 if log_scale: 161 mmin,mmax = np.min(map3d),np.max(map3d) 162 map3d = np.log(map3d-mmin+1) 163 vc=ValueClass.custom_value_class('float',[np.min(map3d),np.max(map3d)]) 164 return cls(data=map3d,ordering=Ordering.CHW,vc=vc)