/ param.cpp
param.cpp
 1  #include "nn.h"
 2  
 3  #define PARAM_DEBUG false
 4  
 5  // HEAD
 6  param* param_head = NULL;
 7  
 8  void log_params(void){
 9      param* temp = param_head;
10      while (temp){
11          tensor* w = temp->value;
12          lprint(w);
13          temp = temp->next;
14      }
15  }
16  
17  void print_num_params(void){
18      param* temp = param_head;
19      int num_params = 0;
20      while (temp){
21          tensor* w = temp->value;
22          num_params += w->size;
23          temp = temp->next;
24      }
25      printf("\nNum trainable params: %i\n", num_params);
26  }
27  
28  int count_params(void){
29      param* temp = param_head;
30      int num_params = 0;
31      while (temp){
32          num_params += 1;
33          temp = temp->next;
34      }
35      return num_params;
36  }
37  
38  void add_param(tensor* t){
39      param* new_param = (param*)checkMallocErrors(malloc(sizeof(param)));
40      new_param->value = t;
41      new_param->next = NULL;
42      // todo: to reduce memory usage, use "velocity" field on the param
43      //  struct to store "first_moments" for adam -- bc when adam is used
44      //  currently "velocity" serves no purpose but does take up memory.
45      // Even if SGD's "velocity" is slightly semantically different from Adam's
46      //  "first moments", in both cases it's just a tensor with same shape as w
47      //  -- can use it for both
48      new_param->velocity = TensorLikeFill(t, 0.0);
49      // used by adam:
50      new_param->t = 0;
51      new_param->beta1 = 0.9;
52      new_param->beta2 = 0.999;
53      new_param->epsilon = 1e-8;
54      new_param->first_moment = TensorLikeFill(t, 0.0);
55      new_param->second_moment = TensorLikeFill(t, 0.0);
56  
57      // append to the beginning of the linked list
58      new_param->next = param_head;
59      param_head = new_param;
60  }
61  
62  // expects "name" be a NULL terminated string
63  tensor* get_param(const char* name){
64      if (!param_head){
65          printf("[get_param] linked list of params is empty\n");
66          exit(1);
67      }
68  
69      param* temp = param_head;
70      while (temp){
71          if (PARAM_DEBUG){
72              printf("[get_param] iterating over %s\n", temp->value->name);
73          }
74          if (strcmp(name, temp->value->name) == 0){
75              return temp->value;
76          }
77          temp = temp->next;
78      }
79  
80      printf("[get_param] couldn't find %s in the linked list of params\n", name);
81      exit(1);
82  }