/ pytlib / run / run_utils.py
run_utils.py
 1  from __future__ import print_function
 2  from builtins import next
 3  from builtins import range
 4  import torch
 5  import os
 6  from utils.directory_tools import mkdir, list_files
 7  from utils.batcher import Batcher
 8  
 9  def save(model,optimizer,iteration,output_dir):
10      state = {}
11      state['iteration']=iteration+1
12      state['state_dict']=model.state_dict()
13      state['optimizer']=optimizer.state_dict()
14      with open(os.path.join(output_dir,'model_{0:08d}.mdl'.format(iteration)),'wb') as f:
15          torch.save(state,f)
16  
17  def load(output_dir,model,iteration,optimizer=None):
18      # list model files and find the latest_model
19      all_models = list_files(output_dir,ext_filter='.mdl')
20      if not all_models:
21          print('No previous checkpoints found!')
22          return iteration
23  
24      all_models_indexed = [(m,int(m.split('.mdl')[0].split('_')[-1])) for m in all_models]
25      all_models_indexed.sort(key=lambda x: x[1],reverse=True)
26      print('Loading model from disk: {0}'.format(all_models_indexed[0][0]))
27      checkpoint = torch.load(all_models_indexed[0][0])
28      model.load_state_dict(checkpoint['state_dict'])
29      if optimizer is not None:
30          optimizer.load_state_dict(checkpoint['optimizer'])
31      return checkpoint['iteration']
32  
33  def load_samples(loader,cuda,batch_size):
34      sample_array = [next(loader) for i in range(0,batch_size)]
35      batched_data, batched_targets = Batcher.batch_samples(sample_array)
36      if cuda:
37          batched_data = [x.cuda() for x in batched_data]
38          batched_targets = [x.cuda() for x in batched_targets]
39      return batched_data,batched_targets,sample_array