/ param.cpp
param.cpp
1 #include "nn.h" 2 3 #define PARAM_DEBUG false 4 5 // HEAD 6 param* param_head = NULL; 7 8 void log_params(void){ 9 param* temp = param_head; 10 while (temp){ 11 tensor* w = temp->value; 12 lprint(w); 13 temp = temp->next; 14 } 15 } 16 17 void print_num_params(void){ 18 param* temp = param_head; 19 int num_params = 0; 20 while (temp){ 21 tensor* w = temp->value; 22 num_params += w->size; 23 temp = temp->next; 24 } 25 printf("\nNum trainable params: %i\n", num_params); 26 } 27 28 int count_params(void){ 29 param* temp = param_head; 30 int num_params = 0; 31 while (temp){ 32 num_params += 1; 33 temp = temp->next; 34 } 35 return num_params; 36 } 37 38 void add_param(tensor* t){ 39 param* new_param = (param*)checkMallocErrors(malloc(sizeof(param))); 40 new_param->value = t; 41 new_param->next = NULL; 42 // todo: to reduce memory usage, use "velocity" field on the param 43 // struct to store "first_moments" for adam -- bc when adam is used 44 // currently "velocity" serves no purpose but does take up memory. 45 // Even if SGD's "velocity" is slightly semantically different from Adam's 46 // "first moments", in both cases it's just a tensor with same shape as w 47 // -- can use it for both 48 new_param->velocity = TensorLikeFill(t, 0.0); 49 // used by adam: 50 new_param->t = 0; 51 new_param->beta1 = 0.9; 52 new_param->beta2 = 0.999; 53 new_param->epsilon = 1e-8; 54 new_param->first_moment = TensorLikeFill(t, 0.0); 55 new_param->second_moment = TensorLikeFill(t, 0.0); 56 57 // append to the beginning of the linked list 58 new_param->next = param_head; 59 param_head = new_param; 60 } 61 62 // expects "name" be a NULL terminated string 63 tensor* get_param(const char* name){ 64 if (!param_head){ 65 printf("[get_param] linked list of params is empty\n"); 66 exit(1); 67 } 68 69 param* temp = param_head; 70 while (temp){ 71 if (PARAM_DEBUG){ 72 printf("[get_param] iterating over %s\n", temp->value->name); 73 } 74 if (strcmp(name, temp->value->name) == 0){ 75 return temp->value; 76 } 77 temp = temp->next; 78 } 79 80 printf("[get_param] couldn't find %s in the linked list of params\n", name); 81 exit(1); 82 }