/ pytlib / image / ptimage.py
ptimage.py
  1  from __future__ import absolute_import
  2  from builtins import object
  3  import torch
  4  import copy
  5  import os
  6  import numpy as np
  7  from PIL import Image
  8  import matplotlib.pyplot as plt
  9  from .box import Box
 10  # This is a general representation of images
 11  # and act as a mediator between different types and storage orders
 12  # here is main use case:
 13  # 1) Load image using PIL to PIL image format
 14  # 2) store image as HWC numpy array
 15  # 3) apply perturbations/affine transforms
 16  # 3) scale and transpose to chw (cudnn format)
 17  # 4) convert to pytorch tensor for NN compute
 18  #
 19  # to take the network output and convert back
 20  # 1) unscale and tranpose to HWC and convert to numpy
 21  # Note on cudnn storage order is BCHW and PIL uses HWC arrays
 22  #
 23  # for Numpy 'C' Style row-major arrays, the first dimension is the
 24  # slowest changing dimension (last is fastest changing), and thus continguous slices of memory is across the last dim
 25  # so for numpy c-style arraynd, we should prefer BCHW for accessing single elements from a batch
 26  # the pytorch tensor should also have this memory layout
 27  
 28  class Ordering(object):
 29      CHW = 'CHW'
 30      HWC = 'HWC'
 31  
 32  class ValueClass(object):
 33      FLOAT01 = {'dtype':'float','range':[0,1]}
 34      BYTE0255 = {'dtype':'uint8','range':[0,255]}
 35  
 36      @staticmethod
 37      def custom_value_class(d,r):
 38          return {'dtype':d,'range':r}
 39  
 40  class PTImage(object):
 41      def __init__(self,data=None,pil_image_path='',ordering=Ordering.HWC,vc=ValueClass.BYTE0255,persist=True):
 42          self.image_path = pil_image_path
 43          self.ordering = ordering
 44          self.vc = vc
 45          self.persist = persist
 46          self.__data = data # numpy array
 47  
 48      def copy(self):
 49          copy = PTImage(pil_image_path=self.image_path,ordering=self.ordering,vc=self.vc)
 50          if self.__data is not None:
 51              copy.__data = np.copy(self.__data)
 52          return copy
 53  
 54      def get_data(self):
 55          if self.__data is None:
 56              assert os.path.isfile(self.image_path), "cant open file: %s" % self.image_path
 57              tmp_data = np.asarray(Image.open(self.image_path, 'r'))
 58              if self.persist:
 59                  self.__data = tmp_data
 60              return tmp_data
 61          else:
 62              return self.__data
 63  
 64      def get_pil_image(self):
 65          transform_image = self.to_order_and_class(Ordering.HWC,ValueClass.BYTE0255)
 66          return Image.fromarray(transform_image.get_data())
 67  
 68      @staticmethod
 69      def scale_np_img(image,input_range,output_range,output_type=float):
 70          assert len(input_range)==2 and len(output_range)==2
 71          scale = float(output_range[1] - output_range[0])/(input_range[1] - input_range[0])
 72          offset = output_range[0] - input_range[0]*scale;
 73          return (image*scale+offset).astype(output_type);
 74  
 75      def visualize(self,axes=None,display=False,title='PTImage Visualization'):
 76          # TODO if already in the right order, don't both converting
 77          display_img = self.to_order_and_class(Ordering.HWC,ValueClass.BYTE0255)
 78          fig,cur_ax = None,None
 79          if axes is None:
 80              fig,cur_ax = plt.subplots(1,figsize=(15, 8))
 81              fig.canvas.set_window_title(title)
 82          else:
 83              cur_ax = axes
 84          # cur_ax.imshow(display_img.get_data())
 85          cur_ax.imshow(display_img.get_data().squeeze(),vmin=0,vmax=255)
 86          if display:
 87              plt.show(block=True)
 88              plt.close()
 89          return cur_ax
 90          
 91      # makes a copy
 92      def to_order_and_class(self,new_ordering,new_value_class):
 93          new_data = None
 94  
 95          if self.ordering == new_ordering:
 96              new_data = self.get_data()
 97          elif self.ordering == Ordering.CHW and new_ordering == Ordering.HWC:
 98              new_data = np.transpose(self.get_data(),axes=(1,2,0))
 99          elif self.ordering == Ordering.HWC and new_ordering == Ordering.CHW:
100              new_data = np.transpose(self.get_data(),axes=(2,0,1))
101          else:
102              assert False, 'Dont know how to convert to this ordering'
103  
104          if self.vc != new_value_class:
105              new_data = PTImage.scale_np_img(new_data,self.vc['range'],new_value_class['range'],new_value_class['dtype'])
106  
107          new_img = PTImage(data=new_data,ordering=new_ordering,vc=new_value_class)
108          return new_img
109  
110      def get_dims(self):
111          return np.array(self.get_data().shape)
112  
113      def get_bounding_box(self):
114          return Box.from_single_array(np.array([0,0,self.get_wh()[0],self.get_wh()[1]]))
115  
116      # get height and width, in that order
117      def get_wh(self):
118          shape = self.get_data().shape
119          if self.ordering == Ordering.CHW:
120              return np.array([shape[2],shape[1]])
121          else:
122              return np.array([shape[1],shape[0]])
123  
124      # get height and width, in that order
125      def get_hw(self):
126          shape = self.get_data().shape
127          if self.ordering == Ordering.CHW:
128              return np.array([shape[1],shape[2]])
129          else:
130              return np.array([shape[0],shape[1]])
131  
132      @classmethod
133      def from_numpy_array(cls,np_array,ordering=Ordering.HWC,vc=ValueClass.BYTE0255):
134          return cls(data=np_array,ordering=ordering,vc=vc)
135  
136      @classmethod
137      def from_pil_image(cls,pil_img):
138          return cls(data=np.asarray(pil_img),ordering=Ordering.HWC,vc=ValueClass.BYTE0255)
139  
140      @classmethod
141      def from_cwh_torch(cls,torch_img):
142          return cls(data=torch_img.detach().cpu().numpy(),ordering=Ordering.CHW,vc=ValueClass.FLOAT01)
143  
144      @classmethod
145      def from_2d_numpy(cls,map2d):
146          # assumes img2d has 2 dimensions
147          assert len(map2d.shape)==2, 'img2d must have only 2 dimenions, found {}'.format(map2d.shape)
148          map3d = np.expand_dims(map2d, axis=0)
149          map3d = np.repeat(map3d,3,axis=0)
150          # import ipdb;ipdb.set_trace()
151          return cls(data=map3d,ordering=Ordering.CHW,vc=ValueClass.FLOAT01)
152  
153      @classmethod
154      def from_2d_wh_torch(cls,img2d,log_scale=False):
155          # assumes img2d has 2 dimensions
156          map2d = img2d.detach().cpu().numpy().squeeze()
157          assert len(map2d.shape)==2, 'img2d must have only 2 dimenions, found {}'.format(map2d.shape)
158          map3d = np.expand_dims(map2d, axis=0)
159          map3d = np.repeat(map3d,3,axis=0)
160          if log_scale:
161              mmin,mmax = np.min(map3d),np.max(map3d)
162              map3d = np.log(map3d-mmin+1)
163          vc=ValueClass.custom_value_class('float',[np.min(map3d),np.max(map3d)])
164          return cls(data=map3d,ordering=Ordering.CHW,vc=vc)