/ models / mlp.cu
mlp.cu
 1  #include <iostream> // todo: use C only
 2  
 3  #define DEVICE CPU
 4  
 5  
 6  #include "../nn.h"
 7  #include "../tensor.cpp"
 8  #include "../ops.cpp"
 9  #include "../composite_ops.cpp"
10  #include "../cifar10.cpp"
11  #include "../print.cpp"
12  #include "../optim.cpp"
13  #include "../codegen.cpp"
14  
15  #define NUM_EP 1
16  #define LR 0.02
17  
18  
19  float train_step(tensor* x, tensor* w1, tensor* w2)
20  {
21      // *** Net ***
22  
23      // x(N, M) @ w1(M, D) = out1(N, D)
24      tensor* out1 = matmul(x, w1);
25      set_name(out1, "matmul_1");
26  
27      // out2(N, D)
28      tensor* out2 = relu(out1);
29      set_name(out2, "relu");
30  
31      // out2(N, D) @ w2(D, O) = out3(N, O)
32      tensor* out3 = matmul(out2, w2);
33      set_name(out3, "matmul_2");
34  
35      // *** Loss fn ***
36      tensor* y = TensorLikeFill(out3, 0.5); // dummy label
37      tensor* loss = reduce_sum(pow(sub(y, out3), 2));
38  
39      // *** Zero-out grads ***
40      zero_grads();
41  
42      // *** Backward ***
43      save_num_uses(loss);
44      loss->backward(loss);
45  
46      generate_test(loss);
47  
48      // *** Optim Step ***
49      sgd(LR);
50  
51      graphviz(loss);
52  
53      return COPY_FROM_DEVICE(loss)->data[0];
54  }
55  
56  
57  int main(void) {
58      // random num generator init, must be called once
59      // srand(time(NULL));
60      srand(123);
61      set_backend_device();
62  
63      int N = 16;
64      int M = 2;
65      int D = 4;
66      int O = 1;
67  
68      // *** Init ***
69      tensor* x = Tensor(N, M);
70      set_name(x, "x"); print(x);
71  
72      tensor* w1 = Tensor(M, D);
73      set_name(w1, "w1"); print(w1);
74      add_param(w1);
75  
76      tensor* w2 = Tensor(D, O);
77      set_name(w2, "w2"); print(w2);
78      add_param(w2);
79  
80      // *** Train ***
81      for (int ep_idx=0; ep_idx<NUM_EP; ep_idx++) {
82          float loss = train_step(x, w1, w2);
83          cout << "\nep: " << ep_idx << "; loss: " << loss << endl;
84      }
85  
86      return 0;
87  }