/ backends / cuda / kernels.cu
kernels.cu
  1  #include "../../nn.h"
  2  
  3  
  4  
  5  void assert_device(tensor* a){
  6      if (a->device!=CUDA){
  7          printf("[assert_device] Error: expected device cuda. t->name: %s\n", a->name);
  8          exit(1);
  9      }
 10  }
 11  
 12  
 13  
 14  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ backward defined in common ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 15  
 16  
 17  
 18  // binary elementwise
 19  
 20  
 21  typedef void (*BinaryElementwiseKernel)(float* a, float* b, float* out, int size);
 22  
 23  // comment: I think unnecessary to launch 2d blocks for binary/unary ops -- since data is contigious can just launch 1d blocks
 24  tensor* _launch_binary_elementwise(BinaryElementwiseKernel kernel, tensor* a, tensor* b, tensor* out){
 25  
 26      assert_binary_elementwise(a, b);
 27  
 28      // added out arg to _launch_binary_elementwise, so that this fn can be re-used in add_k_
 29      if (!out){
 30          out = TensorLikeFill(a, 0.0);
 31      }
 32  
 33      float num_threads = (float)NUM_THREADS;
 34      dim3 dimGrid(ceil(out->size/num_threads), 1, 1);
 35      dim3 dimBlock(num_threads, 1, 1);
 36  
 37      if (CUDA_DEBUG){
 38          printf("[cuda binary_elementwise] grid: (%f, 1, 1)\n", ceil(a->size/num_threads));
 39          printf("[cuda binary_elementwise] block: (%f, 1, 1)\n", num_threads);
 40      }
 41  
 42      kernel<<<dimGrid, dimBlock>>>(a->data, b->data, out->data, out->size);
 43      return out;
 44  }
 45  
 46  
 47  // todo: unify this as well
 48  __global__ void AddKernel(float* a, float* b, float* out, int size){
 49      int idx = blockIdx.x * blockDim.x + threadIdx.x;
 50      if (idx<size){
 51          out[idx] = a[idx] + b[idx];
 52      }
 53  }
 54  
 55  tensor* add_k(tensor* a, tensor* b){
 56      if (CUDA_DEBUG) printf("[add_k]\n");
 57      return _launch_binary_elementwise(AddKernel, a, b, NULL);
 58  }
 59  
 60  // need the below bc add_k_ is used in backends/common div_bwd
 61  // AddKernel is semantically similar to cpu's add_k_, except the
 62  // latter expects tensors (not *floats)
 63  // Maybe in the future change my cuda kernels to expect tensors to
 64  // get rid of this, other other hand add_k_ and mul_k_ are useful in
 65  // impls of some bwd funcs
 66  tensor* add_k_(tensor* a, tensor* b, tensor* c){
 67      if (CUDA_DEBUG) printf("[add_k_]\n");
 68      return _launch_binary_elementwise(AddKernel, a, b, c);
 69  }
 70  
 71  // does not verify input shape similarity, only verifies input size similarity
 72  // another way to name this fn is "unsafe_add_k_"
 73  tensor* unsafe_add_k_(tensor* a, tensor* b, tensor* c){
 74      if (CUDA_DEBUG) printf("[unsafe_add_k_]\n");
 75      IS_INPUT_DIM_CHECK = false;
 76      tensor* out = _launch_binary_elementwise(AddKernel, a, b, c);
 77      IS_INPUT_DIM_CHECK = true;
 78      return out;
 79  }
 80  
 81  __global__ void SubKernel(float* a, float* b, float* out, int size){
 82      int idx = blockIdx.x * blockDim.x + threadIdx.x;
 83      if (idx<size){
 84          out[idx] = a[idx] - b[idx];
 85      }
 86  }
 87  
 88  tensor* sub_k(tensor* a, tensor* b){
 89      if (CUDA_DEBUG) printf("[sub_k]\n");
 90      return _launch_binary_elementwise(SubKernel, a, b, NULL);
 91  }
 92  
 93  tensor* sub_k_(tensor* a, tensor* b, tensor* c){
 94      if (CUDA_DEBUG) printf("[sub_k_]\n");
 95      return _launch_binary_elementwise(SubKernel, a, b, c);
 96  }
 97  
 98  
 99  __global__ void MulKernel(float* a, float* b, float* out, int size){
100      int idx = blockIdx.x * blockDim.x + threadIdx.x;
101      if (idx<size){
102          out[idx] = a[idx] * b[idx];
103      }
104  }
105  
106  tensor* mul_k(tensor* a, tensor* b){
107      if (CUDA_DEBUG) printf("[mul_k]\n");
108      return _launch_binary_elementwise(MulKernel, a, b, NULL);
109  }
110  
111  // used in exp_bwd, log_bwd
112  tensor* mul_k_(tensor* a, tensor* b, tensor* c){
113      if (CUDA_DEBUG) printf("[mul_k_]\n");
114      return _launch_binary_elementwise(MulKernel, a, b, c);
115  }
116  
117  
118  __global__ void DivKernel(float* a, float* b, float* out, int size){
119      int idx = blockIdx.x * blockDim.x + threadIdx.x;
120      if (idx<size){
121          out[idx] = a[idx] / b[idx];
122      }
123  }
124  
125  tensor* div_k(tensor* a, tensor* b){
126      if (CUDA_DEBUG) printf("[div_k]\n");
127      return _launch_binary_elementwise(DivKernel, a, b, NULL);
128  }
129  
130  
131  // binary
132  
133  
134  // a(?B, N, M) @ b(?B, M, D) = out(?B, N, D)
135  __global__ void MatMulKernel(float* a, float* b, float* out, int N, int M, int D, bool is_batched){
136      // (block idx * num threads per block) + threadidx
137      int n = blockIdx.x * blockDim.x + threadIdx.x;
138      int d = blockIdx.y * blockDim.y + threadIdx.y;
139  
140      // cancels out when not batched
141      int batch = blockIdx.z * is_batched;
142  
143      if ((n<N) && (d<D)){
144          float curr_out = 0.0;
145          for (int m=0; m<M; m++){
146              // todo-low: calling a __host__ function("index_2d") from a __global__ function("MatMulKernel") is not allowed
147              //   But, it doesn't really make sense to use index_2d when you're accessing contiguous cuda memory
148              //   (I belive it's contiguous as this is ouput of cuda-malloc called from inside Tensor constructor)
149              // out += a->data[index_2d(a, n, k)] * b->data[index_2d(b, k, d)];
150              curr_out += a[batch*N*M + n*M + m] * b[batch*M*D + m*D + d];
151          }
152          out[batch*N*D + n*D + d] = curr_out;
153      }
154  }
155  
156  // todo: this _k naming is maintain parity between names of the cpu kernels and the cuda kernels (so that both can be used polimorphically in ops)
157  //     but the below isn't "kernel" in the sense of this word, instead it's a stub that calls the actual kernel (MatMulKernel)
158  
159  // a(N, M) @ b(M, D) = out(N, D)
160  tensor* matmul_k(tensor* a, tensor* b){
161      if (CUDA_DEBUG) printf("[matmul_k]\n");
162      assert_input(a, 2);
163      assert_input(b, 2);
164      if (a->shape[1] != b->shape[0]){
165          printf("[cuda MatMul] Error: inner dim doesn't match, saw: a(%i, %i) b(%i, %i)\n", a->shape[0], a->shape[1], b->shape[0], b->shape[1]);
166          exit(1);
167      }
168  
169      int N = a->shape[0], M = a->shape[1], D = b->shape[1];
170      // todo: fill w 0, here and in other stubs
171      tensor* out = Tensor(N, D);
172  
173      // todo: unify 7 lines below into a fn (e.g. compute_launch_shapes), re-use acorss all stubs
174      // important to have it float to avoid int division
175      float num_threads = (float)NUM_THREADS;
176      dim3 dimGrid(ceil(N/num_threads), ceil(D/num_threads), 1);
177      dim3 dimBlock(num_threads, num_threads, 1);
178  
179      if (CUDA_DEBUG){
180          printf("[cuda MatMul] grid: (%f, %f, 1)\n", ceil(N/num_threads), ceil(D/num_threads));
181          printf("[cuda MatMul] block: (%f, %f, 1)\n", num_threads, num_threads);
182      }
183  
184      // todo: to avoid passing shapes, cp tensor structs to cuda and pass them to the kernel?
185      MatMulKernel<<<dimGrid, dimBlock>>>(a->data, b->data, out->data, N, M, D, false);
186      return out;
187  }
188  
189  // a(B, N, M) @ b(B, M, D) = out(B, N, D)
190  tensor* batched_matmul_k(tensor* a, tensor* b){
191      if (CUDA_DEBUG) printf("[batched_matmul_k]\n");
192      assert_input(a, 3);
193      assert_input(b, 3);
194      if (a->shape[2] != b->shape[1]){
195          printf("[cuda BatchedMatMul] Error: inner dim doesn't match\n");
196          exit(1);
197      }
198  
199      int B = a->shape[0], N = a->shape[1], M = a->shape[2], D = b->shape[2];
200      tensor* out = Tensor(B, N, D);
201  
202      // important to have it float to avoid int division
203      float num_threads = (float)NUM_THREADS;
204      dim3 dimGrid(ceil(N/num_threads), ceil(D/num_threads), B);
205      dim3 dimBlock(num_threads, num_threads, 1);
206  
207      if (CUDA_DEBUG){
208          printf("[cuda BatchedMatMul] grid: (%f, %f, %i)\n", ceil(N/num_threads), ceil(D/num_threads), B);
209          printf("[cuda BatchedMatMul] block: (%f, %f, 1)\n", num_threads, num_threads);
210      }
211  
212      MatMulKernel<<<dimGrid, dimBlock>>>(a->data, b->data, out->data, N, M, D, true);
213      return out;
214  }
215  
216  
217  // unary
218  
219  
220  typedef void (*UnaryKernel)(float* a, float* out, int size);
221  
222  tensor* _launch_unary_elementwise(UnaryKernel kernel, tensor* a){
223  
224      // don't assert n_dims == 2, bc in conv_net.cu 4d input is fed to relu kernel, which calls _launch_unary_elementwise;
225      assert_device(a);
226      assert_contiguous(a);
227  
228      tensor* out = TensorLikeFill(a, 0.0);
229  
230      float num_threads = (float)NUM_THREADS;
231      dim3 dimGrid(ceil(out->size/num_threads), 1, 1);
232      dim3 dimBlock(num_threads, 1, 1);
233  
234      if (CUDA_DEBUG){
235          printf("[cuda unary] grid: (%f, 1, 1)\n", ceil(a->size/num_threads));
236          printf("[cuda unary] block: (%f, 1, 1)\n", num_threads);
237      }
238  
239      kernel<<<dimGrid, dimBlock>>>(a->data, out->data, out->size);
240      return out;
241  }
242  
243  
244  
245  __device__ int pow_exponent;
246  
247  __global__ void PowKernel(float* a, float* out, int size){
248      int idx = blockIdx.x * blockDim.x + threadIdx.x;
249      if (idx<size){
250          out[idx] = (float)pow(a[idx], pow_exponent);
251      }
252  }
253  
254  tensor* pow_k(tensor* a, int exponent){
255      if (CUDA_DEBUG) printf("[pow_k]\n");
256      // cpu's pow_k expects exponent as arg, but here because of standardized
257      // _launch_unary_elementwise interface -- passing it via global
258      cudaMemcpyToSymbol(pow_exponent, &exponent, sizeof(int));
259      return _launch_unary_elementwise(PowKernel, a);
260  }
261  
262  
263  
264  __global__ void SqrtKernel(float* a, float* out, int size){
265      int idx = blockIdx.x * blockDim.x + threadIdx.x;
266      if (idx<size){
267          out[idx] = (float)sqrt(a[idx]);
268      }
269  }
270  
271  tensor* sqrt_k(tensor* a){
272      if (CUDA_DEBUG) printf("[sqrt_k]\n");
273      return _launch_unary_elementwise(SqrtKernel, a);
274  }
275  
276  
277  __global__ void ExpKernel(float* a, float* out, int size){
278      int idx = blockIdx.x * blockDim.x + threadIdx.x;
279      if (idx<size){
280          out[idx] = expf(a[idx]);
281      }
282  }
283  
284  tensor* exp_k(tensor* a){
285      if (CUDA_DEBUG) printf("[exp_k]\n");
286      return _launch_unary_elementwise(ExpKernel, a);
287  }
288  
289  
290  __global__ void LogKernel(float* a, float* out, int size){
291      int idx = blockIdx.x * blockDim.x + threadIdx.x;
292      if (idx<size){
293          out[idx] = logf(a[idx]);
294      }
295  }
296  
297  tensor* log_k(tensor* a){
298      if (CUDA_DEBUG) printf("[log_k]\n");
299      return _launch_unary_elementwise(LogKernel, a);
300  }
301  
302  
303  __global__ void NegKernel(float* a, float* out, int size){
304      int idx = blockIdx.x * blockDim.x + threadIdx.x;
305      if (idx<size){
306          out[idx] = -a[idx];
307      }
308  }
309  
310  tensor* neg_k(tensor* a){
311      if (CUDA_DEBUG) printf("[neg_k]\n");
312      return _launch_unary_elementwise(NegKernel, a);
313  }
314  
315  
316  __global__ void ReluKernel(float* a, float* out, int size){
317      int idx = blockIdx.x * blockDim.x + threadIdx.x;
318      if (idx<size){
319          out[idx] = (a[idx] < 0.0) ? 0.0 : a[idx];
320      }
321  }
322  
323  tensor* relu_k(tensor* a){
324      if (CUDA_DEBUG) printf("[relu_k]\n");
325      return _launch_unary_elementwise(ReluKernel, a);
326  }
327  
328  
329  __global__ void ReluBwdKernel(float* a, float* out, int size){
330      int idx = blockIdx.x * blockDim.x + threadIdx.x;
331      if (idx<size){
332          out[idx] = (a[idx] > 0.0) ? 1.0 : 0.0;
333      }
334  }
335  
336  void relu_bwd(tensor* upstream, tensor* out) {
337      if (CUDA_DEBUG) printf("[relu_bwd]\n");
338      tensor* a = out->inputs[0];
339      tensor* local = _launch_unary_elementwise(ReluBwdKernel, a);
340      a->grad = mul_k(local, upstream);
341  }
342  
343  
344  
345  // todo: for transpose, launch 1d blocks so that it can re-use _launch_unary_elementwise?
346  // todo-high: this kernel (TransposeKernel, BatchedTransposeKernel) basically does: swap strides + "contigify" -- can I get rid of this kernel when
347  //   - support non-contigious data in my cuda kernel (which means when use "at" instead of t[idx] to index into tensors)
348  //   - which implies changing kernels to input tensors not floats, to access strides
349  // a(?B, N, M) -> out(?B, M, N)
350  __global__ void TransposeKernel(float* a, float* out, int M, int N, bool is_batched){
351      int m = blockIdx.x * blockDim.x + threadIdx.x;
352      int n = blockIdx.y * blockDim.y + threadIdx.y;
353      int batch = blockIdx.z * is_batched;
354  
355      if (m<M && n<N){
356          out[batch*M*N + m*N + n] = a[batch*N*M + n*M + m];
357      }
358  }
359  
360  tensor* transpose_k(tensor* a){
361      if (CUDA_DEBUG) printf("[transpose_k]\n");
362      assert_input(a, 2);
363  
364      int N = a->shape[0], M = a->shape[1];
365      // todo: allocate empty
366      tensor* out = Tensor(M, N);
367  
368      float num_threads = (float)NUM_THREADS;
369      dim3 dimGrid(ceil(M/num_threads), ceil(N/num_threads), 1);
370      dim3 dimBlock(num_threads, num_threads, 1);
371  
372      if (CUDA_DEBUG){
373          printf("[cuda Transpose] grid: (%f, %f, 1)\n", ceil(M/num_threads), ceil(N/num_threads));
374          printf("[cuda Transpose] block: (%f, %f, 1)\n", num_threads, num_threads);
375      }
376  
377      TransposeKernel<<<dimGrid, dimBlock>>>(a->data, out->data, M, N, false);
378      return out;
379  }
380  
381  
382  
383  
384  // a(B, 1) -> out(B, N)
385  __global__ void RepeatDim1Kernel(float* a, float* out, int num_repeats, int B){
386      // (block idx * num threads per block) + thread idx
387      int b = blockIdx.x * blockDim.x + threadIdx.x;
388      // this is repeat_idx (0-num_repeats) that this thread represents
389      int i = blockIdx.y; // * blockDim.y + threadIdx.y;
390      if (b<B && i<num_repeats){
391          out[b*num_repeats + i] = a[b];
392      }
393  }
394  
395  // question-now: or use this kernel?
396  // // a(B, 1) -> out(B, N)
397  // __global__ void RepeatKernel(float* a, float* out, int num_repeats, int B){
398  //     // (block idx * num threads per block) + thread idx
399  //     int b = blockIdx.x * blockDim.x + threadIdx.x;
400  //     printf("[kernel] b=%i\n", b);
401  //     if (b<B){
402  //         for (int i=0; i<num_repeats; i++){
403  //             // Indexing into out: since out(B, num_repeats), to get to the next batch element
404  //             // (IOW out->stride[0]) need to skip "num_repeats" locations in memory;
405  //             // Indexing into a: since a(B, 1), stride a->stride[0] is just 1 so can omit it
406  //             out[b*num_repeats + i] = a[b];
407  //         }
408  //     }
409  // }
410  
411  
412  // a(1, N) -> out(B, N)
413  __global__ void RepeatDim0Kernel(float* a, float* out, int num_repeats, int N){
414      // (block idx * num threads per block) + thread idx
415      int b = blockIdx.x * blockDim.x + threadIdx.x;
416      // this is repeat_idx (0-num_repeats) that this thread represents
417      int i = blockIdx.y; // * blockDim.y + threadIdx.y;
418      if (b<num_repeats && i<N){
419          out[b*N + i] = a[i];
420      }
421  }
422  
423  
424  
425  tensor* repeat_k(tensor* a, int axis, int num_repeats){
426      if (CUDA_DEBUG) printf("[repeat_k]\n");
427      assert_input(a, 2);
428      if (axis != 0 && axis != 1){
429          printf("[CUDA RepeatKernel] Unexpected axis\n");
430          exit(1);
431      }
432      if (a->shape[axis] != 1){
433          printf("[CUDA RepeatKernel] Shape error\n");
434          exit(1);
435      }
436  
437      tensor* out;
438  
439      // a.shape (1, N)
440      if (axis==0){
441  
442          int N = a->shape[1];
443          out = Tensor(num_repeats, N);
444  
445          float num_threads = (float)NUM_THREADS;
446          dim3 dimGrid(ceil(num_repeats/num_threads), N, 1);
447          dim3 dimBlock(num_threads, 1, 1);
448  
449          if (CUDA_DEBUG){
450              printf("[cuda RepeatKernel] grid: (%f, %i, 1)\n", ceil(num_repeats/num_threads), N);
451              printf("[cuda RepeatKernel] block: (%f, 1, 1)\n", num_threads);
452          }
453  
454          RepeatDim0Kernel<<<dimGrid, dimBlock>>>(a->data, out->data, num_repeats, N);
455  
456  
457      // a.shape (B, 1)
458      } else if (axis==1){
459  
460          int B = a->shape[0];
461          out = Tensor(B, num_repeats);
462  
463          float num_threads = (float)NUM_THREADS;
464          dim3 dimGrid(ceil(B/num_threads), num_repeats, 1);
465          dim3 dimBlock(num_threads, 1, 1);
466  
467          if (CUDA_DEBUG){
468              printf("[cuda RepeatKernel] grid: (%f, %i, 1)\n", ceil(B/num_threads), num_repeats);
469              printf("[cuda RepeatKernel] block: (%f, 1, 1)\n", num_threads);
470          }
471  
472          RepeatDim1Kernel<<<dimGrid, dimBlock>>>(a->data, out->data, num_repeats, B);
473  
474      }
475      return out;
476  }
477  
478  
479  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ backward NOT defined in common ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
480  
481  
482  __global__ void SelectKernel(float* input, float* idx, int B, int N, float* out);
483  __global__ void SelectSetKernel(float* input, float* idx, int B, int N, float value);
484  
485  // using pointer to value, so that I can pass NULL from select_k (to be able re-use this fn)
486  tensor* _launch_select(tensor* a, tensor* idx, float value){
487      if (CUDA_DEBUG) printf("[_launch_select]\n");
488      assert_input(a, 2);
489      assert_input(idx, 2);
490      if (idx->shape[1]!=1 || idx->shape[0]!=a->shape[0]) {
491          printf("[_launch_select] Error shape\n");
492          exit(1);
493      }
494  
495      int B = a->shape[0], N = a->shape[1];
496  
497      float num_threads = (float)NUM_THREADS;
498      dim3 dimGrid(ceil(B/num_threads), 1, 1);
499      dim3 dimBlock(num_threads, 1, 1);
500  
501      if (CUDA_DEBUG){
502          printf("[cuda _launch_select] grid: (%f, 1, 1)\n", ceil(B/num_threads));
503          printf("[cuda _launch_select] block: (%f, 1, 1)\n", num_threads);
504      }
505  
506      if (value==-1){
507          tensor* out = Tensor(B, 1);
508          SelectKernel<<<dimGrid, dimBlock>>>(a->data, idx->data, B, N, out->data);
509          return out;
510      } else {
511          SelectSetKernel<<<dimGrid, dimBlock>>>(a->data, idx->data, B, N, value);
512          return a;
513      }
514  }
515  
516  // input(s1, s2), idx(s1, 1) -> out(s1, 1)
517  __global__ void SelectKernel(float* input, float* idx, int B, int N, float* out){
518      int b = blockIdx.x * blockDim.x + threadIdx.x;
519      if (b<B){
520          int input_idx = idx[b];
521          // bc out and idx are (B, 1), can simply index into each of them with arr[b]
522          out[b] = input[b*N + input_idx];
523      }
524  }
525  tensor* select_k(tensor* a, tensor* idx){
526      return _launch_select(a, idx, -1);
527  }
528  
529  __global__ void SelectSetKernel(float* input, float* idx, int B, int N, float value){
530      int b = blockIdx.x * blockDim.x + threadIdx.x;
531      if (b<B){
532          int input_idx = idx[b];
533          input[b*N + input_idx] = value;
534      }
535  }
536  tensor* select_set_(tensor* a, tensor* idx, float value){
537      return _launch_select(a, idx, value);
538  }
539  
540  // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ kernels that do not have op wrappers ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
541  
542  tensor* batched_transpose_k(tensor* a){
543      if (CUDA_DEBUG) printf("[batched_transpose_k]\n");
544      assert_input(a, 3);
545  
546      int B = a->shape[0], N = a->shape[1], M = a->shape[2];
547      // todo: allocate empty
548      tensor* out = Tensor(B, M, N);
549  
550      float num_threads = (float)NUM_THREADS;
551      dim3 dimGrid(ceil(M/num_threads), ceil(N/num_threads), B);
552      dim3 dimBlock(num_threads, num_threads, 1);
553  
554      if (CUDA_DEBUG){
555          printf("[cuda Transpose] grid: (%f, %f, %i)\n", ceil(M/num_threads), ceil(N/num_threads), B);
556          printf("[cuda Transpose] block: (%f, %f, 1)\n", num_threads, num_threads);
557      }
558  
559      TransposeKernel<<<dimGrid, dimBlock>>>(a->data, out->data, M, N, true);
560      return out;
561  }