/ src / liger_kernel / transformers / functional.py
functional.py
  1  from typing import Optional
  2  
  3  from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction
  4  from liger_kernel.ops.fused_linear_cross_entropy import (
  5      LigerFusedLinearCrossEntropyFunction,
  6  )
  7  from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
  8  from liger_kernel.ops.geglu import LigerGELUMulFunction
  9  from liger_kernel.ops.group_norm import LigerGroupNormFunction
 10  from liger_kernel.ops.jsd import LigerJSDFunction
 11  from liger_kernel.ops.kl_div import LigerKLDivLossFunction
 12  from liger_kernel.ops.layer_norm import LigerLayerNormFunction
 13  from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
 14  from liger_kernel.ops.rms_norm import LigerRMSNormFunction
 15  from liger_kernel.ops.rope import LigerRopeFunction
 16  from liger_kernel.ops.swiglu import LigerSiLUMulFunction
 17  
 18  
 19  # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html
 20  # `weight` and `size_average` are placeholders and not implemented yet
 21  def liger_cross_entropy(
 22      input,
 23      target,
 24      weight=None,
 25      size_average=None,
 26      ignore_index: int = -100,
 27      reduce=None,
 28      reduction: str = "mean",
 29      label_smoothing: float = 0.0,
 30      lse_square_scale: float = 0.0,
 31      softcap: Optional[float] = None,
 32      return_z_loss: bool = False,
 33  ):
 34      loss, z_loss = LigerCrossEntropyFunction.apply(
 35          input,
 36          target,
 37          ignore_index,
 38          lse_square_scale,
 39          label_smoothing,
 40          reduction,
 41          softcap,
 42          return_z_loss,
 43      )
 44      if not return_z_loss:
 45          return loss
 46      return loss, z_loss
 47  
 48  
 49  def liger_fused_linear_cross_entropy(
 50      input,
 51      weight,
 52      target,
 53      bias=None,
 54      ignore_index: int = -100,
 55      lse_square_scale: float = 0.0,
 56      label_smoothing: float = 0.0,
 57      reduction: str = "mean",
 58      softcap: Optional[float] = None,
 59  ):
 60      return LigerFusedLinearCrossEntropyFunction.apply(
 61          input,
 62          weight,
 63          target,
 64          bias,
 65          ignore_index,
 66          lse_square_scale,
 67          label_smoothing,
 68          reduction,
 69          softcap,
 70      )
 71  
 72  
 73  def liger_fused_linear_jsd(
 74      student_input,
 75      student_weight,
 76      teacher_input,
 77      teacher_weight,
 78      shift_labels=None,
 79      jsd_beta: float = 0.5,
 80      ignore_index: int = -100,
 81      temperature: float = 1.0,
 82  ):
 83      return LigerFusedLinearJSDFunction.apply(
 84          student_input,
 85          student_weight,
 86          teacher_input,
 87          teacher_weight,
 88          shift_labels,
 89          jsd_beta,
 90          ignore_index,
 91          temperature,
 92      )
 93  
 94  
 95  def liger_geglu(a, b):
 96      return LigerGELUMulFunction.apply(a, b)
 97  
 98  
 99  def liger_group_norm(
100      X,
101      affine_scaling_weight,
102      affine_shifting_bias,
103      num_channels,
104      num_groups,
105      eps,
106  ):
107      return LigerGroupNormFunction.apply(
108          X,
109          affine_scaling_weight,
110          affine_shifting_bias,
111          num_channels,
112          num_groups,
113          eps,
114      )
115  
116  
117  def liger_jsd(
118      input,
119      target,
120      shift_labels=None,
121      beta: float = 0.5,
122      ignore_index: int = -100,
123  ):
124      return LigerJSDFunction.apply(
125          input,
126          target,
127          shift_labels,
128          beta,
129          ignore_index,
130      )
131  
132  
133  # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.kl_div.html#torch.nn.functional.kl_div
134  # `size_average` and `mean` are being deprecated in torch API and are placeholders here
135  def liger_kl_div(
136      input,
137      target,
138      size_average: bool = True,
139      reduce: bool = True,
140      reduction: str = "mean",
141      log_target: bool = False,
142      eps: float = 1e-10,
143  ):
144      # Note: the default reduction in torch is `mean`, but being `batchmean` in Liger
145      return LigerKLDivLossFunction.apply(
146          input,
147          target,
148          reduction,
149          log_target,
150          eps,
151      )
152  
153  
154  def liger_layer_norm(X, W, B, eps):
155      return LigerLayerNormFunction.apply(X, W, B, eps)
156  
157  
158  def liger_qwen2vl_mrope(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
159      return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)
160  
161  
162  def liger_rms_norm(
163      X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True
164  ):
165      return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place)
166  
167  
168  def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
169      return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)
170  
171  
172  def liger_swiglu(a, b):
173      return LigerSiLUMulFunction.apply(a, b)