/ src / liger_kernel / transformers / fused_linear_jsd.py
fused_linear_jsd.py
 1  from typing import Optional
 2  
 3  import torch
 4  
 5  from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
 6  
 7  
 8  class LigerFusedLinearJSD(torch.nn.Module):
 9      r"""Fusing the last linear layer with generalized JSD
10  
11      Handle the forward and backward pass of the final linear layer via JSD by avoiding
12      the materialization of the large logits tensor.
13  
14      Args:
15          jsd_beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
16          ignore_index (int): The index to ignore in the target. Default: `-100`
17          temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
18  
19      Shape:
20          - student_input: :math:`(BT, H)`, where B is batch size, T is sequence length, H is hidden dimension.
21          - student_weight: :math:`(V, H)`, where V is vocab size.
22          - teacher_input: :math:`(BT, H')`, where H' is hidden dimension of the teacher model.
23          - teacher_weight: :math:`(V, H')`, where hidden size H and H' can be different.
24          - shift_labels: :math:`(BT,)`
25          - Output: a scalar.
26  
27      Examples:
28      ```python
29      >>> (B, T, H_s, H_t, V) = (2, 2, 3, 5, 10)
30      >>> fused_jsd = LigerFusedLinearJSD(jsd_beta=0.1, temperature=2.0)
31      >>> # generate inputs and weights
32      >>> student_input = torch.rand(B * T, H_s, device="cuda", requires_grad=True)
33      >>> student_lin = torch.nn.Linear(H_s, V, bias=False, device="cuda")
34      >>> # teacher input doesn't require grad, hidden_dim can be different from student's
35      >>> teacher_input = torch.rand(B * T, H_t, device="cuda")
36      >>> teacher_lin = torch.nn.Linear(H_t, V, bias=False, device="cuda")
37      >>> output = fused_jsd(student_input, student_lin.weight, teacher_input, teacher_lin.weight)
38      >>> output.backward()
39      >>>
40      >>> # Example with labels for supervised fine-tuning (SFT) context:
41      >>>
42      >>> # Assume hidden_states, lm_heads and corresponding labels are given
43      >>> student_lm_head = torch.nn.Linear(H_s, V, bias=False)
44      >>> student_hidden_states = torch.randn(B * T, H_s, requires_grad=True).log_softmax(dim=-1)
45      >>> teacher_lm_head = torch.nn.Linear(H_t, V, bias=False)
46      >>> teacher_hidden_states = torch.randn(B * T, H_t).log_softmax(dim=-1)
47      >>> labels = torch.randint(0, V, (B * T,), torch.long)
48      >>>
49      >>> # Shift so that tokens < n predict n
50      >>> shift_student_hidden_states = student_hidden_states[..., :-1, :].contiguous()
51      >>> shift_teacher_hidden_states = teacher_hidden_states[..., :-1, :].contiguous()
52      >>> shift_labels = labels[..., 1:].contiguous()
53      >>>
54      >>> # Flatten tokens
55      >>> shift_student_hidden_states = shift_student_hidden_states.view(-1, V)
56      >>> shift_teacher_hidden_states = shift_teacher_hidden_states.view(-1, V)
57      >>> shift_labels = shift_labels.view(-1)
58      >>>
59      >>> # Calculate loss
60      >>> loss_fct = LigerJSD(beta=0.1)
61      >>> loss = loss_fct(
62      >>>     shift_studetn_hidden_states,
63      >>>     student_lm_head.weight,
64      >>>     shift_teacher_hidden_states,
65      >>>     teacher_lm_head.weight,
66      >>>     shift_labels
67      >>> )
68      ```
69      """
70  
71      def __init__(self, jsd_beta=0.5, ignore_index=-100, temperature=1.0):
72          super().__init__()
73          assert temperature != 0, "temperature cannot be 0."
74          self.jsd_beta = jsd_beta
75          self.temperature = temperature
76          self.ignore_index = ignore_index
77  
78      def forward(
79          self,
80          student_input: torch.Tensor,
81          student_weight: torch.Tensor,
82          teacher_input: torch.Tensor,
83          teacher_weight: torch.Tensor,
84          shift_labels: Optional[torch.LongTensor],
85      ):
86          return LigerFusedLinearJSDFunction.apply(
87              student_input,
88              student_weight,
89              teacher_input,
90              teacher_weight,
91              shift_labels,
92              self.jsd_beta,
93              self.ignore_index,
94              self.temperature,
95          )