/ pytlib / utils / batcher.py
batcher.py
 1  from builtins import map
 2  from builtins import zip
 3  from builtins import range
 4  from builtins import object
 5  import torch
 6  import numpy as np
 7  from  data_loading.sample import Sample
 8  
 9  class Batcher(object):
10      # turns an array of inputs into a batched Variables
11      # assumes cudnn (BCHW) ordering
12      # input is a list of list -> [[x1,x2,x3],[y1,y2,y3]]
13      @staticmethod
14      def batch(inputs):
15          assert isinstance(inputs,list) and isinstance(inputs[0],list)
16          return [torch.stack(x,0) for x in inputs]
17  
18      # turns a batched output into an array of outputs
19      # the expected input dimensions to this function is:
20      # array(batched output tensor) output_type size -> array(array(output tensor)) output_type size x batch size 
21      # OR
22      # array(array(batched output tensor)) output_type size x sequence size -> array(array(array(output tensor))) output_type x batch size x sequence size
23      @staticmethod
24      def debatch(outputs):
25          assert isinstance(outputs,list)
26          # if outputs is a list of list, we recurse
27          result = []
28          def debatch_helper(batched_data):
29              return [torch.squeeze(y,0) for y in torch.chunk(batched_data.data,batched_data.size(0),0)]
30  
31          # we want: N results x n_sequence x n_batch
32          for x in outputs: # loop over individual outputs
33              if isinstance(x,list): # this is sequence data
34                  # this is an array(array(output tensor)) -> sequence size x batch size
35                  sequence_batches = [debatch_helper(y) for y in x]
36                  result.append(list(map(list,list(zip(*sequence_batches)))))
37              else:
38                  result.append(debatch_helper(x))
39          return result
40          # return [map(lambda y: torch.squeeze(y,0),torch.chunk(x.data,x.size(0),0)) for x in outputs]
41  
42      # turns an array of samples into a batch of inputs and targets
43      # for each s0 in sample_array -> [s0,s1,...,sn] 
44      # s0 -> [data -> [d0,d1,...,dn], target -> [t0,t1,...,tn]]
45      @staticmethod
46      def batch_samples(sample_array):
47          data_list = list([x.get_data() for x in sample_array])
48          target_list = list([x.get_target() for x in sample_array])
49          return Batcher.batch(list(map(list,list(zip(*data_list))))), Batcher.batch(list(map(list,list(zip(*target_list)))))
50  
51      # store the batched outputs in the corresponding sample array
52      # batched outputs have the form [out0*batch_size,out1*batch_size,out2*batch_size...]
53      @staticmethod
54      def debatch_outputs(sample_array,batched_outputs):
55          output_array = Batcher.debatch(batched_outputs)
56          # output_type size x batchsize -> batch size x outputtype size
57          output_array = list(map(list,list(zip(*output_array))))
58          assert len(output_array)==len(sample_array), 'sample array size {} is the not the same as output_array {}!'.format(len(sample_array),len(output_array))
59          for i in range(0,len(output_array)):
60              sample_array[i].set_output(output_array[i])