jsd.py
 1  from typing import Optional
 2  
 3  import torch
 4  
 5  from liger_kernel.ops.jsd import LigerJSDFunction
 6  
 7  
 8  class LigerJSD(torch.nn.Module):
 9      r"""The generalized Jensen-Shannon Divergence.
10      .. math::
11      JSD(\beta)(P || Q)
12          = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))
13      .. note::
14      As all the other losses in PyTorch, this function expects the first argument,
15      :attr:`log_q`, to be the predictions, the output of the student model in log-space,
16      and the second, :attr:`log_p`, to be the observations, the output of the teacher model in log-space.
17      This differs from the standard mathematical notation :math:`JSD(P || Q)` where
18      :math:`P` denotes the teacher model and :math:`Q` denotes the student model.
19  
20      Args:
21          beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
22          ignore_index (int): The index to ignore in the target. Default: `-100`
23  
24      Shape:
25          - Input: :math:`(BT, V)`, where B is batch size, T is sequence length, V is vocab size.
26          - Target: :math:`(BT, V)`, same shape as the input.
27          - shift_labels (Optional): :math:`(BT,)`
28          - Output: a scalar.
29  
30      Examples:
31      ```python
32      >>> (B, T, V) = (2, 2, 5)
33      >>> jsd = LigerJSD(beta=0.1)
34      >>> # input should be a distribution in the log space
35      >>> input = torch.randn(B * T, V, requires_grad=True).log_softmax(dim=-1)
36      >>> target = torch.randn(B * T, V).log_softmax(dim=-1)
37      >>> output = jsd(input, target)
38      >>>
39      >>> # Example with labels for supervised fine-tuning (SFT) context
40      >>> # Assume logits and corresponding labels are given
41      >>> student_logits = torch.randn(B * T, V, requires_grad=True).log_softmax(dim=-1)
42      >>> teacher_logits = torch.randn(B * T, V).log_softmax(dim=-1)
43      >>> labels = torch.randint(0, V, (B * T,), torch.long)
44      >>> # Shift so that tokens < n predict n
45      >>> shift_student_logits = student_logits[..., :-1, :].contiguous()
46      >>> shift_teacher_logits = teacher_logits[..., :-1, :].contiguous()
47      >>> shift_labels = labels[..., 1:].contiguous()
48      >>> # Flatten tokens
49      >>> shift_student_logits = shift_student_logits.view(-1, V)
50      >>> shift_teacher_logits = shift_teacher_logits.view(-1, V)
51      >>> shift_labels = shift_labels.view(-1)
52      >>> # Calculate loss
53      >>> loss_fct = LigerJSD(beta=0.1)
54      >>> loss = loss_fct(shift_studetn_logits, shift_teacher_logits, shift_labels)
55  
56      ```
57      """
58  
59      def __init__(self, beta: float = 0.5, ignore_index: int = -100):
60          super().__init__()
61          self.beta = beta
62          self.ignore_index = ignore_index
63  
64      def forward(
65          self,
66          log_q: torch.Tensor,
67          log_p: torch.Tensor,
68          shift_labels: Optional[torch.LongTensor] = None,
69      ):
70          return LigerJSDFunction.apply(
71              log_q, log_p, shift_labels, self.beta, self.ignore_index
72          )