trainer.py
1 from __future__ import print_function 2 from __future__ import absolute_import 3 from builtins import str 4 from builtins import map 5 from builtins import range 6 from builtins import object 7 # this is needed to make matplotlib work without explicitly connect to X 8 import matplotlib 9 matplotlib.use('svg') 10 11 import torch 12 from image.ptimage import PTImage 13 import argparse 14 import imp 15 import os 16 import time 17 import random 18 from datetime import datetime 19 from utils.logger import Logger 20 from utils.random_utils import random_str 21 from visualization.graph_visualizer import compute_graph 22 from visualization.image_visualizer import ImageVisualizer 23 from utils.batcher import Batcher 24 from run_utils import load,save,load_samples 25 from utils.directory_tools import mkdir 26 from utils.memory import Memory 27 28 class Trainer(object): 29 def __init__(self,model,args): 30 self.model = model 31 self.args = args 32 self.iteration = 0 33 self.memory = Memory() 34 35 if self.args.override or not os.path.isdir(self.args.output_dir) or self.args.output_dir=='tmp': 36 mkdir(self.args.output_dir,wipe=True) 37 38 # initialize logging and model saving 39 if self.args.output_dir is not None: 40 self.logger = Logger(os.path.join(self.args.output_dir,'train_log.json')) 41 else: 42 self.logger = Logger() 43 44 # a wrapper for model.forward to feed inputs as list and get outputs as a list 45 def evaluate_model(self,inputs): 46 output = self.model.get_model().forward(*inputs) 47 return list(output) if isinstance(output,tuple) else [output] 48 49 def train(self): 50 # load after a forward call for dynamic models 51 batched_data,_,_ = load_samples(self.model.get_loader(),self.model.cuda,self.args.batch_size) 52 self.evaluate_model(batched_data) 53 self.iteration = load(self.args.output_dir,self.model.get_model(),self.iteration,self.model.get_optimizer()) 54 55 for i in range(self.iteration,self.iteration+self.args.iterations): 56 #################### LOAD INPUTS ############################ 57 # TODO, make separate timer class if more complex timings arise 58 t0 = time.time() 59 batched_data,batched_targets,sample_array = load_samples(self.model.get_loader(),self.model.cuda,self.args.batch_size) 60 self.logger.set('timing.input_loading_time',time.time() - t0) 61 ############################################################# 62 63 #################### FORWARD ################################ 64 t1 = time.time() 65 outputs = self.evaluate_model(batched_data) 66 self.logger.set('timing.foward_pass_time',time.time() - t1) 67 ############################################################# 68 69 #################### BACKWARD AND SGD ##################### 70 t2 = time.time() 71 loss = self.model.get_lossfn()(*(outputs + batched_targets)) 72 self.model.get_optimizer().zero_grad() 73 loss.backward() 74 self.model.get_optimizer().step() 75 self.logger.set('timing.loss_backward_update_time',time.time() - t2) 76 ############################################################# 77 78 #################### LOGGING, VIZ and SAVE ################### 79 print('iteration: {0} loss: {1}'.format(self.iteration,loss.data.item())) 80 81 if self.args.compute_graph and i==self.iteration: 82 compute_graph(loss,output_file=os.path.join(self.args.output_dir,self.args.compute_graph)) 83 84 if self.iteration%self.args.save_iter==0: 85 save(self.model.get_model(),self.model.get_optimizer(),self.iteration,self.args.output_dir) 86 87 self.logger.set('time',time.time()) 88 self.logger.set('date',str(datetime.now())) 89 self.logger.set('loss',loss.data.item()) 90 self.logger.set('iteration',self.iteration) 91 self.logger.set('resident_memory',str(self.memory.resident(scale='mB'))+'mB') 92 self.logger.dump_line() 93 self.iteration+=1 94 95 if self.args.visualize_iter>0 and self.iteration%self.args.visualize_iter==0: 96 Batcher.debatch_outputs(sample_array,outputs) 97 list([x.visualize({'title':random_str(5)}) for x in sample_array]) 98 ImageVisualizer().dump_image(os.path.join(self.args.output_dir,'visualizations_{0:08d}.svg'.format(self.iteration))) 99 ############################################################# 100 101 102 if __name__ == '__main__': 103 parser=argparse.ArgumentParser() 104 parser.add_argument('-t','--train_config',required=True,type=str,help='the train configuration') 105 parser.add_argument('-b','--batch_size',default=1, required=False,type=int,help='the batch_size') 106 parser.add_argument('-i','--iterations',required=False, type=int, help='the number of iterations', default=1) 107 parser.add_argument('-v','--visualize_iter',required=False, default=1000,type=int, help='save visualizations every this many iterations') 108 parser.add_argument('-o','--output_dir',required=False,type=str,default='tmp',help='the directory to output the model params and logs') 109 parser.add_argument('-s','--save_iter',type=int,help='save params every this many iterations',default=5000) 110 parser.add_argument('-r','--override',action='store_true',help='if override, the directory will be wiped, otherwise resume from the current dir') 111 parser.add_argument('-e','--seed',type=int,help='the random seed for torch',default=123) 112 parser.add_argument('-g','--compute_graph',default='cgraph',type=str,help='generate the computational graph on the first iteration and write to this file') 113 args=parser.parse_args() 114 115 print("Loading Configuration ...") 116 config_file = imp.load_source('train_config', args.train_config) 117 args.cuda = config_file.train_config.cuda 118 random.seed(args.seed) 119 torch.manual_seed(args.seed) 120 if args.cuda: 121 torch.cuda.manual_seed(args.seed) 122 trainer = Trainer(config_file.train_config,args) 123 124 print("Starting Training ...") 125 trainer.train()