/ src / liger_kernel / transformers / kl_div.py
kl_div.py
 1  import torch.nn as nn
 2  
 3  from liger_kernel.ops.kl_div import LigerKLDivLossFunction
 4  
 5  
 6  class LigerKLDIVLoss(nn.KLDivLoss):
 7      def __init__(self, eps: float = 1e-10, *args, **kwargs):
 8          super(LigerKLDIVLoss, self).__init__(*args, **kwargs)
 9          self.eps = eps
10  
11      def forward(self, y_pred, y_true):
12          return LigerKLDivLossFunction.apply(
13              y_pred, y_true, self.reduction, self.log_target, self.eps
14          )