kernels_conv.cpp
1 // #include "indexing.cpp" // imported though main -> ops.cpp -> indexing.cpp 2 #include <stdio.h> // sprintf 3 4 #define IS_DEBUG false 5 #define IS_DEBUG_MP false 6 7 #define STRIDE 2 8 9 #define STRIDE_CONV 1 10 #define STRIDE_MAXPOOL 2 11 12 13 14 /* 15 x: Input data of shape (C, H, W) 16 w: Filter weights of shape (F, C, HH, WW) 17 18 todo: implement pytorch's C optional as way to support tunable args 19 - 'stride': The number of pixels between adjacent receptive fields in the horizontal and vertical directions. 20 - 'pad': The number of pixels that will be used to zero-pad the input. 21 During padding, 'pad' zeros should be placed symmetrically (i.e equally on both sides) along the height and width axes of the input. 22 23 Returns a tuple of: 24 - out: Output data, of shape (F, H', W') 25 */ 26 tensor* conv_k_(tensor* input, tensor* kernel, tensor* bias, tensor* out) { 27 28 assert_input(out, 3); 29 assert_input(input, 3); 30 assert_input(kernel, 4); 31 assert_input(bias, 2); 32 33 int C = input->shape[0], H = input->shape[1], W = input->shape[2]; 34 int F = kernel->shape[0], HH = kernel->shape[2], WW = kernel->shape[3]; 35 36 int h_out = 1 + (H - HH) / STRIDE_CONV; 37 int w_out = 1 + (W - WW) / STRIDE_CONV; 38 39 if (IS_DEBUG){ 40 printf("[conv_k_] h_out: %i\n", h_out); 41 printf("[conv_k_] w_out: %i\n", w_out); 42 } 43 44 for (int f=0; f<F; f++){ 45 for (int hight=0; hight<h_out; hight++){ 46 for (int width=0; width<w_out; width++){ 47 48 // 0. select current filter 49 50 // todo: slice 51 // - need 4d slice? not necessarily, can just do pointer arithmetic to skip to the right kernel -- but problem is that it unwraps tensor to float, but kernels expect tensor 52 // - have the unified mechansm for all slices -- replace "f*grad_kernels->stride[0];" with slice 4d 53 54 if (IS_DEBUG) 55 printf("[conv_k_] f*C*HH*WW: %i\n", f*C*HH*WW); 56 57 // simple pointer arithmetic to skip from "f" kernels 58 float* curr_kernel = kernel->data + f*kernel->stride[0]; 59 60 // workaround to put the data into tensor type 61 // todo-high:: constructor does NOT take care of setting correct strides 62 tensor* curr_filter = TensorNoData(C, HH, WW); 63 curr_filter->data = curr_kernel; 64 65 if (IS_DEBUG){ 66 set_name(curr_filter, "curr_filter"); print(curr_filter); 67 printf("[conv_k_] f*C*HH*WW: %i\n", f*kernel->stride[0]); 68 } 69 70 // 1. select the chunk of input that this location in the ouput (f, h, w) is looking at 71 int vert_start = hight * STRIDE_CONV; 72 int vert_end = vert_start + HH; 73 int horiz_start = width * STRIDE_CONV; 74 int horiz_end = horiz_start + WW; 75 76 // e.g. "0:3, 30:32, 30:32" is 18 digits 77 78 // todo: use view instead of slice 79 // - mul_k (used below) -- needs to use .at when looping over a->size 80 tensor* x_slice = slice(input, axis(0, C), axis(vert_start, vert_end), axis(horiz_start, horiz_end)); 81 82 if (IS_DEBUG){ 83 // printf("[conv_k_] buffer x_slice %s\n", buffer); 84 printf("[conv_k_] x_slice->shape: %i, %i, %i\n", x_slice->shape[0], x_slice->shape[1], x_slice->shape[2]); 85 set_name(x_slice, "x_slice"); print(x_slice); 86 } 87 88 // 2. element-wise multiply and sum 89 tensor* curr_out = mul_k(x_slice, curr_filter); 90 if (IS_DEBUG){ 91 printf("[conv_k_] curr_out->shape: %i, %i, %i\n", curr_out->shape[0], curr_out->shape[1], curr_out->shape[2]); 92 set_name(curr_out, "curr_out"); print(curr_out); 93 } 94 95 // todo: add flag to tensors, is_contiguous and make each op check that flag and err otherwise 96 curr_out = reduce_sum_k(curr_out); 97 98 // 3. add bias 99 curr_out->data[0] += bias->data[f]; 100 101 // 4. write to the current location at the output 102 if (IS_DEBUG) 103 printf("[conv_k_] curr_out->data[0]: %f\n", curr_out->data[0]); 104 out->data[index(out, f, hight, width)] = curr_out->data[0]; 105 } 106 } 107 } 108 return out; 109 } 110 111 // input (C, H, W) 112 // kernel (F, C, HH, WW) 113 tensor* conv_k(tensor* input, tensor* kernel, tensor* bias) { 114 int H = input->shape[1], W = input->shape[2]; 115 int F = kernel->shape[0], HH = kernel->shape[2], WW = kernel->shape[3]; 116 117 // todo: de-duplicate same calculations of h_out, w_out for all the kernels in this file 118 int h_out = 1 + (H - HH) / STRIDE_CONV; 119 int w_out = 1 + (W - WW) / STRIDE_CONV; 120 121 if (IS_DEBUG){ 122 printf("[conv_k] h_out: %i\n", h_out); 123 printf("[conv_k] w_out: %i\n", w_out); 124 } 125 126 tensor* out = EmptyTensor(F, h_out, w_out); 127 return conv_k_(input, kernel, bias, out); 128 } 129 130 // conv output, upstream: (F, h_out, w_out) 131 void bwd_conv_k(tensor* upstream, tensor* out) { 132 assert_input(upstream, out->num_dims); 133 tensor* input = out->inputs[0]; 134 tensor* kernel = out->inputs[1]; 135 tensor* bias = out->inputs[2]; 136 137 if (IS_DEBUG) 138 printf("[bwd_conv_k] input->shape: %i, %i, %i\n", input->shape[0], input->shape[1], input->shape[2]); 139 140 int C = input->shape[0], H = input->shape[1], W = input->shape[2]; 141 int F = kernel->shape[0], HH = kernel->shape[2], WW = kernel->shape[3]; 142 143 int h_out = 1 + (H - HH) / STRIDE_CONV; 144 int w_out = 1 + (W - WW) / STRIDE_CONV; 145 146 // make sure it's not initialised with garbage 147 tensor* grad_kernels = TensorLikeFill(kernel, 0.0); 148 tensor* grad_bias = TensorLikeFill(bias, 0.0); 149 tensor* grad_x = TensorLikeFill(input, 0.0); 150 151 for (int f=0; f<F; f++){ 152 for (int hight=0; hight<h_out; hight++){ 153 for (int width=0; width<w_out; width++){ 154 155 // select the chunk of input that this location in the ouput (f, h, w) is looking at 156 int vert_start = hight * STRIDE_CONV; 157 int vert_end = vert_start + HH; 158 int horiz_start = width * STRIDE_CONV; 159 int horiz_end = horiz_start + WW; 160 161 // python: x_slice = input[:, vert_start:vert_end, horiz_start:horiz_end] 162 tensor* x_slice = slice(input, axis(0, C), axis(vert_start, vert_end), axis(horiz_start, horiz_end)); // corresponding slice (was used in forward) 163 if (IS_DEBUG){ 164 // printf("[bwd_conv_k] buffer x_slice %s\n", buffer); 165 printf("[bwd_conv_k] x_slice->shape: %i, %i, %i\n", x_slice->shape[0], x_slice->shape[1], x_slice->shape[2]); 166 set_name(x_slice, "x_slice"); print(x_slice); 167 } 168 169 // python: curr_filter = index(kernel, f); // (F, C, HH, WW) -> (C, HH, WW) 170 tensor* curr_filter = TensorNoData(C, HH, WW); // filter (was used in forward) 171 curr_filter->data = kernel->data + f*kernel->stride[0]; 172 if (IS_DEBUG){ 173 printf("[bwd_conv_k] f*C*HH*WW: %i\n", f*C*HH*WW); 174 set_name(curr_filter, "curr_filter"); print(curr_filter); 175 } 176 177 178 // python: 179 // current_dout = dout[i, f, hight, width] # current up-stream grad scalar 180 // current_x = x[i, :, vert_start:vert_end, horiz_start:horiz_end] # corresponding slice (was used in forward) 181 // dw[f] += current_x * current_dout 182 // 1. local grad of a kernel applied to an x_slice is x_slice 183 // 2. then element-wise multiply that local grad with upstream 184 // 3. then because there's multiple locations in x (x patches) we slid the kernel through -- elementwise sum the grad 185 186 // python: curr_upstream = upstream[f,h,w] 187 float curr_upstream_float = upstream->data[index(upstream, f, hight, width)]; // scalar 188 tensor* curr_upstream = TensorLikeFill(x_slice, curr_upstream_float); // broadcast scalar grad to the shape of the slice 189 if (IS_DEBUG){ 190 printf("[bwd_conv_k] curr_upstream_float: %f", curr_upstream_float); 191 printf("\n[bwd_conv_k] curr_upstream.shape: %i, %i, %i", curr_upstream->shape[0], curr_upstream->shape[1], curr_upstream->shape[2]); 192 set_name(curr_upstream, "curr_upstream"); print(curr_upstream); 193 } 194 195 // kernel->grad = add(kernel->grad, mul_k(x_slice, curr_upstream)) 196 tensor* curr_downstream = mul_k(x_slice, curr_upstream); 197 198 // record downstream grad of the current slice, into the larger tensor (for the downstream grad) 199 200 // workaround for not having non owning slice 4d 201 // todo-high: constructor does not set correct strides in this case 202 tensor* curr_downstream_slice_in_larger_tensor = TensorNoData(C, HH, WW); 203 curr_downstream_slice_in_larger_tensor->data = grad_kernels->data + f*grad_kernels->stride[0]; 204 205 add_k_(curr_downstream_slice_in_larger_tensor, curr_downstream, curr_downstream_slice_in_larger_tensor); 206 207 208 // python: 209 // dx[i, :, vert_start:vert_end, horiz_start:horiz_end] += current_w * current_dout 210 // 1. local grad of a x_patch applied to an kernel is kernel 211 // 2. then element-wise multiply that local grad with upstream 212 // 3. then because there's multiple locations in x (x patches) we slid the kernel through -- elementwise sum the grad 213 214 curr_downstream = mul_k(curr_filter, curr_upstream); 215 216 // record downstream grad of the current slice, into the larger tensor (for the downstream grad) 217 curr_downstream_slice_in_larger_tensor = view(grad_x, axis(0, C), axis(vert_start, vert_end), axis(horiz_start, horiz_end)); 218 219 add_k_(curr_downstream_slice_in_larger_tensor, curr_downstream, curr_downstream_slice_in_larger_tensor); 220 221 222 // grad wrt bias 223 grad_bias->data[f] += 1. * curr_upstream_float; 224 } 225 } 226 } 227 228 kernel->grad = grad_kernels; 229 bias->grad = grad_bias; 230 input->grad = grad_x; 231 } 232 233 234 235 // x (B, C, H, W) 236 // w (F, C, HH, WW) 237 tensor* batched_conv_k(tensor* input, tensor* kernel, tensor* bias){ 238 assert_input(input, 4); 239 assert_input(kernel, 4); 240 assert_input(bias, 2); 241 242 int B = input->shape[0], C = input->shape[1], H = input->shape[2], W = input->shape[3]; 243 int F = kernel->shape[0], HH = kernel->shape[2], WW = kernel->shape[3]; 244 245 int h_out = 1 + (H - HH) / STRIDE_CONV; 246 int w_out = 1 + (W - WW) / STRIDE_CONV; 247 248 if (IS_DEBUG){ 249 printf("[batched_conv_k] h_out: %i\n", h_out); 250 printf("[batched_conv_k] w_out: %i\n", w_out); 251 } 252 253 tensor* out = EmptyTensor(B, F, h_out, w_out); 254 255 for (int i=0; i<B; i++){ 256 // comment: same semantics as in batched_matmul_k 257 258 tensor* curr_out = TensorNoData(F, h_out, w_out); 259 curr_out->data = out->data + (i * out->stride[0]); 260 261 tensor* curr_x = TensorNoData(C, H, W); 262 curr_x->data = input->data + (i * input->stride[0]); 263 264 // comment: add support for 5d tensors? -- NO, w stays 4d 265 conv_k_(curr_x, kernel, bias, curr_out); 266 } 267 return out; 268 } 269 270 // x (B, C, H, W) 271 // w (F, C, HH, WW) 272 // conv output; upstream: (B, F, h_out, w_out) 273 void bwd_batched_conv_k(tensor* upstream, tensor* out) { 274 assert_input(upstream, out->num_dims); 275 276 tensor* input = out->inputs[0]; 277 tensor* kernel = out->inputs[1]; 278 tensor* bias = out->inputs[2]; 279 280 int B = input->shape[0], C = input->shape[1], H = input->shape[2], W = input->shape[3]; 281 int F = kernel->shape[0], HH = kernel->shape[2], WW = kernel->shape[3]; 282 283 int h_out = 1 + (H - HH) / STRIDE_CONV; 284 int w_out = 1 + (W - WW) / STRIDE_CONV; 285 if (IS_DEBUG){ 286 printf("[bwd_batched_conv_k] h_out: %i\n", h_out); 287 printf("[bwd_batched_conv_k] w_out: %i\n", w_out); 288 } 289 290 // make sure it's not initialised with garbage 291 tensor* grad_x = TensorLikeFill(input, 0.0); 292 tensor* grad_bias = TensorLikeFill(bias, 0.0); 293 tensor* grad_kernels = TensorLikeFill(kernel, 0.0); 294 295 for (int i=0; i<B; i++){ 296 // comment: same semantics as in batched_matmul_k 297 298 tensor* curr_x = TensorNoData(C, H, W); 299 curr_x->data = input->data + (i * input->stride[0]); 300 301 tensor* curr_upstream = TensorNoData(F, h_out, w_out); 302 curr_upstream->data = upstream->data + (i * upstream->stride[0]); 303 304 tensor* curr_out = TensorNoData(F, h_out, w_out); 305 curr_out->data = out->data + (i * out->stride[0]); 306 307 // bwd_conv_k unpacks this 308 curr_out->inputs[0] = curr_x; 309 curr_out->inputs[1] = kernel; 310 curr_out->inputs[2] = bias; 311 312 bwd_conv_k(curr_upstream, curr_out); 313 // set by bwd_conv_k 314 tensor* curr_grad_x = curr_x->grad; 315 tensor* curr_grad_filter = kernel->grad; 316 tensor* curr_grad_bias = bias->grad; 317 318 for (int ii=0; ii<curr_grad_x->size; ii++){ 319 int offset_batch = i * grad_x->stride[0]; 320 grad_x->data[offset_batch + ii] = grad_x->data[offset_batch + ii] + curr_grad_x->data[ii]; 321 } 322 323 // can iterate over the grad_kernels bc regardless that we're feeding to bwd_conv_k inputs 324 // only for the current b, we're feeding entire kernel (bc there's no batch dim in kernels) 325 for (int ii=0; ii<grad_kernels->size; ii++){ 326 grad_kernels->data[ii] = grad_kernels->data[ii] + curr_grad_filter->data[ii]; 327 } 328 // todo-low: 329 // add_k_(curr_grad_filter, grad_kernels, grad_kernels); 330 331 for (int ii=0; ii<grad_bias->size; ii++){ 332 grad_bias->data[ii] = grad_bias->data[ii] + curr_grad_bias->data[ii]; 333 } 334 } 335 336 input->grad = grad_x; 337 kernel->grad = grad_kernels; 338 bias->grad = grad_bias; 339 } 340 341 342 343 // x (C, H, W) 344 tensor* maxpool_k_(tensor* input, tensor* out) { 345 assert_input(input, 3); 346 assert_input(out, 3); 347 348 // todo: up until and including "int horiz_end" line, was copied from conv. Reduce duplication. 349 350 int C = input->shape[0], H = input->shape[1], W = input->shape[2]; 351 352 // hyperparameters 353 int HH = 2, WW = 2; 354 355 int h_out = 1 + (H - HH) / STRIDE_MAXPOOL; 356 int w_out = 1 + (W - WW) / STRIDE_MAXPOOL; 357 358 if (IS_DEBUG_MP){ 359 printf("[maxpool_k_] h_out: %i\n", h_out); 360 printf("[maxpool_k_] w_out: %i\n", w_out); 361 } 362 363 for (int c=0; c<C; c++){ 364 for (int hight=0; hight<h_out; hight++){ 365 for (int width=0; width<w_out; width++){ 366 367 if (IS_DEBUG_MP) 368 printf("[maxpool_k_] c*C*HH*WW: %i\n", c*C*HH*WW); 369 // 1. select the chunk of input that this location in the ouput (f, h, w) is looking at 370 int vert_start = hight * STRIDE_MAXPOOL; 371 int vert_end = vert_start + HH; 372 int horiz_start = width * STRIDE_MAXPOOL; 373 int horiz_end = horiz_start + WW; 374 375 // select only 1 channel here 376 // todo: pass as array of axis? 377 // axis slice_axis[3] = {axis(c, c+1), axis(vert_start, vert_end), axis(horiz_start, horiz_end)}; 378 tensor* x_slice = view(input, axis(c, c+1), axis(vert_start, vert_end), axis(horiz_start, horiz_end)); 379 if (IS_DEBUG_MP){ 380 // printf("[maxpool_k_] buffer x_slice %s\n", buffer); 381 printf("[maxpool_k_] x_slice->shape: %i, %i, %i\n", x_slice->shape[0], x_slice->shape[1], x_slice->shape[2]); 382 set_name(x_slice, "x_slice"); print(x_slice); 383 } 384 385 // select maximum element 386 // need to recompute this during backward (so that you put local grad 1 into the lications where there was maximum element) 387 float max = x_slice->data[0]; 388 for (int i=0; i<x_slice->size; i++){ 389 if (x_slice->data[at(x_slice, i)] > max) { 390 // comment: 391 // crucial to use at here and not simple "x_slice->data[i]" bc x_slice is a view thus it's NON contiguous! 392 // alternatively, can create x_slice with "slice" instead of "view" -- the former will make a contiguous copy, in which case fine to index with "x_slice->data[i] 393 max = x_slice->data[at(x_slice, i)]; 394 } 395 } 396 if (IS_DEBUG_MP) 397 printf("[maxpool_k_] max: %f\n", max); 398 out->data[index(out, c, hight, width)] = max; 399 } 400 } 401 } 402 return out; 403 } 404 405 // x (C, H, W) 406 // w (F, C, HH, WW) 407 // todo: was copy/pate from conv_k -- reduce duplication 408 tensor* maxpool_k(tensor* input) { 409 int C = input->shape[0], H = input->shape[1], W = input->shape[2]; 410 411 // hyperparameters 412 int HH = 2, WW = 2; 413 414 int h_out = 1 + (H - HH) / STRIDE_MAXPOOL; 415 int w_out = 1 + (W - WW) / STRIDE_MAXPOOL; 416 417 if (IS_DEBUG_MP){ 418 printf("[maxpool_k] h_out: %i\n", h_out); 419 printf("[maxpool_k] w_out: %i\n", w_out); 420 } 421 422 tensor* out = EmptyTensor(C, h_out, w_out); 423 return maxpool_k_(input, out); 424 } 425 426 // todo-high: 427 // - most of below is copy-pasted from the bwd_conv_k 428 // - it's wasteful to have maxpool logic (65 lines) duplicated exactly in its bwd (which needed for re-computing local grad) -- instead just add another filed on the tensor called local_grad 429 void bwd_maxpool_k(tensor* upstream, tensor* out) { 430 assert_input(upstream, out->num_dims); 431 tensor* input = out->inputs[0]; 432 433 int C = input->shape[0], H = input->shape[1], W = input->shape[2]; 434 435 // hyperparams 436 int HH = 2, WW = 2; 437 438 int h_out = 1 + (H - HH) / STRIDE_MAXPOOL; 439 int w_out = 1 + (W - WW) / STRIDE_MAXPOOL; 440 441 tensor* downstream = TensorLikeFill(input, 0.0); 442 443 for (int c=0; c<C; c++){ 444 for (int hight=0; hight<h_out; hight++){ 445 for (int width=0; width<w_out; width++){ 446 447 if (IS_DEBUG_MP) 448 printf("[bwd_maxpool_k] c*C*HH*WW: %i\n", c*C*HH*WW); 449 450 // 1. select the chunk of input that this location in the ouput (f, h, w) is looking at 451 int vert_start = hight * STRIDE_MAXPOOL; 452 int vert_end = vert_start + HH; 453 int horiz_start = width * STRIDE_MAXPOOL; 454 int horiz_end = horiz_start + WW; 455 456 tensor* x_slice = view(input, axis(c, c+1), axis(vert_start, vert_end), axis(horiz_start, horiz_end)); 457 458 // local 459 tensor* local = TensorLikeFill(x_slice, 0.0); 460 int idx_max = 0; 461 float max = x_slice->data[0]; 462 for (int i=0; i<x_slice->size; i++){ 463 if (x_slice->data[at(x_slice, i)] > max) { 464 max = x_slice->data[at(x_slice, i)]; 465 // bc local is contiguous, record contiguous idx (not the x_slice's idx) 466 idx_max = i; 467 } 468 } 469 local->data[idx_max] = 1.0; 470 if (IS_DEBUG_MP) 471 set_name(local, "local"), print(local); 472 473 // upstream 474 float curr_upstream_float = upstream->data[index(upstream, c, hight, width)]; // scalar 475 tensor* curr_upstream = TensorLikeFill(x_slice, curr_upstream_float); // broadcast scalar grad to the shape of the slice 476 477 // downstream 478 tensor* curr_downstream = mul_k(local, curr_upstream); 479 if (IS_DEBUG_MP) 480 set_name(curr_downstream, "curr_downstream"), print(curr_downstream); 481 482 // record downstream grad of the current slice, into the larger tensor (corresponding to the downstream grad) 483 tensor* downstream_slice = view(downstream, axis(c, c+1), axis(vert_start, vert_end), axis(horiz_start, horiz_end)); 484 // todo-low: use _copy_arr instead of the below; modify that fn to use at instead of t->data[i] 485 for (int i=0; i<downstream_slice->size; i++){ 486 // note you use at on the downstream_slice bc it's a slice and therefore it's not contiguous, on the 487 // other hand, curr_downstream is contiguous, so simple curr_downstream->data[i] suffices 488 downstream_slice->data[at(downstream_slice, i)] = curr_downstream->data[i]; 489 } 490 491 } 492 } 493 } 494 input->grad = downstream; 495 } 496 497 498 499 // x (B, C, H, W) 500 // w (F, C, HH, WW) 501 // todo: copy/paste from batched_conv_k -- reduce duplication 502 tensor* batched_maxpool_k(tensor* input){ 503 assert_input(input, 4); 504 505 int B = input->shape[0], C = input->shape[1], H = input->shape[2], W = input->shape[3]; 506 507 int HH = 2, WW = 2; 508 509 int h_out = 1 + (H - HH) / STRIDE_MAXPOOL; 510 int w_out = 1 + (W - WW) / STRIDE_MAXPOOL; 511 512 if (IS_DEBUG_MP){ 513 printf("[batched_maxpool_k] h_out: %i\n", h_out); 514 printf("[batched_maxpool_k] w_out: %i\n", w_out); 515 } 516 517 tensor* out = EmptyTensor(B, C, h_out, w_out); 518 519 for (int i=0; i<B; i++){ 520 tensor* curr_out = TensorNoData(C, h_out, w_out); 521 curr_out->data = out->data + (i * out->stride[0]); 522 523 tensor* curr_x = TensorNoData(C, H, W); 524 curr_x->data = input->data + (i * input->stride[0]); 525 526 maxpool_k_(curr_x, curr_out); 527 } 528 return out; 529 } 530 531 // todo: copy from bwd_batched_conv_k -- reduce duplication 532 void bwd_batched_maxpool_k(tensor* upstream, tensor* out) { 533 assert_input(upstream, out->num_dims); 534 535 tensor* input = out->inputs[0]; 536 537 int B = input->shape[0], C = input->shape[1], H = input->shape[2], W = input->shape[3]; 538 int HH = 2, WW = 2; 539 540 int h_out = 1 + (H - HH) / STRIDE_MAXPOOL; 541 int w_out = 1 + (W - WW) / STRIDE_MAXPOOL; 542 if (IS_DEBUG_MP){ 543 printf("[bwd_batched_maxpool_k] h_out: %i\n", h_out); 544 printf("[bwd_batched_maxpool_k] w_out: %i\n", w_out); 545 } 546 547 // make sure it's not initialised with garbage 548 tensor* downstream = TensorLikeFill(input, 0.0); 549 550 for (int i=0; i<B; i++){ 551 // comment: same semantics as in batched_matmul_k 552 553 tensor* curr_x = TensorNoData(C, H, W); 554 curr_x->data = input->data + (i * input->stride[0]); 555 556 tensor* curr_upstream = TensorNoData(C, h_out, w_out); 557 curr_upstream->data = upstream->data + (i * upstream->stride[0]); 558 559 tensor* curr_out = TensorNoData(C, h_out, w_out); 560 curr_out->data = out->data + (i * out->stride[0]); 561 562 // bwd_maxpool_k unpacks this 563 curr_out->inputs[0] = curr_x; 564 565 bwd_maxpool_k(curr_upstream, curr_out); 566 tensor* curr_downstream = curr_x->grad; // set by bwd_maxpool_k 567 568 for (int ii=0; ii<curr_downstream->size; ii++){ 569 int offset_batch = i*downstream->stride[0]; 570 downstream->data[offset_batch + ii] = curr_downstream->data[ii]; 571 } 572 // todo-low: 573 // add_k_(curr_downstream, downstream, downstream); 574 } 575 input->grad = downstream; 576 }