asserts.cpp
1 2 bool IS_INPUT_DIM_CHECK = true; 3 4 extern void assert_device(tensor* a); 5 6 7 8 void assert_contiguous(tensor* a){ 9 // https://github.com/pytorch/pytorch/blob/dc7461d6f571abb8a6649d0c026793e77d0fd411/torch/_prims_common/__init__.py#L249-L271 10 11 // iterate in the reverse order of shapes and strides 12 13 // a tensor is not contiguous if (to access elements in next dim) 14 // you need to skip more/less elements, than num elements in all 15 // the previous dims of the tensor 16 17 // "-1" bc if e.g. t->num_dims is 2, then only valid locations 18 // for shape are t->shape[0] and t->shape[1] -- IOW t->num_dims - 1 19 for (int expected_stride = 1, i=a->num_dims-1; i>=0; i--){ 20 int x = a->shape[i]; 21 int y = a->stride[i]; 22 // Skips checking strides when a dimension has length 1 23 if (x == 1){ 24 continue; 25 } 26 if (UTILS_DEBUG) printf("[assert_contiguous] (i=%i) y=%i , expected_stride=%i\n", i, y, expected_stride); 27 if (y != expected_stride){ 28 printf("[assert_contiguous] Error: expected contiguous data. Saw:\n"); 29 sprint(a); 30 exit(1); 31 } 32 expected_stride = expected_stride * x; 33 } 34 35 } 36 37 void assert_dim(tensor* a, int expected_dim){ 38 if (a->num_dims!=expected_dim){ 39 printf("[assert_dim] Error: expected %i-dim inputs, saw %i-dim\n", expected_dim, a->num_dims); 40 exit(1); 41 } 42 } 43 44 // todo: would be convenient if this fn also expected a string to be printed (in case of err raised) as an argument 45 // e.g. "[cuda conv_k] expected 3-d input and 4-d kernel\n" 46 // current workaround: launch cuda-gdb; b exit; run; bt 47 void assert_input(tensor* a, int expected_dim){ 48 assert_contiguous(a); 49 assert_device(a); 50 // use -1 when there's not requirement on a particular input ndims -- avoids to reuse assert_input for these cases instead of using "assert_contiguous(a); assert_device(a);" 51 if (expected_dim != -1){ 52 assert_dim(a, expected_dim); 53 } 54 } 55 56 57 58 void assert_same_size(tensor* a, tensor* b){ 59 if (a->size != b->size){ 60 printf("[assert_same_size] Error: expected inputs sizes to match. Saw:\n"); 61 sprint(a); 62 sprint(b); 63 exit(1); 64 } 65 } 66 67 void assert_same_shape(tensor* a, tensor* b){ 68 if (a->num_dims != b->num_dims){ 69 printf("[assert_same_shape] Error: expected inputs of same dimensionality. Saw:\n"); 70 sprint(a); 71 sprint(b); 72 exit(1); 73 } 74 75 for (int i=0; i<a->num_dims; i++){ 76 if (a->shape[i] == b->shape[i]){ 77 continue; 78 } 79 printf("[assert_same_shape] Error: expected input shapes to match. Saw:\n"); 80 sprint(a); 81 sprint(b); 82 exit(1); 83 } 84 } 85 86 void assert_binary_elementwise(tensor* a, tensor* b){ 87 // don't assert n_dim == 2, it's expected that 3d and 4d input will be fed to them, 88 // e.g. mul_k_ is used in many bwd functions; add_k_ is used in batched_flatten_bwd 89 assert_contiguous(a); 90 assert_device(a); 91 92 assert_contiguous(b); 93 assert_device(b); 94 95 // this is basically a side-hatch for _unsafe_add_k; 96 // batched_flatten_k calls add_k_ with a(B, 24) b(B, 6, 2, 2) 97 // because kernels iterate over size, seems the below is more suitable check when comparing shapes 98 if (!IS_INPUT_DIM_CHECK){ 99 printf("[assert_binary_elementwise] NOT checking for shape!\n"); 100 assert_same_size(a, b); 101 return; 102 } 103 104 assert_same_shape(a, b); 105 } 106 107 108 void assert_binary_elementwise_non_contiguous(tensor* a, tensor* b){ 109 assert_device(a); 110 assert_device(b); 111 112 // this is basically a side-hatch for _unsafe_add_k 113 if (!IS_INPUT_DIM_CHECK){ 114 printf("[assert_binary_elementwise_non_contiguous] NOT checking for shape!\n"); 115 assert_same_size(a, b); 116 return; 117 } 118 119 assert_same_shape(a, b); 120 }