/ composite_ops.cpp
composite_ops.cpp
1 #include "nn.h" 2 3 tensor* log_softmax(tensor* logits){ 4 // min-max trick for numerical stability, python: "logits -= np.max(logits, axis=1, keepdims=True)" 5 int n_repeats = logits->shape[1]; 6 tensor* maxes = repeat(batched_reduce_max(logits), /*axis = */ 1, /*num_repeats = */ n_repeats); 7 set_name(maxes, "maxes"); // sprint(maxes); 8 tensor* su = sub(logits, maxes); 9 set_name(su, "su"); // sprint(su); 10 11 tensor* ex = exp(su); // (B, ?) 12 set_name(ex, "ex"); // sprint(ex); 13 tensor* re = batched_reduce_sum(ex); // (B, 1) 14 set_name(re, "re"); // sprint(denom); 15 16 // https://github.com/pytorch/pytorch/blob/de484134e4700f95a8a9db5b15daf57d28496a6b/aten/src/ATen/native/vulkan/ops/Softmax.cpp#L196-L203 17 // 18 // note: this is invisible to the generated code, bc uses k_, and not the add op 19 add_k_(re, TensorLikeFill(re, 6e-8), re); 20 21 tensor* denom = log(re); // (B, 1) 22 set_name(denom, "denom"); // sprint(denom); 23 // print(denom); 24 n_repeats = ex->shape[1]; 25 tensor* denom_broadcasted = repeat(denom, /*axis = */ 1, /*num_repeats = */ n_repeats); 26 set_name(denom_broadcasted, "denom_broadcasted"); // sprint(denom_broadcasted); 27 28 tensor* log_sm = sub(su, denom_broadcasted); // (B, ?) 29 set_name(log_sm, "log_sm"); // sprint(log_sm); 30 return log_sm; 31 } 32 33 // expects log probabilities (output of LOGsoftmax) as input 34 tensor* NLL(tensor* log_probs, tensor* label){ 35 int B = label->shape[0]; 36 set_name(label, "label"); // sprint(label); 37 tensor* se = select(log_probs, label); // (B, 1) 38 set_name(se, "se"); // sprint(se); 39 tensor* lgsum = reduce_sum(se); // (, ) 40 set_name(lgsum, "lgsum"); // sprint(lgsum); 41 tensor* nll = neg(lgsum); // (, ) 42 set_name(nll, "nll"); // sprint(nll); 43 // divide by the batch size 44 tensor* nll_normalized = div(nll, TensorScalarFill(B)); // (, ) 45 set_name(nll_normalized, "nll_normalized"); // sprint(nll_normalized); 46 return nll_normalized; 47 } 48 49 50 // comment: log_sofmtax followed by NLL (which expects log probs) -- is the same as "nn.CrossEntropyLoss" 51 // - cross entropy takes in raw logits and internally applies softmax 52 // - NLL takes log-softmax’ed values as input (log probabilities)