/ src / liger_kernel / ops / kl_div.py
kl_div.py
  1  from typing import Literal
  2  
  3  import torch
  4  import triton
  5  import triton.language as tl
  6  
  7  from liger_kernel.ops.utils import ensure_contiguous, is_hip
  8  
  9  
 10  def get_num_warps(BLOCK_SIZE):
 11      num_warps = 4
 12      if BLOCK_SIZE >= 32768:
 13          num_warps = 32 if not is_hip() else 16
 14      elif BLOCK_SIZE >= 8192:
 15          num_warps = 16
 16      elif BLOCK_SIZE >= 2048:
 17          num_warps = 8
 18  
 19      return num_warps
 20  
 21  
 22  MAX_FUSED_SIZE = 65536 // 4  # 65536 // 4 or 8 works the best
 23  
 24  REDUCTION_LITERAL = Literal["none", "sum", "mean", "batchmean"]
 25  
 26  _REDUCTION_MODE_NONE = tl.constexpr(0)
 27  _REDUCTION_MODE_SUM = tl.constexpr(1)
 28  _REDUCTION_MODE_MEAN = tl.constexpr(2)
 29  _REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)
 30  
 31  _str_to_reduction_mode = {
 32      "none": _REDUCTION_MODE_NONE.value,
 33      "sum": _REDUCTION_MODE_SUM.value,
 34      "mean": _REDUCTION_MODE_MEAN.value,
 35      "batchmean": _REDUCTION_MODE_BATCHMEAN.value,
 36  }
 37  
 38  
 39  @triton.jit
 40  def _kldiv_kernel_forward(
 41      y_ptr,  # [B, S], prediction ptr, the kernel expects the prediction in log-space
 42      y_stride,  # int, prediction stride
 43      gt_ptr,  # [B, S], ground truth ptr
 44      gt_stride,  # int, ground truth stride
 45      loss_ptr,  # [B] or [B, S] if reduction == _REDUCTION_MODE_NONE, output ptr
 46      loss_stride,  # int, output stride
 47      n_cols,  # int, number of columns in the input tensor
 48      eps,
 49      BLOCK_SIZE: tl.constexpr,
 50      log_target: tl.constexpr = False,
 51      reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,
 52  ):
 53      pid = tl.program_id(0).to(tl.int64)
 54      y_ptr += pid * y_stride
 55      gt_ptr += pid * gt_stride
 56      loss_ptr += pid * loss_stride
 57  
 58      base_offsets = tl.arange(0, BLOCK_SIZE)
 59  
 60      loss_sum = 0.0
 61      for i in range(0, n_cols, BLOCK_SIZE):
 62          offsets = i + base_offsets
 63          mask = offsets < n_cols
 64          y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
 65          y_true = tl.load(gt_ptr + offsets, mask=mask, other=0.0)
 66  
 67          # KL(y_true || y) = y_true * (log(y_true) - log(y))
 68          # We compute KL(y_true || y) with y in the log-space
 69          if not log_target:
 70              loss = y_true * (tl.log(tl.maximum(y_true, eps)) - y)
 71          else:
 72              loss = tl.exp(y_true) * (y_true - y)
 73  
 74          if reduction == _REDUCTION_MODE_NONE:
 75              tl.store(loss_ptr + offsets, loss, mask=mask)
 76          else:
 77              loss_sum += tl.sum(loss, axis=0)
 78  
 79      if reduction != _REDUCTION_MODE_NONE:
 80          tl.store(loss_ptr, loss_sum)
 81  
 82  
 83  @triton.jit
 84  def _kldiv_kernel_backward(
 85      target_ptr,
 86      target_stride,
 87      new_grads_ptr,
 88      new_grads_stride,
 89      n_cols,
 90      BLOCK_SIZE: tl.constexpr,
 91      log_target: tl.constexpr = False,
 92  ):
 93      pid = tl.program_id(0).to(tl.int64)
 94  
 95      target_ptr += pid * target_stride
 96      new_grads_ptr += pid * new_grads_stride
 97  
 98      offsets = tl.arange(0, BLOCK_SIZE)
 99      mask = offsets < n_cols
100  
101      for i in range(0, n_cols, BLOCK_SIZE):
102          offsets = i + tl.arange(0, BLOCK_SIZE)
103          mask = offsets < n_cols
104  
105          target = tl.load(target_ptr + offsets, mask=mask, other=0.0)
106  
107          if not log_target:
108              res = target * -1
109          else:
110              res = -tl.exp(target)
111  
112          tl.store(new_grads_ptr + offsets, res, mask=mask)
113  
114  
115  def kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps):  # [BT, V]
116      BT, V = y_pred.shape
117  
118      BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
119      num_warps = get_num_warps(BLOCK_SIZE)
120  
121      grid = (BT,)
122      reduction = _str_to_reduction_mode[reduction]
123  
124      out_size = (BT, V) if reduction == _REDUCTION_MODE_NONE.value else (BT,)
125      output_tensor = torch.zeros(out_size, device=y_pred.device, dtype=torch.float32)
126  
127      _kldiv_kernel_forward[grid](
128          y_pred,
129          y_pred.stride(0),
130          y_true,
131          y_true.stride(0),
132          output_tensor,
133          output_tensor.stride(0),
134          V,
135          eps=eps,
136          BLOCK_SIZE=BLOCK_SIZE,
137          num_warps=num_warps,
138          log_target=log_target,
139          reduction=reduction,
140      )
141  
142      # calculated according to the reduction mode same as in Pytorch. In the later versions, `mean` will be changed to the same behavior as `batchmean`
143      # https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
144      # https://github.com/pytorch/pytorch/blob/d7b57c4d63edb42e1deeeba9497fcb5f1f748ff2/torch/nn/functional.py#L3372
145      if reduction == _REDUCTION_MODE_BATCHMEAN.value:
146          return output_tensor.sum() / BT
147      elif reduction == _REDUCTION_MODE_SUM.value:
148          return output_tensor.sum(dim=0)
149      elif reduction == _REDUCTION_MODE_MEAN.value:
150          return output_tensor.sum() / (BT * V)
151      else:
152          return output_tensor
153  
154  
155  def kldiv_backward_triton(target, grad_output, new_grads, log_target):
156      BT, V = target.shape
157  
158      BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
159      num_warps = get_num_warps(BLOCK_SIZE)
160  
161      grid = (BT,)
162  
163      # We store the gradients in-place in the input tensor
164      _kldiv_kernel_backward[grid](
165          target,
166          target.stride(0),
167          new_grads,
168          new_grads.stride(0),
169          V,
170          BLOCK_SIZE=BLOCK_SIZE,
171          num_warps=num_warps,
172          log_target=log_target,
173      )
174  
175      # If cross entropy is the last layer, grad_output is 1.0. Skip the mul then.
176      if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
177          return new_grads
178  
179      return new_grads * grad_output
180  
181  
182  class LigerKLDivLossFunction(torch.autograd.Function):
183      """
184      Class implementing the forward and backward pass for the KL Divergence Loss using Triton, as defined by the following formula:
185      ```python
186      if log_target:
187          loss = target * (target.log() - input)
188      else:
189          loss = target.exp() * (target - input)
190      ```,
191      then the loss is reduced according to the `reduction` parameter.
192      as defined in the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html
193      """
194  
195      @staticmethod
196      @ensure_contiguous
197      def forward(
198          ctx,
199          y_pred: torch.Tensor,
200          y_true: torch.Tensor,
201          reduction: REDUCTION_LITERAL = "batchmean",
202          log_target: bool = False,
203          eps: float = 1e-10,
204      ) -> torch.Tensor:
205          """A forward pass for the KL Divergence Loss.
206  
207          Args:
208              ctx: Torch autograd context
209              y_pred (torch.Tensor): A tensor of shape (BT, V) containing the predicted values, expected to be log-probabilities.
210              y_true (torch.Tensor): A tensor of shape (BT, V) containing the target values, expected to be either probabilities or log-probabilities, depending on the value of `log_target`.
211              reduction (REDUCTION_LITERAL, optional): Reduction to be used. Defaults to "batchmean".
212              log_target (bool, optional): If set to true, expects the ground truth to already be log-probabilities. Defaults to False.
213              eps: (float, optional): A small value to avoid division by zero. Defaults to 1e-10.
214  
215          Returns:
216              torch.Tensor: The computed KL Divergence Loss, with shape (BT, V) if `reduction` is "none", else a scalar.
217          """
218          ctx.save_for_backward(y_true)
219          ctx.reduction = reduction
220          ctx.log_target = log_target
221          return kldiv_forward_triton(
222              y_pred, y_true, log_target=log_target, reduction=reduction, eps=eps
223          )
224  
225      @staticmethod
226      @ensure_contiguous
227      def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
228          """A backward pass for the KL Divergence Loss.
229  
230          Args:
231              ctx: Torch autograd context
232              grad_output (torch.Tensor): The gradient of the loss with respect to the output.
233  
234          Returns:
235              tuple[torch.Tensor, None, None, None, None]: The gradient of the loss with respect to the inputs and None for the other arguments of the forward method.
236          """
237          (y_true,) = ctx.saved_tensors
238  
239          new_grads = torch.empty_like(y_true)
240  
241          derivative = kldiv_backward_triton(
242              y_true, grad_output, new_grads, ctx.log_target
243          )
244  
245          if ctx.reduction == "batchmean":
246              derivative = derivative / y_true.shape[0]
247          elif ctx.reduction == "sum" or ctx.reduction == "none":
248              pass
249          elif ctx.reduction == "mean":
250              derivative = derivative / (y_true.shape[0] * y_true.shape[1])
251  
252          return (
253              derivative,
254              None,
255              None,
256              None,
257              None,
258          )