run_utils.py
1 from __future__ import print_function 2 from builtins import next 3 from builtins import range 4 import torch 5 import os 6 from utils.directory_tools import mkdir, list_files 7 from utils.batcher import Batcher 8 9 def save(model,optimizer,iteration,output_dir): 10 state = {} 11 state['iteration']=iteration+1 12 state['state_dict']=model.state_dict() 13 state['optimizer']=optimizer.state_dict() 14 with open(os.path.join(output_dir,'model_{0:08d}.mdl'.format(iteration)),'wb') as f: 15 torch.save(state,f) 16 17 def load(output_dir,model,iteration,optimizer=None): 18 # list model files and find the latest_model 19 all_models = list_files(output_dir,ext_filter='.mdl') 20 if not all_models: 21 print('No previous checkpoints found!') 22 return iteration 23 24 all_models_indexed = [(m,int(m.split('.mdl')[0].split('_')[-1])) for m in all_models] 25 all_models_indexed.sort(key=lambda x: x[1],reverse=True) 26 print('Loading model from disk: {0}'.format(all_models_indexed[0][0])) 27 checkpoint = torch.load(all_models_indexed[0][0]) 28 model.load_state_dict(checkpoint['state_dict']) 29 if optimizer is not None: 30 optimizer.load_state_dict(checkpoint['optimizer']) 31 return checkpoint['iteration'] 32 33 def load_samples(loader,cuda,batch_size): 34 sample_array = [next(loader) for i in range(0,batch_size)] 35 batched_data, batched_targets = Batcher.batch_samples(sample_array) 36 if cuda: 37 batched_data = [x.cuda() for x in batched_data] 38 batched_targets = [x.cuda() for x in batched_targets] 39 return batched_data,batched_targets,sample_array