kl_div.py
1 from typing import Literal 2 3 import torch 4 import triton 5 import triton.language as tl 6 7 from liger_kernel.ops.utils import ensure_contiguous, is_hip 8 9 10 def get_num_warps(BLOCK_SIZE): 11 num_warps = 4 12 if BLOCK_SIZE >= 32768: 13 num_warps = 32 if not is_hip() else 16 14 elif BLOCK_SIZE >= 8192: 15 num_warps = 16 16 elif BLOCK_SIZE >= 2048: 17 num_warps = 8 18 19 return num_warps 20 21 22 MAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best 23 24 REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"] 25 26 _REDUCTION_MODE_NONE = tl.constexpr(0) 27 _REDUCTION_MODE_SUM = tl.constexpr(1) 28 _REDUCTION_MODE_MEAN = tl.constexpr(2) 29 _REDUCTION_MODE_BATCHMEAN = tl.constexpr(3) 30 31 _str_to_reduction_mode = { 32 "none": _REDUCTION_MODE_NONE.value, 33 "sum": _REDUCTION_MODE_SUM.value, 34 "mean": _REDUCTION_MODE_MEAN.value, 35 "batchmean": _REDUCTION_MODE_BATCHMEAN.value, 36 } 37 38 39 @triton.jit 40 def _kldiv_kernel_forward( 41 y_ptr, # [B, S], prediction ptr, the kernel expects the prediction in log-space 42 y_stride, # int, prediction stride 43 gt_ptr, # [B, S], ground truth ptr 44 gt_stride, # int, ground truth stride 45 loss_ptr, # [B] or [B, S] if reduction == _REDUCTION_MODE_NONE, output ptr 46 loss_stride, # int, output stride 47 n_cols, # int, number of columns in the input tensor 48 eps, 49 BLOCK_SIZE: tl.constexpr, 50 log_target: tl.constexpr = False, 51 reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN, 52 ): 53 pid = tl.program_id(0).to(tl.int64) 54 y_ptr += pid * y_stride 55 gt_ptr += pid * gt_stride 56 loss_ptr += pid * loss_stride 57 58 base_offsets = tl.arange(0, BLOCK_SIZE) 59 60 loss_sum = 0.0 61 for i in range(0, n_cols, BLOCK_SIZE): 62 offsets = i + base_offsets 63 mask = offsets < n_cols 64 y = tl.load(y_ptr + offsets, mask=mask, other=0.0) 65 y_true = tl.load(gt_ptr + offsets, mask=mask, other=0.0) 66 67 # KL(y_true || y) = y_true * (log(y_true) - log(y)) 68 # We compute KL(y_true || y) with y in the log-space 69 if not log_target: 70 loss = y_true * (tl.log(tl.maximum(y_true, eps)) - y) 71 else: 72 loss = tl.exp(y_true) * (y_true - y) 73 74 if reduction == _REDUCTION_MODE_NONE: 75 tl.store(loss_ptr + offsets, loss, mask=mask) 76 else: 77 loss_sum += tl.sum(loss, axis=0) 78 79 if reduction != _REDUCTION_MODE_NONE: 80 tl.store(loss_ptr, loss_sum) 81 82 83 @triton.jit 84 def _kldiv_kernel_backward( 85 target_ptr, 86 target_stride, 87 new_grads_ptr, 88 new_grads_stride, 89 n_cols, 90 BLOCK_SIZE: tl.constexpr, 91 log_target: tl.constexpr = False, 92 ): 93 pid = tl.program_id(0).to(tl.int64) 94 95 target_ptr += pid * target_stride 96 new_grads_ptr += pid * new_grads_stride 97 98 offsets = tl.arange(0, BLOCK_SIZE) 99 mask = offsets < n_cols 100 101 for i in range(0, n_cols, BLOCK_SIZE): 102 offsets = i + tl.arange(0, BLOCK_SIZE) 103 mask = offsets < n_cols 104 105 target = tl.load(target_ptr + offsets, mask=mask, other=0.0) 106 107 if not log_target: 108 res = target * -1 109 else: 110 res = -tl.exp(target) 111 112 tl.store(new_grads_ptr + offsets, res, mask=mask) 113 114 115 def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V] 116 BT, V = y_pred.shape 117 118 BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) 119 num_warps = get_num_warps(BLOCK_SIZE) 120 121 grid = (BT,) 122 reduction = _str_to_reduction_mode[reduction] 123 124 out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,) 125 output_tensor = torch.zeros(out_size, device=y_pred.device, dtype=torch.float32) 126 127 _kldiv_kernel_forward[grid]( 128 y_pred, 129 y_pred.stride(0), 130 y_true, 131 y_true.stride(0), 132 output_tensor, 133 output_tensor.stride(0), 134 V, 135 eps=eps, 136 BLOCK_SIZE=BLOCK_SIZE, 137 num_warps=num_warps, 138 log_target=log_target, 139 reduction=reduction, 140 ) 141 142 # calculated according to the reduction mode same as in Pytorch. In the later versions, `mean` will be changed to the same behavior as `batchmean` 143 # https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html 144 # https://github.com/pytorch/pytorch/blob/d7b57c4d63edb42e1deeeba9497fcb5f1f748ff2/torch/nn/functional.py#L3372 145 if reduction == _REDUCTION_MODE_BATCHMEAN.value: 146 return output_tensor.sum() / BT 147 elif reduction == _REDUCTION_MODE_SUM.value: 148 return output_tensor.sum(dim=0) 149 elif reduction == _REDUCTION_MODE_MEAN.value: 150 return output_tensor.sum() / (BT * V) 151 else: 152 return output_tensor 153 154 155 def kldiv_backward_triton(target, grad_output, new_grads, log_target): 156 BT, V = target.shape 157 158 BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) 159 num_warps = get_num_warps(BLOCK_SIZE) 160 161 grid = (BT,) 162 163 # We store the gradients in-place in the input tensor 164 _kldiv_kernel_backward[grid]( 165 target, 166 target.stride(0), 167 new_grads, 168 new_grads.stride(0), 169 V, 170 BLOCK_SIZE=BLOCK_SIZE, 171 num_warps=num_warps, 172 log_target=log_target, 173 ) 174 175 # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then. 176 if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): 177 return new_grads 178 179 return new_grads * grad_output 180 181 182 class LigerKLDivLossFunction(torch.autograd.Function): 183 """ 184 Class implementing the forward and backward pass for the KL Divergence Loss using Triton, as defined by the following formula: 185 ```python 186 if log_target: 187 loss = target * (target.log() - input) 188 else: 189 loss = target.exp() * (target - input) 190 ```, 191 then the loss is reduced according to the `reduction` parameter. 192 as defined in the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html 193 """ 194 195 @staticmethod 196 @ensure_contiguous 197 def forward( 198 ctx, 199 y_pred: torch.Tensor, 200 y_true: torch.Tensor, 201 reduction: REDUCTION_LITERAL = "batchmean", 202 log_target: bool = False, 203 eps: float = 1e-10, 204 ) -> torch.Tensor: 205 """A forward pass for the KL Divergence Loss. 206 207 Args: 208 ctx: Torch autograd context 209 y_pred (torch.Tensor): A tensor of shape (BT, V) containing the predicted values, expected to be log-probabilities. 210 y_true (torch.Tensor): A tensor of shape (BT, V) containing the target values, expected to be either probabilities or log-probabilities, depending on the value of `log_target`. 211 reduction (REDUCTION_LITERAL, optional): Reduction to be used. Defaults to "batchmean". 212 log_target (bool, optional): If set to true, expects the ground truth to already be log-probabilities. Defaults to False. 213 eps: (float, optional): A small value to avoid division by zero. Defaults to 1e-10. 214 215 Returns: 216 torch.Tensor: The computed KL Divergence Loss, with shape (BT, V) if `reduction` is "none", else a scalar. 217 """ 218 ctx.save_for_backward(y_true) 219 ctx.reduction = reduction 220 ctx.log_target = log_target 221 return kldiv_forward_triton( 222 y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps 223 ) 224 225 @staticmethod 226 @ensure_contiguous 227 def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: 228 """A backward pass for the KL Divergence Loss. 229 230 Args: 231 ctx: Torch autograd context 232 grad_output (torch.Tensor): The gradient of the loss with respect to the output. 233 234 Returns: 235 tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method. 236 """ 237 (y_true,) = ctx.saved_tensors 238 239 new_grads = torch.empty_like(y_true) 240 241 derivative = kldiv_backward_triton( 242 y_true, grad_output, new_grads, ctx.log_target 243 ) 244 245 if ctx.reduction == "batchmean": 246 derivative = derivative / y_true.shape[0] 247 elif ctx.reduction == "sum" or ctx.reduction == "none": 248 pass 249 elif ctx.reduction == "mean": 250 derivative = derivative / (y_true.shape[0] * y_true.shape[1]) 251 252 return ( 253 derivative, 254 None, 255 None, 256 None, 257 None, 258 )