/ pytlib / visualization / image_visualizer.py
image_visualizer.py
 1  from __future__ import division
 2  from __future__ import print_function
 3  from builtins import range
 4  from builtins import object
 5  from past.utils import old_div
 6  import sys
 7  import numpy as np
 8  import math
 9  import matplotlib.pyplot as plt
10  
11  # Singleton class for visualization images during training, similiar to the logger
12  class ImageVisualizer(object):
13      instance = None
14  
15      class __ImageVisualizer(object):
16          def __init__(self):
17              # raise exception here
18              self.cur_images = dict()
19  
20          def dump_image(self,output_file,display=False,save=True,max_cols=7):
21              rows = int(math.ceil(old_div(float(len(self.cur_images)), max_cols)))
22              cols = max_cols if len(self.cur_images) > max_cols else len(self.cur_images)
23              if rows==0 or cols==0:
24                  print("Nothing to visualize...")
25                  return
26              fig,axes = plt.subplots(rows,cols,figsize=(cols*3,rows*3))
27              # fig.subplots_adjust(hspace=0.5, wspace=0.5)
28              fig.canvas.set_window_title('Visualizations')
29              # print rows, cols
30              if rows==1 and cols==1:
31                  axes = [[axes]]
32              elif rows==1 and cols>1:
33                  axes = [axes]
34  
35              for i in range(0,rows):
36                  for j in range(0,cols):
37                      axes[i][j].axis('off')
38  
39              for i,(key,image) in enumerate(sorted(self.cur_images.items())):
40                  (r,c) = np.unravel_index(i, (rows,cols))
41                  ax = axes[r][c]
42                  ax.set_title(key,fontsize=16)
43                  ax.set_xticklabels([])
44                  ax.set_yticklabels([])                
45                  image.visualize(axes=ax,display=False)
46              # plt.tight_layout()
47  
48              if save:
49                  fig.savefig(output_file)
50  
51              if display:
52                  plt.show(block=True)
53              plt.close("all")
54  
55              self.cur_images.clear()
56  
57  
58          def set_image(self,pt_image,key):
59              self.cur_images[key]=pt_image
60  
61          def clear(self):
62              self.cur_images.clear()
63  
64      def __init__(self):
65          if not ImageVisualizer.instance:
66              ImageVisualizer.instance = ImageVisualizer.__ImageVisualizer()
67  
68      def __getattr__(self,name):
69          return getattr(self.instance,name)