/ indexing.cpp
indexing.cpp
1 #include <stdarg.h> 2 3 4 // todo: support omitting at the both ends 5 // auto-filling missing dims -- allows "kernel[0]" instead of below 6 // curr_filter = slice_4d(kernel, "f, 0:C, 0:HH, 0:WW"); // (F, C, HH, WW) -> (C, HH, WW) 7 8 9 // todo: 10 // support ":" 11 // support ":n" and "n:" 12 // support omitting at the both ends 13 // auto-filling missing dims -- allows "kernel[0]" instead of below 14 // curr_filter = slice_4d(kernel, "f, 0:C, 0:HH, 0:WW"); // (F, C, HH, WW) -> (C, HH, WW) 15 16 struct ax { 17 int start; 18 int end; 19 }; 20 21 ax axis(int start, int end){ 22 ax out = {start, end}; 23 return out; 24 } 25 26 27 28 29 // todo-low: use C++ to implement these as methods on the struct, so that don't need to explicitly pass first arg 30 // todo-high: implement "index", "slice", "view" funcs -- recursively? Removing need to manually impl for each new n_dim 31 32 int index_2d(tensor* t, int y, int z){ 33 return t->stride[0]*y + t->stride[1]*z; 34 } 35 36 int index_3d(tensor* t, int x, int y, int z){ 37 return t->stride[0]*x + t->stride[1]*y + t->stride[2]*z; 38 } 39 40 int index_4d(tensor* t, int o, int x, int y, int z){ 41 return t->stride[0]*o + t->stride[1]*x + t->stride[2]*y + t->stride[3]*z; 42 } 43 44 // most of the logic here unpacks varying number of args so that idx_Nd can be called 45 int index(tensor* t, ...){ 46 va_list args; 47 va_start(args, t); 48 49 int idx; 50 int s0 = va_arg(args, int); 51 int s1 = va_arg(args, int); 52 53 if (t->num_dims==2){ 54 idx = index_2d(t, s0, s1); 55 } else if (t->num_dims==3){ 56 int s2 = va_arg(args, int); 57 idx = index_3d(t, s0, s1, s2); 58 } else if (t->num_dims==4){ 59 int s2 = va_arg(args, int); 60 int s3 = va_arg(args, int); 61 idx = index_4d(t, s0, s1, s2, s3); 62 } else { 63 printf("[index] unexpected t->num_dims: %i\n", t->num_dims); 64 sprint(t); 65 exit(1); 66 } 67 va_end(args); 68 return idx; 69 } 70 71 72 // todo: name it "contigify" 73 // owning views: 74 75 76 tensor* slice_2d(tensor* t, ax axis0, ax axis1){ 77 78 // lowercase to denote sizes of the slice, not of t 79 int y = axis0.end - axis0.start; 80 int z = axis1.end - axis1.start; 81 82 tensor* out = EmptyTensor(y, z); 83 84 // lower-case to denote dims of the slice 85 for (int yi=0; yi<y; yi++){ 86 for (int zi=0; zi<z; zi++){ 87 int out_idx = index_2d(out, yi, zi); 88 int inp_idx = index_2d(t, yi+axis0.start, zi+axis1.start); 89 out->data[out_idx] = t->data[inp_idx]; 90 } 91 } 92 return out; 93 } 94 95 // todo: can make this re-use slice_2d? 96 tensor* slice_3d(tensor* t, ax axis0, ax axis1, ax axis2){ 97 98 // lowercase to denote sizes of the slice, not of t 99 int x = axis0.end - axis0.start; 100 int y = axis1.end - axis1.start; 101 int z = axis2.end - axis2.start; 102 103 tensor* out = EmptyTensor(x, y, z); 104 105 for (int xi=0; xi<x; xi++){ 106 for (int yi=0; yi<y; yi++){ 107 for (int zi=0; zi<z; zi++){ 108 int out_idx = index_3d(out, xi, yi, zi); 109 int inp_idx = index_3d(t, xi+axis0.start, yi+axis1.start, zi+axis2.start); 110 out->data[out_idx] = t->data[inp_idx]; 111 } 112 } 113 } 114 return out; 115 } 116 117 118 #define slice(t, ...) CONCAT(slice_, CONCAT(VA_NARGS(__VA_ARGS__), d))(t, __VA_ARGS__) 119 120 121 122 // non-owning views: 123 124 tensor* view_2d(tensor* t, ax axis0, ax axis1){ 125 126 // lowercase to denote sizes of the slice, not of t 127 int y = axis0.end - axis0.start; 128 int z = axis1.end - axis1.start; 129 130 tensor* out = TensorNoData(y, z); 131 // the default constructor sets strides based on the shapes provided to the constructor. 132 // This is correct in general, however here heed to change 133 out->stride[0] = t->stride[0]; 134 out->stride[1] = t->stride[1]; // this is more general than setting to 1 135 136 // data should point to the first element (of the view) in the original tensor 137 out->data = &t->data[index_2d(t, axis0.start, axis1.start)]; 138 139 // comment: 140 // no need to loop since in this fn no need to copy or even access the elements 141 // as oppose to (slice_2d, slice_3d) 142 143 return out; 144 } 145 146 tensor* view_3d(tensor* t, ax axis0, ax axis1, ax axis2){ 147 // lowercase to denote sizes of the slice, not of t 148 int x = axis0.end - axis0.start; 149 int y = axis1.end - axis1.start; 150 int z = axis2.end - axis2.start; 151 152 tensor* out = TensorNoData(x, y, z); 153 154 out->stride[0] = t->stride[0]; 155 out->stride[1] = t->stride[1]; 156 out->stride[2] = t->stride[2]; 157 158 out->data = &t->data[index_3d(t, axis0.start, axis1.start, axis2.start)]; 159 return out; 160 } 161 162 #define view(t, ...) CONCAT(view_, CONCAT(VA_NARGS(__VA_ARGS__), d))(t, __VA_ARGS__) 163 164 165 166 /* 167 Used in elementwise ops, which were previously implemented (see below), and this is not valid when input is not contiguous 168 > for (int i=0; i<out->size; i++) 169 > out->data[i] = a->data[i] + b->data[i]; 170 */ 171 172 173 int at_2d(tensor* t, int idx){ 174 int z = t->shape[1]; 175 // z instead of stride -- bc want num elements in shapes here 176 int y_idx = idx / z; 177 int z_idx = idx % z; 178 return index_2d(t, y_idx, z_idx); 179 } 180 181 int at_3d(tensor* t, int idx){ 182 // int x = t->shape[0]; 183 int y = t->shape[1]; 184 int z = t->shape[2]; 185 186 // num elements in x: y*z 187 int x_idx = idx / (y*z); 188 // remaining idx 189 idx -= x_idx * (y*z); 190 191 // num elements in y: z 192 int y_idx = idx / z; 193 idx -= (y_idx * z); 194 195 // remaining z 196 int z_idx = idx % z; 197 198 return index_3d(t, x_idx, y_idx, z_idx); 199 } 200 201 int at_4d(tensor* t, int idx){ 202 // int x = t->shape[0]; 203 int y = t->shape[1]; 204 int z = t->shape[2]; 205 int o = t->shape[3]; 206 207 // num elements in x: y*z*o 208 int x_idx = idx / (y*z*o); 209 // remaining idx 210 idx -= x_idx * (y*z*o); 211 212 // num elements in y: z*o 213 int y_idx = idx / (z*o); 214 idx -= y_idx * (z*o); 215 216 // num elements in z: o 217 int z_idx = idx / o; 218 idx -= (z_idx * o); 219 220 // remaining o 221 int o_idx = idx % o; 222 223 return index_4d(t, x_idx, y_idx, z_idx, o_idx); 224 } 225 226 int at(tensor* t, int idx){ 227 if (idx > t->size){ 228 printf("[at] index cannot be greater than t->size\n"); 229 exit(1); 230 } 231 if (t->num_dims==2) return at_2d(t, idx); 232 else if (t->num_dims==3) return at_3d(t, idx); 233 else if (t->num_dims==4) return at_4d(t, idx); 234 else { 235 printf("[at] unexpected t->num_dims (2, 3, or 4)\n"); 236 exit(1); 237 }; 238 }