/ src / liger_kernel / ops / jsd.py
jsd.py
  1  from typing import Optional
  2  
  3  import torch
  4  import triton
  5  import triton.language as tl
  6  
  7  from liger_kernel.ops.utils import ensure_contiguous
  8  
  9  
 10  @triton.jit
 11  def _jsd_kernel(
 12      X_ptr,  # input in logspace, X = log Q
 13      X_stride,
 14      Y_ptr,  # ground truth in logspace, Y = log P
 15      Y_stride,
 16      loss_ptr,
 17      loss_stride,
 18      dX_ptr,
 19      dX_stride,
 20      label_ptr,
 21      beta: tl.constexpr,
 22      n_non_ignore: int,
 23      ignore_index: tl.constexpr,
 24      n_cols,
 25      BLOCK_SIZE: tl.constexpr,
 26      HAS_LABEL: tl.constexpr,
 27  ):
 28      # JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X)
 29      #             = sum(P * log P + Q * log Q - 2 * M * log M) / 2
 30      #             = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2
 31      # grad_x_i = 0.5 * Q * (X - log_M)
 32      pid = tl.program_id(0).to(tl.int64)
 33      X_ptr += pid * X_stride
 34      dX_ptr += pid * dX_stride
 35      Y_ptr += pid * Y_stride
 36      loss_ptr += pid * loss_stride
 37      label_ptr += pid
 38  
 39      if HAS_LABEL:
 40          label = tl.load(label_ptr)
 41          if label == ignore_index:
 42              for i in range(0, n_cols, BLOCK_SIZE):
 43                  offsets = i + tl.arange(0, BLOCK_SIZE)
 44                  tl.store(dX_ptr + offsets, 0.0, mask=offsets < n_cols)
 45              return
 46  
 47      for i in range(0, n_cols, BLOCK_SIZE):
 48          offsets = i + tl.arange(0, BLOCK_SIZE)
 49          mask = offsets < n_cols
 50          X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
 51          Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
 52  
 53          if beta == 0.0:  # forward KL
 54              Y_prob = tl.exp(Y)
 55              loss = Y_prob * (Y - X)
 56              dX = -Y_prob
 57          elif beta == 1.0:
 58              X_prob = tl.exp(X)
 59              loss = X_prob * (X - Y)
 60              dX = loss + X_prob
 61          else:
 62              Q = tl.exp(X)
 63              P = tl.exp(Y)
 64              M = beta * P + (1 - beta) * Q
 65              log_M = tl.log(M)
 66  
 67              loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
 68              dX = (1 - beta) * Q * (X - log_M)
 69  
 70          loss = loss / n_non_ignore
 71          dX = dX / n_non_ignore
 72          tl.store(loss_ptr + offsets, loss, mask=mask)
 73          tl.store(dX_ptr + offsets, dX, mask=mask)
 74  
 75  
 76  MAX_FUSED_SIZE = 65536
 77  
 78  
 79  def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
 80      BT, V = _input.shape
 81      n_rows = BT
 82      BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
 83      # non reduction loss
 84      loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device)
 85      dX = torch.empty_like(_input)
 86  
 87      if has_label:
 88          n_non_ignore = (shift_labels != ignore_index).sum().item()
 89      else:
 90          n_non_ignore = BT
 91  
 92      _jsd_kernel[(n_rows,)](
 93          X_ptr=_input,  # input in logspace, X = log Q
 94          X_stride=_input.stride(-2),
 95          Y_ptr=target,  # ground truth in logspace, Y = log P
 96          Y_stride=target.stride(-2),
 97          loss_ptr=loss,
 98          loss_stride=loss.stride(-2),
 99          dX_ptr=dX,
100          dX_stride=dX.stride(-2),
101          label_ptr=(
102              shift_labels if has_label else torch.empty(1, device=_input.device)
103          ),  # dummy ptr if no label
104          beta=beta,
105          n_non_ignore=n_non_ignore,
106          ignore_index=ignore_index,
107          n_cols=V,
108          BLOCK_SIZE=BLOCK_SIZE,
109          HAS_LABEL=has_label,
110      )
111  
112      loss = torch.sum(loss)
113      return loss.to(_input.dtype), dX
114  
115  
116  def jsd_backward(dX, grad_output):
117      # If jsd is the last layer, grad_output is 1.0. Skip the mul to save time
118      if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
119          return dX
120      else:
121          return grad_output * dX
122  
123  
124  class LigerJSDFunction(torch.autograd.Function):
125      r"""
126      This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence.
127      .. math::
128          JSD(\beta)(P || Q)
129              = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))
130  
131      .. note::
132          As all the other losses in PyTorch, this function expects the first argument,
133          :attr:`_input`, to be the predictions, the output of the student model, in log-space
134          and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space.
135          This differs from the standard mathematical notation :math:`JSD(P || Q)` where
136          :math:`P` denotes the teacher model and :math:`Q` denotes the student model.
137      """
138  
139      @staticmethod
140      @ensure_contiguous
141      def forward(
142          ctx,
143          _input: torch.Tensor,
144          target: torch.Tensor,
145          shift_labels: Optional[torch.Tensor] = None,
146          beta: float = 0.5,
147          ignore_index: int = -100,
148      ) -> torch.Tensor:
149          """
150          Args:
151              _input (torch.Tensor): predict values with shape (BT, V) in logspace
152              target (torch.Tensor): ground truth values with shape (BT, V) in logspace
153              shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
154              beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
155              ignore_index (int): the index to ignore. Default: -100
156  
157          Returns:
158              loss (torch.Tensor): generalized JSD
159          """
160          has_label = False
161          if shift_labels is not None:
162              assert shift_labels.shape == (
163                  _input.shape[0],
164              ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
165              shift_labels = shift_labels.contiguous()
166              has_label = True
167  
168          loss, dX = jsd_forward(
169              _input, target, shift_labels, beta, ignore_index, has_label
170          )
171          ctx.save_for_backward(dX)
172          return loss
173  
174      @staticmethod
175      @ensure_contiguous
176      def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
177          (dX,) = ctx.saved_tensors
178          dX = jsd_backward(dX, grad_output)
179          return (
180              dX,
181              None,
182              None,
183              None,
184              None,
185          )