/ pytlib / run / trainer.py
trainer.py
  1  from __future__ import print_function
  2  from __future__ import absolute_import
  3  from builtins import str
  4  from builtins import map
  5  from builtins import range
  6  from builtins import object
  7  # this is needed to make matplotlib work without explicitly connect to X
  8  import matplotlib 
  9  matplotlib.use('svg')
 10  
 11  import torch
 12  from image.ptimage import PTImage
 13  import argparse
 14  import imp
 15  import os
 16  import time
 17  import random
 18  from datetime import datetime
 19  from utils.logger import Logger
 20  from utils.random_utils import random_str
 21  from visualization.graph_visualizer import compute_graph
 22  from visualization.image_visualizer import ImageVisualizer
 23  from utils.batcher import Batcher
 24  from run_utils import load,save,load_samples
 25  from utils.directory_tools import mkdir
 26  from utils.memory import Memory
 27  
 28  class Trainer(object):
 29      def __init__(self,model,args):
 30          self.model = model
 31          self.args = args
 32          self.iteration = 0
 33          self.memory = Memory()
 34  
 35          if self.args.override or not os.path.isdir(self.args.output_dir) or self.args.output_dir=='tmp':
 36              mkdir(self.args.output_dir,wipe=True)     
 37  
 38          # initialize logging and model saving
 39          if self.args.output_dir is not None:
 40              self.logger = Logger(os.path.join(self.args.output_dir,'train_log.json'))
 41          else:
 42              self.logger = Logger()
 43  
 44      # a wrapper for model.forward to feed inputs as list and get outputs as a list
 45      def evaluate_model(self,inputs):
 46          output = self.model.get_model().forward(*inputs)
 47          return list(output) if isinstance(output,tuple) else [output] 
 48  
 49      def train(self):
 50          # load after a forward call for dynamic models
 51          batched_data,_,_ = load_samples(self.model.get_loader(),self.model.cuda,self.args.batch_size)
 52          self.evaluate_model(batched_data)
 53          self.iteration = load(self.args.output_dir,self.model.get_model(),self.iteration,self.model.get_optimizer())
 54  
 55          for i in range(self.iteration,self.iteration+self.args.iterations):
 56              #################### LOAD INPUTS ############################
 57              # TODO, make separate timer class if more complex timings arise
 58              t0 = time.time()
 59              batched_data,batched_targets,sample_array = load_samples(self.model.get_loader(),self.model.cuda,self.args.batch_size)
 60              self.logger.set('timing.input_loading_time',time.time() - t0)
 61              #############################################################
 62  
 63              #################### FORWARD ################################
 64              t1 = time.time()
 65              outputs = self.evaluate_model(batched_data)
 66              self.logger.set('timing.foward_pass_time',time.time() - t1)
 67              #############################################################
 68  
 69              #################### BACKWARD AND SGD  #####################
 70              t2 = time.time()
 71              loss = self.model.get_lossfn()(*(outputs + batched_targets))
 72              self.model.get_optimizer().zero_grad()
 73              loss.backward()
 74              self.model.get_optimizer().step()
 75              self.logger.set('timing.loss_backward_update_time',time.time() - t2)
 76              #############################################################
 77  
 78              #################### LOGGING, VIZ and SAVE ###################
 79              print('iteration: {0} loss: {1}'.format(self.iteration,loss.data.item()))
 80  
 81              if self.args.compute_graph and i==self.iteration:
 82                  compute_graph(loss,output_file=os.path.join(self.args.output_dir,self.args.compute_graph))
 83  
 84              if self.iteration%self.args.save_iter==0:
 85                  save(self.model.get_model(),self.model.get_optimizer(),self.iteration,self.args.output_dir)
 86  
 87              self.logger.set('time',time.time())
 88              self.logger.set('date',str(datetime.now()))
 89              self.logger.set('loss',loss.data.item())
 90              self.logger.set('iteration',self.iteration)
 91              self.logger.set('resident_memory',str(self.memory.resident(scale='mB'))+'mB')
 92              self.logger.dump_line()
 93              self.iteration+=1
 94  
 95              if self.args.visualize_iter>0 and self.iteration%self.args.visualize_iter==0:
 96                  Batcher.debatch_outputs(sample_array,outputs)
 97                  list([x.visualize({'title':random_str(5)}) for x in sample_array])
 98                  ImageVisualizer().dump_image(os.path.join(self.args.output_dir,'visualizations_{0:08d}.svg'.format(self.iteration)))
 99              #############################################################
100  
101  
102  if __name__ == '__main__':
103      parser=argparse.ArgumentParser()
104      parser.add_argument('-t','--train_config',required=True,type=str,help='the train configuration')
105      parser.add_argument('-b','--batch_size',default=1, required=False,type=int,help='the batch_size')
106      parser.add_argument('-i','--iterations',required=False, type=int, help='the number of iterations', default=1)
107      parser.add_argument('-v','--visualize_iter',required=False, default=1000,type=int, help='save visualizations every this many iterations')
108      parser.add_argument('-o','--output_dir',required=False,type=str,default='tmp',help='the directory to output the model params and logs')
109      parser.add_argument('-s','--save_iter',type=int,help='save params every this many iterations',default=5000)
110      parser.add_argument('-r','--override',action='store_true',help='if override, the directory will be wiped, otherwise resume from the current dir')
111      parser.add_argument('-e','--seed',type=int,help='the random seed for torch',default=123)
112      parser.add_argument('-g','--compute_graph',default='cgraph',type=str,help='generate the computational graph on the first iteration and write to this file')
113      args=parser.parse_args()
114  
115      print("Loading Configuration ...")
116      config_file = imp.load_source('train_config', args.train_config)
117      args.cuda = config_file.train_config.cuda
118      random.seed(args.seed)
119      torch.manual_seed(args.seed)
120      if args.cuda:
121          torch.cuda.manual_seed(args.seed)
122      trainer = Trainer(config_file.train_config,args)
123  
124      print("Starting Training ...")
125      trainer.train()