/ composite_ops.cpp
composite_ops.cpp
 1  #include "nn.h"
 2  
 3  tensor* log_softmax(tensor* logits){
 4      // min-max trick for numerical stability, python: "logits -= np.max(logits, axis=1, keepdims=True)"
 5      int n_repeats = logits->shape[1];
 6      tensor* maxes = repeat(batched_reduce_max(logits), /*axis = */ 1, /*num_repeats = */ n_repeats);
 7      set_name(maxes, "maxes"); // sprint(maxes);
 8      tensor* su = sub(logits, maxes);
 9      set_name(su, "su"); // sprint(su);
10  
11      tensor* ex = exp(su);                           // (B, ?)
12      set_name(ex, "ex"); // sprint(ex);
13      tensor* re = batched_reduce_sum(ex);            // (B, 1)
14      set_name(re, "re"); // sprint(denom);
15  
16      // https://github.com/pytorch/pytorch/blob/de484134e4700f95a8a9db5b15daf57d28496a6b/aten/src/ATen/native/vulkan/ops/Softmax.cpp#L196-L203
17      //
18      // note: this is invisible to the generated code, bc uses k_, and not the add op
19      add_k_(re, TensorLikeFill(re, 6e-8), re);
20  
21      tensor* denom = log(re);                        // (B, 1)
22      set_name(denom, "denom"); // sprint(denom);
23      // print(denom);
24      n_repeats = ex->shape[1];
25      tensor* denom_broadcasted = repeat(denom, /*axis = */ 1, /*num_repeats = */ n_repeats);
26      set_name(denom_broadcasted, "denom_broadcasted"); // sprint(denom_broadcasted);
27  
28      tensor* log_sm = sub(su, denom_broadcasted);    // (B, ?)
29      set_name(log_sm, "log_sm"); // sprint(log_sm);
30      return log_sm;
31  }
32  
33  // expects log probabilities (output of LOGsoftmax) as input
34  tensor* NLL(tensor* log_probs, tensor* label){
35      int B = label->shape[0];
36      set_name(label, "label"); // sprint(label);
37      tensor* se = select(log_probs, label);      // (B, 1)
38      set_name(se, "se"); // sprint(se);
39      tensor* lgsum = reduce_sum(se);         // (, )
40      set_name(lgsum, "lgsum");  // sprint(lgsum);
41      tensor* nll = neg(lgsum);               // (, )
42      set_name(nll, "nll"); // sprint(nll);
43      // divide by the batch size
44      tensor* nll_normalized = div(nll, TensorScalarFill(B)); // (, )
45      set_name(nll_normalized, "nll_normalized"); // sprint(nll_normalized);
46      return nll_normalized;
47  }
48  
49  
50  // comment: log_sofmtax followed by NLL (which expects log probs) -- is the same as "nn.CrossEntropyLoss"
51  //  - cross entropy takes in raw logits and internally applies softmax
52  //  - NLL takes log-softmax’ed values as input (log probabilities)