/ pytlib / run / tester.py
tester.py
 1  from __future__ import print_function
 2  from __future__ import absolute_import
 3  # this is needed to make matplotlib work without explicitly connect to X
 4  from builtins import map
 5  from builtins import str
 6  from builtins import range
 7  from builtins import object
 8  import matplotlib 
 9  matplotlib.use('svg')
10  
11  import torch
12  from image.ptimage import PTImage
13  import argparse
14  import imp
15  import os
16  import time
17  import random
18  from datetime import datetime
19  from utils.logger import Logger
20  from utils.random_utils import random_str
21  from utils.batcher import Batcher
22  from visualization.image_visualizer import ImageVisualizer
23  from .run_utils import load,save,load_samples
24  
25  class Tester(object):
26      def __init__(self,model,args):
27          self.model = model
28          self.args = args
29          self.iteration = 0
30          
31          # initialize logging and model saving
32          if self.args.output_dir is not None:
33              self.logger = Logger(os.path.join(self.args.output_dir,'infer_log.json'))
34          else:
35              self.logger = Logger()
36  
37      def evaluate_model(self,inputs):
38          output = self.model.get_model().infer(*inputs)
39          return list(output) if isinstance(output,tuple) else [output] 
40  
41      def test(self):
42          # load after a forward call for dynamic models
43          batched_data,_,_ = load_samples(self.model.get_loader(),self.model.cuda,self.args.batch_size)
44          self.evaluate_model(batched_data)
45          self.iteration = load(self.args.output_dir,self.model.get_model(),self.iteration)
46  
47          for i in range(self.iteration,self.iteration+self.args.iterations):
48              #################### LOAD INPUTS ############################
49              t0 = time.time()
50              batched_data,batched_targets,sample_array = load_samples(self.model.get_loader(),self.model.cuda,self.args.batch_size)
51              self.logger.set('timing.input_loading_time',time.time() - t0)
52              #############################################################
53  
54              #################### FORWARD ################################
55              t1 = time.time()
56              outputs = self.evaluate_model(batched_data)
57              self.logger.set('timing.foward_pass_time',time.time() - t1)
58              #############################################################
59  
60              #################### LOGGING, VIZ ###################
61              print('iteration: {0}'.format(self.iteration))
62  
63              self.logger.set('time',time.time())
64              self.logger.set('date',str(datetime.now()))
65              self.logger.set('iteration',self.iteration)
66              self.logger.dump_line()
67              self.iteration+=1
68  
69              Batcher.debatch_outputs(sample_array,outputs)
70              list(map(lambda x:x.visualize({'title':random_str(5),'mode':'test'}),sample_array))
71              if self.args.visualize_iter>0 and self.iteration%self.args.visualize_iter==0:
72                  print('dumping {}'.format('testviz_{0:08d}.svg'.format(self.iteration)))
73                  ImageVisualizer().dump_image(os.path.join(self.args.output_dir,'testviz_{0:08d}.svg'.format(self.iteration)))
74  
75              #############################################################
76  
77  
78  if __name__ == '__main__':
79      parser=argparse.ArgumentParser()
80      parser.add_argument('-t','--test_config',required=True,type=str,help='the train configuration')
81      parser.add_argument('-b','--batch_size',default=1, required=False,type=int,help='the batch_size')
82      parser.add_argument('-i','--iterations',required=False, type=int, help='the number of iterations', default=1)
83      parser.add_argument('-v','--visualize_iter',required=False, default=1,type=int, help='save visualizations every this many iterations')
84      parser.add_argument('-o','--output_dir',required=False,type=str,default='tmp',help='the directory where the model weights are')
85      parser.add_argument('-e','--seed',type=int,help='the random seed for torch',default=123)
86      args=parser.parse_args()
87  
88      print("Loading Model ...")
89      config_file = imp.load_source('test_config', args.test_config)
90      args.cuda = config_file.test_config.cuda
91      random.seed(args.seed)
92      torch.manual_seed(args.seed)
93      if args.cuda:
94          torch.cuda.manual_seed(args.seed)
95      tester = Tester(config_file.test_config,args)
96  
97      print("Starting Inference ...")
98      tester.test()