image_visualizer.py
1 from __future__ import division 2 from __future__ import print_function 3 from builtins import range 4 from builtins import object 5 from past.utils import old_div 6 import sys 7 import numpy as np 8 import math 9 import matplotlib.pyplot as plt 10 11 # Singleton class for visualization images during training, similiar to the logger 12 class ImageVisualizer(object): 13 instance = None 14 15 class __ImageVisualizer(object): 16 def __init__(self): 17 # raise exception here 18 self.cur_images = dict() 19 20 def dump_image(self,output_file,display=False,save=True,max_cols=7): 21 rows = int(math.ceil(old_div(float(len(self.cur_images)), max_cols))) 22 cols = max_cols if len(self.cur_images) > max_cols else len(self.cur_images) 23 if rows==0 or cols==0: 24 print("Nothing to visualize...") 25 return 26 fig,axes = plt.subplots(rows,cols,figsize=(cols*3,rows*3)) 27 # fig.subplots_adjust(hspace=0.5, wspace=0.5) 28 fig.canvas.set_window_title('Visualizations') 29 # print rows, cols 30 if rows==1 and cols==1: 31 axes = [[axes]] 32 elif rows==1 and cols>1: 33 axes = [axes] 34 35 for i in range(0,rows): 36 for j in range(0,cols): 37 axes[i][j].axis('off') 38 39 for i,(key,image) in enumerate(sorted(self.cur_images.items())): 40 (r,c) = np.unravel_index(i, (rows,cols)) 41 ax = axes[r][c] 42 ax.set_title(key,fontsize=16) 43 ax.set_xticklabels([]) 44 ax.set_yticklabels([]) 45 image.visualize(axes=ax,display=False) 46 # plt.tight_layout() 47 48 if save: 49 fig.savefig(output_file) 50 51 if display: 52 plt.show(block=True) 53 plt.close("all") 54 55 self.cur_images.clear() 56 57 58 def set_image(self,pt_image,key): 59 self.cur_images[key]=pt_image 60 61 def clear(self): 62 self.cur_images.clear() 63 64 def __init__(self): 65 if not ImageVisualizer.instance: 66 ImageVisualizer.instance = ImageVisualizer.__ImageVisualizer() 67 68 def __getattr__(self,name): 69 return getattr(self.instance,name)