jsd.py
1 from typing import Optional 2 3 import torch 4 5 from liger_kernel.ops.jsd import LigerJSDFunction 6 7 8 class LigerJSD(torch.nn.Module): 9 r"""The generalized Jensen-Shannon Divergence. 10 .. math:: 11 JSD(\beta)(P || Q) 12 = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q)) 13 .. note:: 14 As all the other losses in PyTorch, this function expects the first argument, 15 :attr:`log_q`, to be the predictions, the output of the student model in log-space, 16 and the second, :attr:`log_p`, to be the observations, the output of the teacher model in log-space. 17 This differs from the standard mathematical notation :math:`JSD(P || Q)` where 18 :math:`P` denotes the teacher model and :math:`Q` denotes the student model. 19 20 Args: 21 beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5` 22 ignore_index (int): The index to ignore in the target. Default: `-100` 23 24 Shape: 25 - Input: :math:`(BT, V)`, where B is batch size, T is sequence length, V is vocab size. 26 - Target: :math:`(BT, V)`, same shape as the input. 27 - shift_labels (Optional): :math:`(BT,)` 28 - Output: a scalar. 29 30 Examples: 31 ```python 32 >>> (B, T, V) = (2, 2, 5) 33 >>> jsd = LigerJSD(beta=0.1) 34 >>> # input should be a distribution in the log space 35 >>> input = torch.randn(B * T, V, requires_grad=True).log_softmax(dim=-1) 36 >>> target = torch.randn(B * T, V).log_softmax(dim=-1) 37 >>> output = jsd(input, target) 38 >>> 39 >>> # Example with labels for supervised fine-tuning (SFT) context 40 >>> # Assume logits and corresponding labels are given 41 >>> student_logits = torch.randn(B * T, V, requires_grad=True).log_softmax(dim=-1) 42 >>> teacher_logits = torch.randn(B * T, V).log_softmax(dim=-1) 43 >>> labels = torch.randint(0, V, (B * T,), torch.long) 44 >>> # Shift so that tokens < n predict n 45 >>> shift_student_logits = student_logits[..., :-1, :].contiguous() 46 >>> shift_teacher_logits = teacher_logits[..., :-1, :].contiguous() 47 >>> shift_labels = labels[..., 1:].contiguous() 48 >>> # Flatten tokens 49 >>> shift_student_logits = shift_student_logits.view(-1, V) 50 >>> shift_teacher_logits = shift_teacher_logits.view(-1, V) 51 >>> shift_labels = shift_labels.view(-1) 52 >>> # Calculate loss 53 >>> loss_fct = LigerJSD(beta=0.1) 54 >>> loss = loss_fct(shift_studetn_logits, shift_teacher_logits, shift_labels) 55 56 ``` 57 """ 58 59 def __init__(self, beta: float = 0.5, ignore_index: int = -100): 60 super().__init__() 61 self.beta = beta 62 self.ignore_index = ignore_index 63 64 def forward( 65 self, 66 log_q: torch.Tensor, 67 log_p: torch.Tensor, 68 shift_labels: Optional[torch.LongTensor] = None, 69 ): 70 return LigerJSDFunction.apply( 71 log_q, log_p, shift_labels, self.beta, self.ignore_index 72 )