jsd.py
1 from typing import Optional 2 3 import torch 4 import triton 5 import triton.language as tl 6 7 from liger_kernel.ops.utils import ensure_contiguous 8 9 10 @triton.jit 11 def _jsd_kernel( 12 X_ptr, # input in logspace, X = log Q 13 X_stride, 14 Y_ptr, # ground truth in logspace, Y = log P 15 Y_stride, 16 loss_ptr, 17 loss_stride, 18 dX_ptr, 19 dX_stride, 20 label_ptr, 21 beta: tl.constexpr, 22 n_non_ignore: int, 23 ignore_index: tl.constexpr, 24 n_cols, 25 BLOCK_SIZE: tl.constexpr, 26 HAS_LABEL: tl.constexpr, 27 ): 28 # JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X) 29 # = sum(P * log P + Q * log Q - 2 * M * log M) / 2 30 # = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2 31 # grad_x_i = 0.5 * Q * (X - log_M) 32 pid = tl.program_id(0).to(tl.int64) 33 X_ptr += pid * X_stride 34 dX_ptr += pid * dX_stride 35 Y_ptr += pid * Y_stride 36 loss_ptr += pid * loss_stride 37 label_ptr += pid 38 39 if HAS_LABEL: 40 label = tl.load(label_ptr) 41 if label == ignore_index: 42 for i in range(0, n_cols, BLOCK_SIZE): 43 offsets = i + tl.arange(0, BLOCK_SIZE) 44 tl.store(dX_ptr + offsets, 0.0, mask=offsets < n_cols) 45 return 46 47 for i in range(0, n_cols, BLOCK_SIZE): 48 offsets = i + tl.arange(0, BLOCK_SIZE) 49 mask = offsets < n_cols 50 X = tl.load(X_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32) 51 Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32) 52 53 if beta == 0.0: # forward KL 54 Y_prob = tl.exp(Y) 55 loss = Y_prob * (Y - X) 56 dX = -Y_prob 57 elif beta == 1.0: 58 X_prob = tl.exp(X) 59 loss = X_prob * (X - Y) 60 dX = loss + X_prob 61 else: 62 Q = tl.exp(X) 63 P = tl.exp(Y) 64 M = beta * P + (1 - beta) * Q 65 log_M = tl.log(M) 66 67 loss = beta * P * Y + (1 - beta) * Q * X - M * log_M 68 dX = (1 - beta) * Q * (X - log_M) 69 70 loss = loss / n_non_ignore 71 dX = dX / n_non_ignore 72 tl.store(loss_ptr + offsets, loss, mask=mask) 73 tl.store(dX_ptr + offsets, dX, mask=mask) 74 75 76 MAX_FUSED_SIZE = 65536 77 78 79 def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label): 80 BT, V = _input.shape 81 n_rows = BT 82 BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) 83 # non reduction loss 84 loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device) 85 dX = torch.empty_like(_input) 86 87 if has_label: 88 n_non_ignore = (shift_labels != ignore_index).sum().item() 89 else: 90 n_non_ignore = BT 91 92 _jsd_kernel[(n_rows,)]( 93 X_ptr=_input, # input in logspace, X = log Q 94 X_stride=_input.stride(-2), 95 Y_ptr=target, # ground truth in logspace, Y = log P 96 Y_stride=target.stride(-2), 97 loss_ptr=loss, 98 loss_stride=loss.stride(-2), 99 dX_ptr=dX, 100 dX_stride=dX.stride(-2), 101 label_ptr=( 102 shift_labels if has_label else torch.empty(1, device=_input.device) 103 ), # dummy ptr if no label 104 beta=beta, 105 n_non_ignore=n_non_ignore, 106 ignore_index=ignore_index, 107 n_cols=V, 108 BLOCK_SIZE=BLOCK_SIZE, 109 HAS_LABEL=has_label, 110 ) 111 112 loss = torch.sum(loss) 113 return loss.to(_input.dtype), dX 114 115 116 def jsd_backward(dX, grad_output): 117 # If jsd is the last layer, grad_output is 1.0. Skip the mul to save time 118 if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): 119 return dX 120 else: 121 return grad_output * dX 122 123 124 class LigerJSDFunction(torch.autograd.Function): 125 r""" 126 This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence. 127 .. math:: 128 JSD(\beta)(P || Q) 129 = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q)) 130 131 .. note:: 132 As all the other losses in PyTorch, this function expects the first argument, 133 :attr:`_input`, to be the predictions, the output of the student model, in log-space 134 and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space. 135 This differs from the standard mathematical notation :math:`JSD(P || Q)` where 136 :math:`P` denotes the teacher model and :math:`Q` denotes the student model. 137 """ 138 139 @staticmethod 140 @ensure_contiguous 141 def forward( 142 ctx, 143 _input: torch.Tensor, 144 target: torch.Tensor, 145 shift_labels: Optional[torch.Tensor] = None, 146 beta: float = 0.5, 147 ignore_index: int = -100, 148 ) -> torch.Tensor: 149 """ 150 Args: 151 _input (torch.Tensor): predict values with shape (BT, V) in logspace 152 target (torch.Tensor): ground truth values with shape (BT, V) in logspace 153 shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1]. 154 beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5` 155 ignore_index (int): the index to ignore. Default: -100 156 157 Returns: 158 loss (torch.Tensor): generalized JSD 159 """ 160 has_label = False 161 if shift_labels is not None: 162 assert shift_labels.shape == ( 163 _input.shape[0], 164 ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" 165 shift_labels = shift_labels.contiguous() 166 has_label = True 167 168 loss, dX = jsd_forward( 169 _input, target, shift_labels, beta, ignore_index, has_label 170 ) 171 ctx.save_for_backward(dX) 172 return loss 173 174 @staticmethod 175 @ensure_contiguous 176 def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: 177 (dX,) = ctx.saved_tensors 178 dX = jsd_backward(dX, grad_output) 179 return ( 180 dX, 181 None, 182 None, 183 None, 184 None, 185 )