/ backends / cpu / kernels_conv.cpp
kernels_conv.cpp
  1  // #include "indexing.cpp" // imported though main -> ops.cpp -> indexing.cpp
  2  #include <stdio.h> // sprintf
  3  
  4  #define IS_DEBUG false
  5  #define IS_DEBUG_MP false
  6  
  7  #define STRIDE 2
  8  
  9  #define STRIDE_CONV 1
 10  #define STRIDE_MAXPOOL 2
 11  
 12  
 13  
 14  /*
 15  x: Input data of shape (C, H, W)
 16  w: Filter weights of shape (F, C, HH, WW)
 17  
 18  todo: implement pytorch's C optional as way to support tunable args
 19    - 'stride': The number of pixels between adjacent receptive fields in the horizontal and vertical directions.
 20    - 'pad': The number of pixels that will be used to zero-pad the input. 
 21      During padding, 'pad' zeros should be placed symmetrically (i.e equally on both sides) along the height and width axes of the input.
 22  
 23  Returns a tuple of:
 24  - out: Output data, of shape (F, H', W')
 25  */
 26  tensor* conv_k_(tensor* input, tensor* kernel, tensor* bias, tensor* out) {
 27  
 28      assert_input(out, 3);
 29      assert_input(input, 3);
 30      assert_input(kernel, 4);
 31      assert_input(bias, 2);
 32  
 33      int C = input->shape[0], H = input->shape[1], W = input->shape[2];
 34      int F = kernel->shape[0], HH = kernel->shape[2], WW = kernel->shape[3];
 35  
 36      int h_out = 1 + (H - HH) / STRIDE_CONV;
 37      int w_out = 1 + (W - WW) / STRIDE_CONV;
 38  
 39      if (IS_DEBUG){
 40          printf("[conv_k_] h_out: %i\n", h_out);
 41          printf("[conv_k_] w_out: %i\n", w_out);
 42      }
 43  
 44      for (int f=0; f<F; f++){
 45          for (int hight=0; hight<h_out; hight++){
 46              for (int width=0; width<w_out; width++){
 47  
 48                  // 0. select current filter
 49  
 50                  // todo: slice
 51                  //   - need 4d slice? not necessarily, can just do pointer arithmetic to skip to the right kernel -- but problem is that it unwraps tensor to float, but kernels expect tensor
 52                  //   - have the unified mechansm for all slices -- replace "f*grad_kernels->stride[0];" with slice 4d
 53  
 54                  if (IS_DEBUG)
 55                      printf("[conv_k_] f*C*HH*WW: %i\n", f*C*HH*WW);
 56  
 57                  // simple pointer arithmetic to skip from "f" kernels
 58                  float* curr_kernel = kernel->data + f*kernel->stride[0];
 59  
 60                  // workaround to put the data into tensor type
 61                  // todo-high:: constructor does NOT take care of setting correct strides
 62                  tensor* curr_filter = TensorNoData(C, HH, WW);
 63                  curr_filter->data = curr_kernel;
 64  
 65                  if (IS_DEBUG){
 66                      set_name(curr_filter, "curr_filter"); print(curr_filter);
 67                      printf("[conv_k_] f*C*HH*WW: %i\n", f*kernel->stride[0]);
 68                  }
 69  
 70                  // 1. select the chunk of input that this location in the ouput (f, h, w) is looking at
 71                  int vert_start = hight * STRIDE_CONV;
 72                  int vert_end = vert_start + HH;
 73                  int horiz_start = width * STRIDE_CONV;
 74                  int horiz_end = horiz_start + WW;
 75  
 76                  // e.g. "0:3, 30:32, 30:32" is 18 digits
 77  
 78                  // todo: use view instead of slice
 79                  //   - mul_k (used below) -- needs to use .at when looping over a->size
 80                  tensor* x_slice = slice(input, axis(0, C), axis(vert_start, vert_end), axis(horiz_start, horiz_end));
 81  
 82                  if (IS_DEBUG){
 83                      // printf("[conv_k_] buffer x_slice %s\n", buffer);
 84                      printf("[conv_k_] x_slice->shape: %i, %i, %i\n", x_slice->shape[0], x_slice->shape[1], x_slice->shape[2]);
 85                      set_name(x_slice, "x_slice"); print(x_slice);
 86                  }
 87  
 88                  // 2. element-wise multiply and sum
 89                  tensor* curr_out = mul_k(x_slice, curr_filter);
 90                  if (IS_DEBUG){
 91                      printf("[conv_k_] curr_out->shape: %i, %i, %i\n", curr_out->shape[0], curr_out->shape[1], curr_out->shape[2]);
 92                      set_name(curr_out, "curr_out"); print(curr_out);
 93                  }
 94  
 95                  // todo: add flag to tensors, is_contiguous and make each op check that flag and err otherwise
 96                  curr_out = reduce_sum_k(curr_out);
 97  
 98                  // 3. add bias
 99                  curr_out->data[0] += bias->data[f];
100  
101                  // 4. write to the current location at the output
102                  if (IS_DEBUG)
103                      printf("[conv_k_] curr_out->data[0]: %f\n", curr_out->data[0]);
104                  out->data[index(out, f, hight, width)] = curr_out->data[0];
105              }
106          }
107      }
108      return out;
109  }
110  
111  // input (C, H, W)
112  // kernel (F, C, HH, WW)
113  tensor* conv_k(tensor* input, tensor* kernel, tensor* bias) {
114      int H = input->shape[1], W = input->shape[2];
115      int F = kernel->shape[0], HH = kernel->shape[2], WW = kernel->shape[3];
116  
117      // todo: de-duplicate same calculations of h_out, w_out for all the kernels in this file
118      int h_out = 1 + (H - HH) / STRIDE_CONV;
119      int w_out = 1 + (W - WW) / STRIDE_CONV;
120  
121      if (IS_DEBUG){
122          printf("[conv_k] h_out: %i\n", h_out);
123          printf("[conv_k] w_out: %i\n", w_out);
124      }
125  
126      tensor* out = EmptyTensor(F, h_out, w_out);
127      return  conv_k_(input, kernel, bias, out);
128  }
129  
130  // conv output, upstream: (F, h_out, w_out)
131  void bwd_conv_k(tensor* upstream, tensor* out) {
132      assert_input(upstream, out->num_dims);
133      tensor* input = out->inputs[0];
134      tensor* kernel = out->inputs[1];
135      tensor* bias = out->inputs[2];
136  
137      if (IS_DEBUG)
138          printf("[bwd_conv_k] input->shape: %i, %i, %i\n", input->shape[0], input->shape[1], input->shape[2]);
139  
140      int C = input->shape[0], H = input->shape[1], W = input->shape[2];
141      int F = kernel->shape[0], HH = kernel->shape[2], WW = kernel->shape[3];
142  
143      int h_out = 1 + (H - HH) / STRIDE_CONV;
144      int w_out = 1 + (W - WW) / STRIDE_CONV;
145  
146      // make sure it's not initialised with garbage
147      tensor* grad_kernels = TensorLikeFill(kernel, 0.0);
148      tensor* grad_bias = TensorLikeFill(bias, 0.0);
149      tensor* grad_x = TensorLikeFill(input, 0.0);
150  
151      for (int f=0; f<F; f++){
152          for (int hight=0; hight<h_out; hight++){
153              for (int width=0; width<w_out; width++){
154  
155                  // select the chunk of input that this location in the ouput (f, h, w) is looking at
156                  int vert_start = hight * STRIDE_CONV;
157                  int vert_end = vert_start + HH;
158                  int horiz_start = width * STRIDE_CONV;
159                  int horiz_end = horiz_start + WW;
160  
161                  // python: x_slice = input[:, vert_start:vert_end, horiz_start:horiz_end]
162                  tensor* x_slice = slice(input, axis(0, C), axis(vert_start, vert_end), axis(horiz_start, horiz_end)); // corresponding slice (was used in forward)
163                  if (IS_DEBUG){
164                      // printf("[bwd_conv_k] buffer x_slice %s\n", buffer);
165                      printf("[bwd_conv_k] x_slice->shape: %i, %i, %i\n", x_slice->shape[0], x_slice->shape[1], x_slice->shape[2]);
166                      set_name(x_slice, "x_slice"); print(x_slice);
167                  }
168  
169                  // python: curr_filter = index(kernel, f); // (F, C, HH, WW) -> (C, HH, WW)
170                  tensor* curr_filter = TensorNoData(C, HH, WW); // filter (was used in forward)
171                  curr_filter->data = kernel->data + f*kernel->stride[0];
172                  if (IS_DEBUG){
173                      printf("[bwd_conv_k] f*C*HH*WW: %i\n", f*C*HH*WW);
174                      set_name(curr_filter, "curr_filter"); print(curr_filter);
175                  }
176  
177  
178                  // python:
179                  //    current_dout = dout[i, f, hight, width]                           # current up-stream grad scalar
180                  //    current_x = x[i, :, vert_start:vert_end, horiz_start:horiz_end]   # corresponding slice (was used in forward)
181                  //    dw[f] += current_x * current_dout
182                  // 1. local grad of a kernel applied to an x_slice is x_slice
183                  // 2. then element-wise multiply that local grad with upstream
184                  // 3. then because there's multiple locations in x (x patches) we slid the kernel through -- elementwise sum the grad
185  
186                  // python: curr_upstream = upstream[f,h,w]
187                  float curr_upstream_float = upstream->data[index(upstream, f, hight, width)]; // scalar
188                  tensor* curr_upstream = TensorLikeFill(x_slice, curr_upstream_float); // broadcast scalar grad to the shape of the slice
189                  if (IS_DEBUG){
190                      printf("[bwd_conv_k] curr_upstream_float: %f", curr_upstream_float);
191                      printf("\n[bwd_conv_k] curr_upstream.shape: %i, %i, %i", curr_upstream->shape[0], curr_upstream->shape[1], curr_upstream->shape[2]);
192                      set_name(curr_upstream, "curr_upstream"); print(curr_upstream);
193                  }
194  
195                  // kernel->grad = add(kernel->grad, mul_k(x_slice, curr_upstream))
196                  tensor* curr_downstream = mul_k(x_slice, curr_upstream);
197  
198                  // record downstream grad of the current slice, into the larger tensor (for the downstream grad)
199  
200                  // workaround for not having non owning slice 4d
201                  //    todo-high: constructor does not set correct strides in this case
202                  tensor* curr_downstream_slice_in_larger_tensor = TensorNoData(C, HH, WW);
203                  curr_downstream_slice_in_larger_tensor->data = grad_kernels->data + f*grad_kernels->stride[0];
204  
205                  add_k_(curr_downstream_slice_in_larger_tensor, curr_downstream, curr_downstream_slice_in_larger_tensor);
206  
207  
208                  // python:
209                  //    dx[i, :, vert_start:vert_end, horiz_start:horiz_end] += current_w * current_dout
210                  // 1. local grad of a x_patch applied to an kernel is kernel
211                  // 2. then element-wise multiply that local grad with upstream
212                  // 3. then because there's multiple locations in x (x patches) we slid the kernel through -- elementwise sum the grad
213  
214                  curr_downstream = mul_k(curr_filter, curr_upstream);
215  
216                  // record downstream grad of the current slice, into the larger tensor (for the downstream grad)
217                  curr_downstream_slice_in_larger_tensor = view(grad_x, axis(0, C), axis(vert_start, vert_end), axis(horiz_start, horiz_end));
218  
219                  add_k_(curr_downstream_slice_in_larger_tensor, curr_downstream, curr_downstream_slice_in_larger_tensor);
220  
221  
222                  // grad wrt bias
223                  grad_bias->data[f] += 1. * curr_upstream_float;
224              }
225          }
226      }
227  
228      kernel->grad = grad_kernels;
229      bias->grad = grad_bias;
230      input->grad = grad_x;
231  }
232  
233  
234  
235  // x (B, C, H, W)
236  // w (F, C, HH, WW)
237  tensor* batched_conv_k(tensor* input, tensor* kernel, tensor* bias){
238      assert_input(input, 4);
239      assert_input(kernel, 4);
240      assert_input(bias, 2);
241  
242      int B = input->shape[0], C = input->shape[1], H = input->shape[2], W = input->shape[3];
243      int F = kernel->shape[0], HH = kernel->shape[2], WW = kernel->shape[3];
244  
245      int h_out = 1 + (H - HH) / STRIDE_CONV;
246      int w_out = 1 + (W - WW) / STRIDE_CONV;
247  
248      if (IS_DEBUG){
249          printf("[batched_conv_k] h_out: %i\n", h_out);
250          printf("[batched_conv_k] w_out: %i\n", w_out);
251      }
252  
253      tensor* out = EmptyTensor(B, F, h_out, w_out);
254  
255      for (int i=0; i<B; i++){
256          // comment: same semantics as in batched_matmul_k
257  
258          tensor* curr_out = TensorNoData(F, h_out, w_out);
259          curr_out->data = out->data + (i * out->stride[0]);
260  
261          tensor* curr_x = TensorNoData(C, H, W);
262          curr_x->data = input->data + (i * input->stride[0]);
263  
264          // comment: add support for 5d tensors? -- NO, w stays 4d
265          conv_k_(curr_x, kernel, bias, curr_out);
266      }
267      return out;
268  }
269  
270  // x (B, C, H, W)
271  // w (F, C, HH, WW)
272  // conv output; upstream: (B, F, h_out, w_out)
273  void bwd_batched_conv_k(tensor* upstream, tensor* out) {
274      assert_input(upstream, out->num_dims);
275  
276      tensor* input = out->inputs[0];
277      tensor* kernel = out->inputs[1];
278      tensor* bias = out->inputs[2];
279  
280      int B = input->shape[0], C = input->shape[1], H = input->shape[2], W = input->shape[3];
281      int F = kernel->shape[0], HH = kernel->shape[2], WW = kernel->shape[3];
282  
283      int h_out = 1 + (H - HH) / STRIDE_CONV;
284      int w_out = 1 + (W - WW) / STRIDE_CONV;
285      if (IS_DEBUG){
286          printf("[bwd_batched_conv_k] h_out: %i\n", h_out);
287          printf("[bwd_batched_conv_k] w_out: %i\n", w_out);
288      }
289  
290      // make sure it's not initialised with garbage
291      tensor* grad_x = TensorLikeFill(input, 0.0);
292      tensor* grad_bias = TensorLikeFill(bias, 0.0);
293      tensor* grad_kernels = TensorLikeFill(kernel, 0.0);
294  
295      for (int i=0; i<B; i++){
296          // comment: same semantics as in batched_matmul_k
297  
298          tensor* curr_x = TensorNoData(C, H, W);
299          curr_x->data = input->data + (i * input->stride[0]);
300  
301          tensor* curr_upstream = TensorNoData(F, h_out, w_out);
302          curr_upstream->data = upstream->data + (i * upstream->stride[0]);
303  
304          tensor* curr_out = TensorNoData(F, h_out, w_out);
305          curr_out->data = out->data + (i * out->stride[0]);
306  
307          // bwd_conv_k unpacks this
308          curr_out->inputs[0] = curr_x;
309          curr_out->inputs[1] = kernel;
310          curr_out->inputs[2] = bias;
311  
312          bwd_conv_k(curr_upstream, curr_out);
313          // set by bwd_conv_k
314          tensor* curr_grad_x = curr_x->grad;
315          tensor* curr_grad_filter = kernel->grad;
316          tensor* curr_grad_bias = bias->grad;
317  
318          for (int ii=0; ii<curr_grad_x->size; ii++){
319              int offset_batch = i * grad_x->stride[0];
320              grad_x->data[offset_batch + ii] = grad_x->data[offset_batch + ii] + curr_grad_x->data[ii];
321          }
322  
323          // can iterate over the grad_kernels bc regardless that we're feeding to bwd_conv_k inputs
324          // only for the current b, we're feeding entire kernel (bc there's no batch dim in kernels)
325          for (int ii=0; ii<grad_kernels->size; ii++){
326              grad_kernels->data[ii] = grad_kernels->data[ii] + curr_grad_filter->data[ii];
327          }
328          // todo-low:
329          // add_k_(curr_grad_filter, grad_kernels, grad_kernels);
330  
331          for (int ii=0; ii<grad_bias->size; ii++){
332              grad_bias->data[ii] = grad_bias->data[ii] + curr_grad_bias->data[ii];
333          }
334      }
335  
336      input->grad = grad_x;
337      kernel->grad = grad_kernels;
338      bias->grad = grad_bias;
339  }
340  
341  
342  
343  // x (C, H, W)
344  tensor* maxpool_k_(tensor* input, tensor* out) {
345      assert_input(input, 3);
346      assert_input(out, 3);
347  
348      // todo: up until and including "int horiz_end" line, was copied from conv. Reduce duplication.
349  
350      int C = input->shape[0], H = input->shape[1], W = input->shape[2];
351  
352      // hyperparameters
353      int HH = 2, WW = 2;
354  
355      int h_out = 1 + (H - HH) / STRIDE_MAXPOOL;
356      int w_out = 1 + (W - WW) / STRIDE_MAXPOOL;
357  
358      if (IS_DEBUG_MP){
359          printf("[maxpool_k_] h_out: %i\n", h_out);
360          printf("[maxpool_k_] w_out: %i\n", w_out);
361      }
362  
363      for (int c=0; c<C; c++){
364          for (int hight=0; hight<h_out; hight++){
365              for (int width=0; width<w_out; width++){
366  
367                  if (IS_DEBUG_MP)
368                      printf("[maxpool_k_] c*C*HH*WW: %i\n", c*C*HH*WW);
369                  // 1. select the chunk of input that this location in the ouput (f, h, w) is looking at
370                  int vert_start = hight * STRIDE_MAXPOOL;
371                  int vert_end = vert_start + HH;
372                  int horiz_start = width * STRIDE_MAXPOOL;
373                  int horiz_end = horiz_start + WW;
374  
375                  // select only 1 channel here
376                  // todo:  pass as array of axis?
377                  // axis slice_axis[3] = {axis(c, c+1), axis(vert_start, vert_end), axis(horiz_start, horiz_end)};
378                  tensor* x_slice = view(input, axis(c, c+1), axis(vert_start, vert_end), axis(horiz_start, horiz_end));
379                  if (IS_DEBUG_MP){
380                      // printf("[maxpool_k_] buffer x_slice %s\n", buffer);
381                      printf("[maxpool_k_] x_slice->shape: %i, %i, %i\n", x_slice->shape[0], x_slice->shape[1], x_slice->shape[2]);
382                      set_name(x_slice, "x_slice"); print(x_slice);
383                  }
384  
385                  // select maximum element
386                  // need to recompute this during backward (so that you put local grad 1 into the lications where there was maximum element)
387                  float max = x_slice->data[0];
388                  for (int i=0; i<x_slice->size; i++){
389                      if (x_slice->data[at(x_slice, i)] > max) {
390                          // comment: 
391                          //  crucial to use at here and not simple "x_slice->data[i]" bc x_slice is a view thus it's NON contiguous!
392                          //  alternatively, can create x_slice with "slice" instead of "view" -- the former will make a contiguous copy, in which case fine to index with "x_slice->data[i]
393                          max = x_slice->data[at(x_slice, i)];
394                      }
395                  }
396                  if (IS_DEBUG_MP)
397                      printf("[maxpool_k_] max: %f\n", max);
398                  out->data[index(out, c, hight, width)] = max;
399              }
400          }
401      }
402      return out;
403  }
404  
405  // x (C, H, W)
406  // w (F, C, HH, WW)
407  // todo: was copy/pate from conv_k -- reduce duplication
408  tensor* maxpool_k(tensor* input) {
409      int C = input->shape[0], H = input->shape[1], W = input->shape[2];
410  
411      // hyperparameters
412      int HH = 2, WW = 2;
413  
414      int h_out = 1 + (H - HH) / STRIDE_MAXPOOL;
415      int w_out = 1 + (W - WW) / STRIDE_MAXPOOL;
416  
417      if (IS_DEBUG_MP){
418          printf("[maxpool_k] h_out: %i\n", h_out);
419          printf("[maxpool_k] w_out: %i\n", w_out);
420      }
421  
422      tensor* out = EmptyTensor(C, h_out, w_out);
423      return  maxpool_k_(input, out);
424  }
425  
426  // todo-high:
427  //  - most of below is copy-pasted from the bwd_conv_k
428  //  - it's wasteful to have maxpool logic (65 lines) duplicated exactly in its bwd (which needed for re-computing local grad) -- instead just add another filed on the tensor called local_grad
429  void bwd_maxpool_k(tensor* upstream, tensor* out) {
430      assert_input(upstream, out->num_dims);
431      tensor* input = out->inputs[0];
432  
433      int C = input->shape[0], H = input->shape[1], W = input->shape[2];
434  
435      // hyperparams
436      int HH = 2, WW = 2;
437  
438      int h_out = 1 + (H - HH) / STRIDE_MAXPOOL;
439      int w_out = 1 + (W - WW) / STRIDE_MAXPOOL;
440  
441      tensor* downstream = TensorLikeFill(input, 0.0);
442  
443      for (int c=0; c<C; c++){
444          for (int hight=0; hight<h_out; hight++){
445              for (int width=0; width<w_out; width++){
446  
447                  if (IS_DEBUG_MP)
448                      printf("[bwd_maxpool_k] c*C*HH*WW: %i\n", c*C*HH*WW);
449  
450                  // 1. select the chunk of input that this location in the ouput (f, h, w) is looking at
451                  int vert_start = hight * STRIDE_MAXPOOL;
452                  int vert_end = vert_start + HH;
453                  int horiz_start = width * STRIDE_MAXPOOL;
454                  int horiz_end = horiz_start + WW;
455  
456                  tensor* x_slice = view(input, axis(c, c+1), axis(vert_start, vert_end), axis(horiz_start, horiz_end));
457  
458                  // local
459                  tensor* local = TensorLikeFill(x_slice, 0.0);
460                  int idx_max = 0;
461                  float max = x_slice->data[0];
462                  for (int i=0; i<x_slice->size; i++){
463                      if (x_slice->data[at(x_slice, i)] > max) {
464                          max = x_slice->data[at(x_slice, i)];
465                          // bc local is contiguous, record contiguous idx (not the x_slice's idx)
466                          idx_max = i;
467                      }
468                  }
469                  local->data[idx_max] = 1.0;
470                  if (IS_DEBUG_MP)
471                      set_name(local, "local"), print(local);
472  
473                  // upstream
474                  float curr_upstream_float = upstream->data[index(upstream, c, hight, width)]; // scalar
475                  tensor* curr_upstream = TensorLikeFill(x_slice, curr_upstream_float); // broadcast scalar grad to the shape of the slice
476  
477                  // downstream
478                  tensor* curr_downstream = mul_k(local, curr_upstream);
479                  if (IS_DEBUG_MP)
480                      set_name(curr_downstream, "curr_downstream"), print(curr_downstream);
481  
482                  // record downstream grad of the current slice, into the larger tensor (corresponding to the downstream grad)
483                  tensor* downstream_slice = view(downstream, axis(c, c+1), axis(vert_start, vert_end), axis(horiz_start, horiz_end));
484                  // todo-low: use _copy_arr instead of the below; modify that fn to use at instead of t->data[i]
485                  for (int i=0; i<downstream_slice->size; i++){
486                      // note you use at on the downstream_slice bc it's a slice and therefore it's not contiguous, on the
487                      // other hand, curr_downstream is contiguous, so simple curr_downstream->data[i] suffices
488                      downstream_slice->data[at(downstream_slice, i)] = curr_downstream->data[i];
489                  }
490  
491              }
492          }
493      }
494      input->grad = downstream;
495  }
496  
497  
498  
499  // x (B, C, H, W)
500  // w (F, C, HH, WW)
501  // todo: copy/paste from batched_conv_k -- reduce duplication
502  tensor* batched_maxpool_k(tensor* input){
503      assert_input(input, 4);
504  
505      int B = input->shape[0], C = input->shape[1], H = input->shape[2], W = input->shape[3];
506  
507      int HH = 2, WW = 2;
508  
509      int h_out = 1 + (H - HH) / STRIDE_MAXPOOL;
510      int w_out = 1 + (W - WW) / STRIDE_MAXPOOL;
511  
512      if (IS_DEBUG_MP){
513          printf("[batched_maxpool_k] h_out: %i\n", h_out);
514          printf("[batched_maxpool_k] w_out: %i\n", w_out);
515      }
516  
517      tensor* out = EmptyTensor(B, C, h_out, w_out);
518  
519      for (int i=0; i<B; i++){
520          tensor* curr_out = TensorNoData(C, h_out, w_out);
521          curr_out->data = out->data + (i * out->stride[0]);
522  
523          tensor* curr_x = TensorNoData(C, H, W);
524          curr_x->data = input->data + (i * input->stride[0]);
525  
526          maxpool_k_(curr_x, curr_out);
527      }
528      return out;
529  }
530  
531  // todo: copy from bwd_batched_conv_k -- reduce duplication
532  void bwd_batched_maxpool_k(tensor* upstream, tensor* out) {
533      assert_input(upstream, out->num_dims);
534  
535      tensor* input = out->inputs[0];
536  
537      int B = input->shape[0], C = input->shape[1], H = input->shape[2], W = input->shape[3];
538      int HH = 2, WW = 2;
539  
540      int h_out = 1 + (H - HH) / STRIDE_MAXPOOL;
541      int w_out = 1 + (W - WW) / STRIDE_MAXPOOL;
542      if (IS_DEBUG_MP){
543          printf("[bwd_batched_maxpool_k] h_out: %i\n", h_out);
544          printf("[bwd_batched_maxpool_k] w_out: %i\n", w_out);
545      }
546  
547      // make sure it's not initialised with garbage
548      tensor* downstream = TensorLikeFill(input, 0.0);
549  
550      for (int i=0; i<B; i++){
551          // comment: same semantics as in batched_matmul_k
552  
553          tensor* curr_x = TensorNoData(C, H, W);
554          curr_x->data = input->data + (i * input->stride[0]);
555  
556          tensor* curr_upstream = TensorNoData(C, h_out, w_out);
557          curr_upstream->data = upstream->data + (i * upstream->stride[0]);
558  
559          tensor* curr_out = TensorNoData(C, h_out, w_out);
560          curr_out->data = out->data + (i * out->stride[0]);
561  
562          // bwd_maxpool_k unpacks this
563          curr_out->inputs[0] = curr_x;
564  
565          bwd_maxpool_k(curr_upstream, curr_out);
566          tensor* curr_downstream = curr_x->grad; // set by bwd_maxpool_k
567  
568          for (int ii=0; ii<curr_downstream->size; ii++){
569              int offset_batch = i*downstream->stride[0];
570              downstream->data[offset_batch + ii] = curr_downstream->data[ii];
571          }
572          // todo-low:
573          // add_k_(curr_downstream, downstream, downstream);
574      }
575      input->grad = downstream;
576  }