/ src / liger_kernel / transformers / fused_linear_cross_entropy.py
fused_linear_cross_entropy.py
 1  from typing import Optional
 2  
 3  import torch
 4  
 5  from liger_kernel.ops.fused_linear_cross_entropy import (
 6      LigerFusedLinearCrossEntropyFunction,
 7  )
 8  
 9  
10  class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
11      def __init__(
12          self,
13          ignore_index: int = -100,
14          lse_square_scale: float = 0.0,
15          label_smoothing: float = 0.0,
16          reduction: str = "mean",
17          softcap: Optional[float] = None,
18      ):
19          super().__init__()
20          assert (label_smoothing >= 0) and (
21              label_smoothing <= 1
22          ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
23          assert reduction in {
24              "mean",
25              "sum",
26              "none",
27          }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
28          assert (
29              softcap is None or softcap > 0
30          ), f"softcap must greater than 0.0 or None. Got: {softcap}"
31          self.ignore_index = ignore_index
32          self.lse_square_scale = lse_square_scale
33          self.label_smoothing = label_smoothing
34          self.reduction = reduction
35          self.softcap = softcap
36  
37      def forward(self, lin_weight, _input, target, bias=None):
38          return LigerFusedLinearCrossEntropyFunction.apply(
39              _input,
40              lin_weight,
41              target,
42              bias,
43              self.ignore_index,
44              self.lse_square_scale,
45              self.label_smoothing,
46              self.reduction,
47              self.softcap,
48          )