batcher.py
1 from builtins import map 2 from builtins import zip 3 from builtins import range 4 from builtins import object 5 import torch 6 import numpy as np 7 from data_loading.sample import Sample 8 9 class Batcher(object): 10 # turns an array of inputs into a batched Variables 11 # assumes cudnn (BCHW) ordering 12 # input is a list of list -> [[x1,x2,x3],[y1,y2,y3]] 13 @staticmethod 14 def batch(inputs): 15 assert isinstance(inputs,list) and isinstance(inputs[0],list) 16 return [torch.stack(x,0) for x in inputs] 17 18 # turns a batched output into an array of outputs 19 # the expected input dimensions to this function is: 20 # array(batched output tensor) output_type size -> array(array(output tensor)) output_type size x batch size 21 # OR 22 # array(array(batched output tensor)) output_type size x sequence size -> array(array(array(output tensor))) output_type x batch size x sequence size 23 @staticmethod 24 def debatch(outputs): 25 assert isinstance(outputs,list) 26 # if outputs is a list of list, we recurse 27 result = [] 28 def debatch_helper(batched_data): 29 return [torch.squeeze(y,0) for y in torch.chunk(batched_data.data,batched_data.size(0),0)] 30 31 # we want: N results x n_sequence x n_batch 32 for x in outputs: # loop over individual outputs 33 if isinstance(x,list): # this is sequence data 34 # this is an array(array(output tensor)) -> sequence size x batch size 35 sequence_batches = [debatch_helper(y) for y in x] 36 result.append(list(map(list,list(zip(*sequence_batches))))) 37 else: 38 result.append(debatch_helper(x)) 39 return result 40 # return [map(lambda y: torch.squeeze(y,0),torch.chunk(x.data,x.size(0),0)) for x in outputs] 41 42 # turns an array of samples into a batch of inputs and targets 43 # for each s0 in sample_array -> [s0,s1,...,sn] 44 # s0 -> [data -> [d0,d1,...,dn], target -> [t0,t1,...,tn]] 45 @staticmethod 46 def batch_samples(sample_array): 47 data_list = list([x.get_data() for x in sample_array]) 48 target_list = list([x.get_target() for x in sample_array]) 49 return Batcher.batch(list(map(list,list(zip(*data_list))))), Batcher.batch(list(map(list,list(zip(*target_list))))) 50 51 # store the batched outputs in the corresponding sample array 52 # batched outputs have the form [out0*batch_size,out1*batch_size,out2*batch_size...] 53 @staticmethod 54 def debatch_outputs(sample_array,batched_outputs): 55 output_array = Batcher.debatch(batched_outputs) 56 # output_type size x batchsize -> batch size x outputtype size 57 output_array = list(map(list,list(zip(*output_array)))) 58 assert len(output_array)==len(sample_array), 'sample array size {} is the not the same as output_array {}!'.format(len(sample_array),len(output_array)) 59 for i in range(0,len(output_array)): 60 sample_array[i].set_output(output_array[i])