functional.py
1 from typing import Optional 2 3 from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction 4 from liger_kernel.ops.fused_linear_cross_entropy import ( 5 LigerFusedLinearCrossEntropyFunction, 6 ) 7 from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction 8 from liger_kernel.ops.geglu import LigerGELUMulFunction 9 from liger_kernel.ops.group_norm import LigerGroupNormFunction 10 from liger_kernel.ops.jsd import LigerJSDFunction 11 from liger_kernel.ops.kl_div import LigerKLDivLossFunction 12 from liger_kernel.ops.layer_norm import LigerLayerNormFunction 13 from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction 14 from liger_kernel.ops.rms_norm import LigerRMSNormFunction 15 from liger_kernel.ops.rope import LigerRopeFunction 16 from liger_kernel.ops.swiglu import LigerSiLUMulFunction 17 18 19 # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html 20 # `weight` and `size_average` are placeholders and not implemented yet 21 def liger_cross_entropy( 22 input, 23 target, 24 weight=None, 25 size_average=None, 26 ignore_index: int = -100, 27 reduce=None, 28 reduction: str = "mean", 29 label_smoothing: float = 0.0, 30 lse_square_scale: float = 0.0, 31 softcap: Optional[float] = None, 32 return_z_loss: bool = False, 33 ): 34 loss, z_loss = LigerCrossEntropyFunction.apply( 35 input, 36 target, 37 ignore_index, 38 lse_square_scale, 39 label_smoothing, 40 reduction, 41 softcap, 42 return_z_loss, 43 ) 44 if not return_z_loss: 45 return loss 46 return loss, z_loss 47 48 49 def liger_fused_linear_cross_entropy( 50 input, 51 weight, 52 target, 53 bias=None, 54 ignore_index: int = -100, 55 lse_square_scale: float = 0.0, 56 label_smoothing: float = 0.0, 57 reduction: str = "mean", 58 softcap: Optional[float] = None, 59 ): 60 return LigerFusedLinearCrossEntropyFunction.apply( 61 input, 62 weight, 63 target, 64 bias, 65 ignore_index, 66 lse_square_scale, 67 label_smoothing, 68 reduction, 69 softcap, 70 ) 71 72 73 def liger_fused_linear_jsd( 74 student_input, 75 student_weight, 76 teacher_input, 77 teacher_weight, 78 shift_labels=None, 79 jsd_beta: float = 0.5, 80 ignore_index: int = -100, 81 temperature: float = 1.0, 82 ): 83 return LigerFusedLinearJSDFunction.apply( 84 student_input, 85 student_weight, 86 teacher_input, 87 teacher_weight, 88 shift_labels, 89 jsd_beta, 90 ignore_index, 91 temperature, 92 ) 93 94 95 def liger_geglu(a, b): 96 return LigerGELUMulFunction.apply(a, b) 97 98 99 def liger_group_norm( 100 X, 101 affine_scaling_weight, 102 affine_shifting_bias, 103 num_channels, 104 num_groups, 105 eps, 106 ): 107 return LigerGroupNormFunction.apply( 108 X, 109 affine_scaling_weight, 110 affine_shifting_bias, 111 num_channels, 112 num_groups, 113 eps, 114 ) 115 116 117 def liger_jsd( 118 input, 119 target, 120 shift_labels=None, 121 beta: float = 0.5, 122 ignore_index: int = -100, 123 ): 124 return LigerJSDFunction.apply( 125 input, 126 target, 127 shift_labels, 128 beta, 129 ignore_index, 130 ) 131 132 133 # conform to the function signature in https://pytorch.org/docs/stable/generated/torch.nn.functional.kl_div.html#torch.nn.functional.kl_div 134 # `size_average` and `mean` are being deprecated in torch API and are placeholders here 135 def liger_kl_div( 136 input, 137 target, 138 size_average: bool = True, 139 reduce: bool = True, 140 reduction: str = "mean", 141 log_target: bool = False, 142 eps: float = 1e-10, 143 ): 144 # Note: the default reduction in torch is `mean`, but being `batchmean` in Liger 145 return LigerKLDivLossFunction.apply( 146 input, 147 target, 148 reduction, 149 log_target, 150 eps, 151 ) 152 153 154 def liger_layer_norm(X, W, B, eps): 155 return LigerLayerNormFunction.apply(X, W, B, eps) 156 157 158 def liger_qwen2vl_mrope(q, k, cos, sin, mrope_section, unsqueeze_dim=1): 159 return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim) 160 161 162 def liger_rms_norm( 163 X, W, eps, offset: float = 0.0, casting_mode: str = "llama", in_place: bool = True 164 ): 165 return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place) 166 167 168 def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): 169 return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim) 170 171 172 def liger_swiglu(a, b): 173 return LigerSiLUMulFunction.apply(a, b)