fused_linear_cross_entropy.py
1 from typing import Optional 2 3 import torch 4 5 from liger_kernel.ops.fused_linear_cross_entropy import ( 6 LigerFusedLinearCrossEntropyFunction, 7 ) 8 9 10 class LigerFusedLinearCrossEntropyLoss(torch.nn.Module): 11 def __init__( 12 self, 13 ignore_index: int = -100, 14 lse_square_scale: float = 0.0, 15 label_smoothing: float = 0.0, 16 reduction: str = "mean", 17 softcap: Optional[float] = None, 18 ): 19 super().__init__() 20 assert (label_smoothing >= 0) and ( 21 label_smoothing <= 1 22 ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}" 23 assert reduction in { 24 "mean", 25 "sum", 26 "none", 27 }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}" 28 assert ( 29 softcap is None or softcap > 0 30 ), f"softcap must greater than 0.0 or None. Got: {softcap}" 31 self.ignore_index = ignore_index 32 self.lse_square_scale = lse_square_scale 33 self.label_smoothing = label_smoothing 34 self.reduction = reduction 35 self.softcap = softcap 36 37 def forward(self, lin_weight, _input, target, bias=None): 38 return LigerFusedLinearCrossEntropyFunction.apply( 39 _input, 40 lin_weight, 41 target, 42 bias, 43 self.ignore_index, 44 self.lse_square_scale, 45 self.label_smoothing, 46 self.reduction, 47 self.softcap, 48 )