kl_div.py
1 import torch.nn as nn 2 3 from liger_kernel.ops.kl_div import LigerKLDivLossFunction 4 5 6 class LigerKLDIVLoss(nn.KLDivLoss): 7 def __init__(self, eps: float = 1e-10, *args, **kwargs): 8 super(LigerKLDIVLoss, self).__init__(*args, **kwargs) 9 self.eps = eps 10 11 def forward(self, y_pred, y_true): 12 return LigerKLDivLossFunction.apply( 13 y_pred, y_true, self.reduction, self.log_target, self.eps 14 )