/ backends / common / asserts.cpp
asserts.cpp
  1  
  2  bool IS_INPUT_DIM_CHECK = true;
  3  
  4  extern void assert_device(tensor* a);
  5  
  6  
  7  
  8  void assert_contiguous(tensor* a){
  9      // https://github.com/pytorch/pytorch/blob/dc7461d6f571abb8a6649d0c026793e77d0fd411/torch/_prims_common/__init__.py#L249-L271
 10  
 11      // iterate in the reverse order of shapes and strides
 12  
 13      // a tensor is not contiguous if (to access elements in next dim)
 14      // you need to skip more/less elements, than num elements in all
 15      // the previous dims of the tensor
 16  
 17      // "-1" bc if e.g. t->num_dims is 2, then only valid locations
 18      // for shape are t->shape[0] and t->shape[1] -- IOW t->num_dims - 1
 19      for (int expected_stride = 1, i=a->num_dims-1; i>=0; i--){
 20          int x = a->shape[i];
 21          int y = a->stride[i];
 22          // Skips checking strides when a dimension has length 1
 23          if (x == 1){
 24              continue;
 25          }
 26          if (UTILS_DEBUG) printf("[assert_contiguous] (i=%i) y=%i , expected_stride=%i\n", i, y, expected_stride);
 27          if (y != expected_stride){
 28              printf("[assert_contiguous] Error: expected contiguous data. Saw:\n");
 29              sprint(a);
 30              exit(1);
 31          }
 32          expected_stride = expected_stride * x;
 33      }
 34  
 35  }
 36  
 37  void assert_dim(tensor* a, int expected_dim){
 38      if (a->num_dims!=expected_dim){
 39          printf("[assert_dim] Error: expected %i-dim inputs, saw %i-dim\n", expected_dim, a->num_dims);
 40          exit(1);
 41      }
 42  }
 43  
 44  // todo: would be convenient if this fn also expected a string to be printed (in case of err raised) as an argument
 45  //  e.g. "[cuda conv_k] expected 3-d input and 4-d kernel\n"
 46  // current workaround: launch cuda-gdb; b exit; run; bt
 47  void assert_input(tensor* a, int expected_dim){
 48      assert_contiguous(a);
 49      assert_device(a);
 50      // use -1 when there's not requirement on a particular input ndims -- avoids to reuse assert_input for these cases instead of using "assert_contiguous(a); assert_device(a);"
 51      if (expected_dim != -1){
 52          assert_dim(a, expected_dim);
 53      }
 54  }
 55  
 56  
 57  
 58  void assert_same_size(tensor* a, tensor* b){
 59      if (a->size != b->size){
 60          printf("[assert_same_size] Error: expected inputs sizes to match. Saw:\n");
 61          sprint(a);
 62          sprint(b);
 63          exit(1);
 64      }
 65  }
 66  
 67  void assert_same_shape(tensor* a, tensor* b){
 68      if (a->num_dims != b->num_dims){
 69          printf("[assert_same_shape] Error: expected inputs of same dimensionality. Saw:\n");
 70          sprint(a);
 71          sprint(b);
 72          exit(1);
 73      }
 74  
 75      for (int i=0; i<a->num_dims; i++){
 76          if (a->shape[i] == b->shape[i]){
 77              continue;
 78          }
 79          printf("[assert_same_shape] Error: expected input shapes to match. Saw:\n");
 80          sprint(a);
 81          sprint(b);
 82          exit(1);
 83      }
 84  }
 85  
 86  void assert_binary_elementwise(tensor* a, tensor* b){
 87      // don't assert n_dim == 2, it's expected that 3d and 4d input will be fed to them,
 88      // e.g. mul_k_ is used in many bwd functions; add_k_ is used in batched_flatten_bwd
 89      assert_contiguous(a);
 90      assert_device(a);
 91  
 92      assert_contiguous(b);
 93      assert_device(b);
 94  
 95      // this is basically a side-hatch for _unsafe_add_k;
 96      // batched_flatten_k calls add_k_ with a(B, 24) b(B, 6, 2, 2)
 97      // because kernels iterate over size, seems the below is more suitable check when comparing shapes
 98      if (!IS_INPUT_DIM_CHECK){
 99          printf("[assert_binary_elementwise] NOT checking for shape!\n");
100          assert_same_size(a, b);
101          return;
102      }
103  
104      assert_same_shape(a, b);
105  }
106  
107  
108  void assert_binary_elementwise_non_contiguous(tensor* a, tensor* b){
109      assert_device(a);
110      assert_device(b);
111  
112      // this is basically a side-hatch for _unsafe_add_k
113      if (!IS_INPUT_DIM_CHECK){
114          printf("[assert_binary_elementwise_non_contiguous] NOT checking for shape!\n");
115          assert_same_size(a, b);
116          return;
117      }
118  
119      assert_same_shape(a, b);
120  }