/ codegen.cpp
codegen.cpp
  1  #include "nn.h"
  2  
  3  
  4  // todo-med: rm cprint[s] and re-use lprint[s] funcs -- modify a global pointer to whatever filenme a lprint should write to, set this pointer to "test.py" before calling lprint then set it back to "log.txt"
  5  
  6  
  7  
  8  
  9  extern int index(tensor* t, ...);
 10  extern void (*COPY_TO_DEVICE)(tensor*);
 11  extern tensor* (*COPY_FROM_DEVICE)(tensor*);
 12  
 13  // extern bool IS_CODEGEN;
 14  #include <stdio.h> // structure declaration called FILE
 15  FILE *fopen(char *name, char *mode);
 16  
 17  // to print "longest decimal approximation" use FLT_DIG from <float.h> header.
 18  //  FLT_DIG: Represents the number of decimal digits of precision for the float type. This is
 19  //  typically set to 6, indicating that a float can represent a decimal number with up to 6
 20  //  significant digits without loss of precision. 
 21  //
 22  // todo: but using -- "%12.*f", FLT_DIG -- makes py test asserts fail
 23  // #include <float.h>
 24  
 25  void cprint_1d(tensor* t, FILE *f){
 26      tensor* t_copy = COPY_FROM_DEVICE(t);
 27  
 28      fprintf(f, "    [%12.*f, ]", 8, t_copy->data[0]);
 29  }
 30  
 31  void cprint_2d(tensor* t, FILE *f){
 32      tensor* t_copy = COPY_FROM_DEVICE(t);
 33  
 34      for (int y=0; y<t->shape[0]; y++){
 35          fprintf(f, "    [");
 36          for (int z=0; z<t->shape[1]; z++){
 37              int idx = index(t_copy, y, z);
 38              fprintf(f, "%12.*f, ", 8, t_copy->data[idx]);
 39          }
 40          fprintf(f, "],\n");
 41      }
 42  }
 43  
 44  void cprint_3d(tensor* t, FILE *f){
 45      tensor* t_copy = COPY_FROM_DEVICE(t);
 46  
 47      for (int x=0; x<t->shape[0]; x++){
 48          for (int y=0; y<t->shape[1]; y++){
 49              fprintf(f, "    [");
 50              for (int z=0; z<t->shape[2]; z++){
 51                  int idx = index(t_copy, x, y, z);
 52                  fprintf(f, "%12.*f, ", 8, t_copy->data[idx]);
 53              }
 54              fprintf(f, "],\n");
 55          }
 56          // if (x < t->shape[0]-1)   // avoid empty lines after the last matrix
 57          fprintf(f, "\n");
 58      }
 59  }
 60  
 61  void cprint_4d(tensor* t, FILE *f){
 62      tensor* t_copy = COPY_FROM_DEVICE(t);
 63  
 64      for (int o=0; o<t->shape[0]; o++){
 65          for (int x=0; x<t->shape[1]; x++){
 66              for (int y=0; y<t->shape[2]; y++){
 67                  fprintf(f, "    [");
 68                  for (int z=0; z<t->shape[3]; z++){
 69                      int idx = index(t_copy, o, x, y, z);
 70                      fprintf(f, "%12.*f, ", 8, t_copy->data[idx]);
 71                  }
 72                  fprintf(f, "],\n");
 73              }
 74              // if (x < t->shape[1]-1)
 75              fprintf(f, "\n");
 76          }
 77          // if (o < t->shape[0]-1)
 78          fprintf(f, "\n");
 79      }
 80  }
 81  
 82  void cprint(tensor* t, FILE *f){
 83      // handles Scalar tensors
 84      if (t->num_dims==2 && t->shape[0] == 1 && t->shape[1] == 1) cprint_1d(t, f);
 85      else if (t->num_dims==2) cprint_2d(t, f);
 86      else if (t->num_dims==3) cprint_3d(t, f);
 87      else if (t->num_dims==4) cprint_4d(t, f);
 88      else {
 89          printf("[cprint] Error");
 90          exit(1);
 91      }
 92  }
 93  
 94  
 95  
 96  
 97  void codegen_op_call(tensor* t){
 98      FILE *f = fopen("./generated/test.py", "a");
 99      switch (t->op_type) {
100          case 0:
101              fprintf(f, "%s = %s + %s\n", t->name, t->inputs[0]->name, t->inputs[1]->name);
102              break;
103          case 1:
104              fprintf(f, "%s = %s - %s\n", t->name, t->inputs[0]->name, t->inputs[1]->name);
105              break;
106          case 2:
107              fprintf(f, "%s = %s * %s\n", t->name, t->inputs[0]->name, t->inputs[1]->name);
108              break;
109          case 3:
110              fprintf(f, "%s = %s @ %s\n", t->name, t->inputs[0]->name, t->inputs[1]->name);
111              break;
112          case 20:
113              fprintf(f, "%s = %s / %s\n", t->name, t->inputs[0]->name, t->inputs[1]->name);
114              break;
115          case 18:
116          {
117              int axis = t->non_grad_inputs[0];
118              int num_repeats = t->non_grad_inputs[1];
119              if (axis == 0){
120                  fprintf(f, "%s = %s.repeat(%i, 1)\n", t->name, t->inputs[0]->name, num_repeats);
121              } else if (axis == 1){
122                  fprintf(f, "%s = %s.repeat(1, %i)\n", t->name, t->inputs[0]->name, num_repeats);
123              }
124              break;
125          }
126          case 14:
127              fprintf(f, "%s = torch.gather(%s, dim=1, index=%s.long())\n", t->name, t->inputs[0]->name, t->inputs[1]->name);
128              break;
129          case 4:
130              fprintf(f, "%s = torch.pow(%s, %i)\n", t->name, t->inputs[0]->name, t->non_grad_inputs[0]);
131              break;
132          case 6:
133              fprintf(f, "%s = F.relu(%s)\n", t->name, t->inputs[0]->name);
134              break;
135          case 7:
136              fprintf(f, "%s = torch.transpose(%s, 0, 1)\n", t->name, t->inputs[0]->name);
137              break;
138          case 19:
139              fprintf(f, "%s = - %s\n", t->name, t->inputs[0]->name);
140              break;
141          case 16:
142              fprintf(f, "%s = torch.exp(%s)\n", t->name, t->inputs[0]->name);
143              break;
144          case 15:
145              fprintf(f, "%s = torch.log(%s)\n", t->name, t->inputs[0]->name);
146              break;
147          case 8:
148              fprintf(f, "%s = %s @ %s\n", t->name, t->inputs[0]->name, t->inputs[1]->name);
149              break;
150          case 13:
151              fprintf(f, "%s = torch.flatten(%s, start_dim=1)\n", t->name, t->inputs[0]->name);
152              break;
153          case 5:
154              fprintf(f, "%s = torch.sum(%s)\n", t->name, t->inputs[0]->name);
155              break;
156          case 17:
157              fprintf(f, "%s = torch.sum(%s, axis=1, keepdim=True)\n", t->name, t->inputs[0]->name);
158              break;
159          case 21:
160              fprintf(f, "%s = torch.max(%s)[0]\n", t->name, t->inputs[0]->name);
161              break;
162          case 22:
163              fprintf(f, "%s = torch.max(%s, dim=1, keepdim=True)[0]\n", t->name, t->inputs[0]->name);
164              break;
165          // todo: for case 9 and 10, use STRIDE_CONV
166          case 9:
167              // need squeeze because F.conv2d expects bias of shape (F, ), but bc tiny-torch doesn't support 1d tensors its shape is (F, 1)
168              fprintf(f, "%s = F.conv2d(%s, %s, bias=%s.squeeze(-1), stride=1, padding=0)\n", t->name, t->inputs[0]->name, t->inputs[1]->name, t->inputs[2]->name);
169              break;
170          case 10:
171              fprintf(f, "%s = F.conv2d(%s, %s, bias=%s.squeeze(-1), stride=1, padding=0)\n", t->name, t->inputs[0]->name, t->inputs[1]->name, t->inputs[2]->name);
172              break;
173          // todo: for case 11 and 12, use STRIDE_MAXPOOL
174          case 11:
175              fprintf(f, "%s = F.max_pool2d(%s, kernel_size=2, stride=2, padding=0)\n", t->name, t->inputs[0]->name);
176              break;
177          case 12:
178              fprintf(f, "%s = F.max_pool2d(%s, kernel_size=2, stride=2, padding=0)\n", t->name, t->inputs[0]->name);
179              break;
180  
181          default:
182              printf("[codegen] unexpected op_type");
183              exit(1);
184      }
185      fclose(f);
186  }
187  
188  void codegen_tensor(tensor* t){
189      FILE *f = fopen("./generated/test.py", "a");
190  
191      fprintf(f, "_%s = np.array([\n", t->name);
192      // IS_CODEGEN = true;
193      cprint(t, f);
194      // IS_CODEGEN = false;
195      fprintf(f, "])\n");
196      fprintf(f, "%s = torch.Tensor(_%s.reshape%s)\n", t->name, t->name, str_shape(t));
197      fprintf(f, "%s.requires_grad = True\n\n", t->name);
198      fclose(f);
199  }
200  
201  // todo-low: assert shapes are same
202  void codegen_assert_close(tensor* t){
203      FILE *f = fopen("./generated/test.py", "a");
204  
205      // CODEGEN TENSOR
206      fprintf(f, "_tiny_torch_%s = np.array([\n", t->name);
207      cprint(t, f);
208      fprintf(f, "])\n");
209      fprintf(f, "_tiny_torch_%s = torch.Tensor(_tiny_torch_%s.reshape%s)\n", t->name, t->name, str_shape(t));
210      // it's convenient to see how difference changes throughout the graph (later in the graph vs in the beginning)
211      fprintf(f, "print('%s abs diff: ', torch.sum(torch.abs(%s) - torch.abs(_tiny_torch_%s)).item())\n", t->name, t->name, t->name);
212      // todo-low: use "np.testing.assert_allclose" ?
213      fprintf(f, "assert torch.allclose(%s, _tiny_torch_%s, atol=1e-4)\n\n", t->name, t->name);
214      fclose(f);
215  
216  }
217  
218  void codegen_assert_grad_close(tensor* t){
219      FILE *f = fopen("./generated/test.py", "a");
220  
221      // CODEGEN TENSOR
222      fprintf(f, "_tiny_torch_%s_grad = np.array([\n", t->name);
223      cprint(t->grad, f);
224      fprintf(f, "])\n");
225      fprintf(f, "_tiny_torch_%s_grad = torch.Tensor(_tiny_torch_%s_grad.reshape%s)\n", t->name, t->name, str_shape(t));
226      fprintf(f, "print('%s grad abs diff: ', torch.sum(torch.abs(%s.grad) - torch.abs(_tiny_torch_%s_grad)).item())\n", t->name, t->name, t->name);
227      fprintf(f, "assert torch.allclose(%s.grad, _tiny_torch_%s_grad, atol=1e-4)\n\n", t->name, t->name);
228  
229      // sanity check -- abs diff of grad with itself
230      // fprintf(f, "print('[SANITY CHECK] %s grad abs diff: ', torch.sum(torch.abs(%s.grad) - torch.abs(%s.grad)).item())\n\n", t->name, t->name, t->name);
231      fclose(f);
232  
233  }
234  
235  
236  
237  
238  void codegen_imports(void){
239      // clear contents of the file
240      fclose(fopen("./generated/test.py", "w"));
241  
242      FILE *f = fopen("./generated/test.py", "a");
243      fprintf(f, "# --------------------------\n# ATTENTION:\n# THIS FILE IS AUTOGENERATED\n# DO NOT MODIFY BY HAND.\n# --------------------------\n\n\n\n");
244  
245      fprintf(f, "import numpy as np\n");
246      fprintf(f, "import torch\n");
247      fprintf(f, "import torch.nn.functional as F\n");
248      fprintf(f, "torch.set_printoptions(precision=8, sci_mode=False, threshold=10_000, edgeitems=100, linewidth=1000)\n");
249      // use fclose to that it appears in the beginning of the file
250      fclose(f);
251  }
252  
253  // need this fn bc in tiny torch call to AG is not an op, so can't just add another case to the switch statement (to represent the backward call)
254  void codegen_backward_call(tensor* t){
255      FILE *f = fopen("./generated/test.py", "a");
256      fprintf(f, "%s.backward(torch.ones_like(%s))\n", t->name, t->name);
257      fclose(f);
258  }
259  
260  
261  
262  
263  // comment:
264  // split the three passes of recursive_traverse  (where in the 1st pass I generate leaf tensors;
265  // in the 2nd pass op calls; and in the 3rd pass asserts for intermediate tensors) -- this way the
266  // generated code is more readable (all tensors are grouped together and declared above the ops,
267  // followed by the ops, which are followed by the asserts)
268  
269  void codegen_all_leafs(tensor* t){
270      if (t->num_uses != 0){
271          return;
272      }
273  
274      for (int i=0; i<t->num_inputs; i++){
275          tensor* inp = t->inputs[i];
276          // printf("[codegen_all_leafs] %s\n", inp->name);
277          inp->num_uses--;
278          codegen_all_leafs(inp);
279      }
280  
281      if (t->is_leaf){
282          // todo-low: don't write "input" and "label" tensors ?
283          codegen_tensor(t);
284      }
285  }
286  
287  void codegen_all_ops(tensor* t){
288      if (t->num_uses != 0){
289          return;
290      }
291  
292      for (int i=0; i<t->num_inputs; i++){
293          tensor* inp = t->inputs[i];
294          // printf("[codegen_all_ops] %s\n", inp->name);
295          inp->num_uses--;
296          codegen_all_ops(inp);
297      }
298  
299      // todo: && "t->num_uses == 0"? bc num uses might have changed while processing the inputs
300      if (!t->is_leaf){
301          // note codegen_op_call is after the recursive call, so
302          // codegen_op_call will be called when the recursive calls tack unwinds
303          // which will generate code in the correct (reverse) order
304          codegen_op_call(t);
305      }
306  }
307  
308  void codegen_all_asserts(tensor* t){
309      if (t->num_uses != 0){
310          return;
311      }
312  
313      for (int i=0; i<t->num_inputs; i++){
314          tensor* inp = t->inputs[i];
315          // printf("[codegen_all_asserts] %s\n", inp->name);
316          inp->num_uses--;
317          codegen_all_asserts(inp);
318      }
319  
320      // moved codegen_assert_close into the recursive function -- to avoid adding it here after each instruction
321      if (!t->is_leaf){
322          codegen_assert_close(t);
323      }
324  }
325  
326  
327  
328  
329  // todo-now: can I run recursive_traverse and AG wt destroying t->num_inputs?
330  //   - at the moment I can run either AG or codegen -- to assert grads, have to be able to run both
331  // using these helper fns for now
332  void _save_num_uses(tensor* t){
333      for (int i=0; i<t->num_inputs; i++){
334          tensor* inp = t->inputs[i];
335          inp->_num_uses = inp->num_uses;
336          _save_num_uses(inp);
337      }
338  }
339  
340  void save_num_uses(tensor* t){
341      t->num_uses = 0;
342      _save_num_uses(t);
343  }
344  
345  void rest_num_uses(tensor* t){
346      for (int i=0; i<t->num_inputs; i++){
347          tensor* inp = t->inputs[i];
348          inp->num_uses = inp->_num_uses;
349          rest_num_uses(inp);
350      }
351  }
352  
353  // this fn should be called on a final output of the graph (e.g. loss)
354  void generate_test(tensor* loss){
355      // using t->num_uses in my codegen, and bc AG destroys these need to reset them here
356      rest_num_uses(loss);
357  
358      codegen_imports();
359  
360      FILE *f = fopen("./generated/test.py", "a");
361      fprintf(f, "\n\n\n# ~~~~~~~~~~ leafs ~~~~~~~~~~\n\n\n\n");
362      fclose(f);
363      codegen_all_leafs(loss);
364      rest_num_uses(loss);
365  
366      f = fopen("./generated/test.py", "a");
367      fprintf(f, "\n\n\n# ~~~~~~~~~~ ops ~~~~~~~~~~\n\n\n\n");
368      fclose(f);
369      codegen_all_ops(loss);
370      rest_num_uses(loss);
371  
372      codegen_op_call(loss);
373      codegen_backward_call(loss);
374  
375      f = fopen("./generated/test.py", "a");
376      fprintf(f, "\n\n\n# ~~~~~~~~~~ intermediate tensors asserts ~~~~~~~~~~\n\n\n\n");
377      fclose(f);
378      codegen_all_asserts(loss);
379      rest_num_uses(loss);
380  
381      f = fopen("./generated/test.py", "a");
382      fprintf(f, "\n\n\n# ~~~~~~~~~~ grad asserts ~~~~~~~~~~\n\n\n\n");
383      fclose(f);
384  
385      // // param_head is a global variable
386      // extern param* param_head;
387      param* temp = param_head;
388      while (temp){
389          // printf("[codegen_assert_grad_close] %s\n", temp->value->name);
390          codegen_assert_grad_close(temp->value);
391          temp = temp->next;
392      }
393  
394      f = fopen("./generated/test.py", "a");
395      fprintf(f, "print('--------------------------\\n TEST PASSED!\\n--------------------------')");
396      fclose(f);
397  }
398  
399  
400  
401  
402  /*
403  parts of AG relevant to recursive_traverse:
404  
405  void backward(tensor* loss){
406  
407          for (int i=0; i<t->num_inputs; i++){
408              tensor* inp = t->inputs[i];
409  
410              // leaf tensors have no grad_fn, so don't push them on the queue
411              // bc for each value pop'ed from the queue at later iterations,
412              // this value's grad_fn will be called
413              if (!inp->is_leaf && !is_pushed) {
414                  ready.push_front(inp);
415              }
416  
417              // bc just called grad_fn of one of the outputs (t) of this tensor (inp)
418              inp->num_uses--;
419  
420          }
421      }
422  }
423  */