tester.py
1 from __future__ import print_function 2 from __future__ import absolute_import 3 # this is needed to make matplotlib work without explicitly connect to X 4 from builtins import map 5 from builtins import str 6 from builtins import range 7 from builtins import object 8 import matplotlib 9 matplotlib.use('svg') 10 11 import torch 12 from image.ptimage import PTImage 13 import argparse 14 import imp 15 import os 16 import time 17 import random 18 from datetime import datetime 19 from utils.logger import Logger 20 from utils.random_utils import random_str 21 from utils.batcher import Batcher 22 from visualization.image_visualizer import ImageVisualizer 23 from .run_utils import load,save,load_samples 24 25 class Tester(object): 26 def __init__(self,model,args): 27 self.model = model 28 self.args = args 29 self.iteration = 0 30 31 # initialize logging and model saving 32 if self.args.output_dir is not None: 33 self.logger = Logger(os.path.join(self.args.output_dir,'infer_log.json')) 34 else: 35 self.logger = Logger() 36 37 def evaluate_model(self,inputs): 38 output = self.model.get_model().infer(*inputs) 39 return list(output) if isinstance(output,tuple) else [output] 40 41 def test(self): 42 # load after a forward call for dynamic models 43 batched_data,_,_ = load_samples(self.model.get_loader(),self.model.cuda,self.args.batch_size) 44 self.evaluate_model(batched_data) 45 self.iteration = load(self.args.output_dir,self.model.get_model(),self.iteration) 46 47 for i in range(self.iteration,self.iteration+self.args.iterations): 48 #################### LOAD INPUTS ############################ 49 t0 = time.time() 50 batched_data,batched_targets,sample_array = load_samples(self.model.get_loader(),self.model.cuda,self.args.batch_size) 51 self.logger.set('timing.input_loading_time',time.time() - t0) 52 ############################################################# 53 54 #################### FORWARD ################################ 55 t1 = time.time() 56 outputs = self.evaluate_model(batched_data) 57 self.logger.set('timing.foward_pass_time',time.time() - t1) 58 ############################################################# 59 60 #################### LOGGING, VIZ ################### 61 print('iteration: {0}'.format(self.iteration)) 62 63 self.logger.set('time',time.time()) 64 self.logger.set('date',str(datetime.now())) 65 self.logger.set('iteration',self.iteration) 66 self.logger.dump_line() 67 self.iteration+=1 68 69 Batcher.debatch_outputs(sample_array,outputs) 70 list(map(lambda x:x.visualize({'title':random_str(5),'mode':'test'}),sample_array)) 71 if self.args.visualize_iter>0 and self.iteration%self.args.visualize_iter==0: 72 print('dumping {}'.format('testviz_{0:08d}.svg'.format(self.iteration))) 73 ImageVisualizer().dump_image(os.path.join(self.args.output_dir,'testviz_{0:08d}.svg'.format(self.iteration))) 74 75 ############################################################# 76 77 78 if __name__ == '__main__': 79 parser=argparse.ArgumentParser() 80 parser.add_argument('-t','--test_config',required=True,type=str,help='the train configuration') 81 parser.add_argument('-b','--batch_size',default=1, required=False,type=int,help='the batch_size') 82 parser.add_argument('-i','--iterations',required=False, type=int, help='the number of iterations', default=1) 83 parser.add_argument('-v','--visualize_iter',required=False, default=1,type=int, help='save visualizations every this many iterations') 84 parser.add_argument('-o','--output_dir',required=False,type=str,default='tmp',help='the directory where the model weights are') 85 parser.add_argument('-e','--seed',type=int,help='the random seed for torch',default=123) 86 args=parser.parse_args() 87 88 print("Loading Model ...") 89 config_file = imp.load_source('test_config', args.test_config) 90 args.cuda = config_file.test_config.cuda 91 random.seed(args.seed) 92 torch.manual_seed(args.seed) 93 if args.cuda: 94 torch.cuda.manual_seed(args.seed) 95 tester = Tester(config_file.test_config,args) 96 97 print("Starting Inference ...") 98 tester.test()