fused_linear_jsd.py
1 from typing import Optional 2 3 import torch 4 5 from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction 6 7 8 class LigerFusedLinearJSD(torch.nn.Module): 9 r"""Fusing the last linear layer with generalized JSD 10 11 Handle the forward and backward pass of the final linear layer via JSD by avoiding 12 the materialization of the large logits tensor. 13 14 Args: 15 jsd_beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5` 16 ignore_index (int): The index to ignore in the target. Default: `-100` 17 temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0` 18 19 Shape: 20 - student_input: :math:`(BT, H)`, where B is batch size, T is sequence length, H is hidden dimension. 21 - student_weight: :math:`(V, H)`, where V is vocab size. 22 - teacher_input: :math:`(BT, H')`, where H' is hidden dimension of the teacher model. 23 - teacher_weight: :math:`(V, H')`, where hidden size H and H' can be different. 24 - shift_labels: :math:`(BT,)` 25 - Output: a scalar. 26 27 Examples: 28 ```python 29 >>> (B, T, H_s, H_t, V) = (2, 2, 3, 5, 10) 30 >>> fused_jsd = LigerFusedLinearJSD(jsd_beta=0.1, temperature=2.0) 31 >>> # generate inputs and weights 32 >>> student_input = torch.rand(B * T, H_s, device="cuda", requires_grad=True) 33 >>> student_lin = torch.nn.Linear(H_s, V, bias=False, device="cuda") 34 >>> # teacher input doesn't require grad, hidden_dim can be different from student's 35 >>> teacher_input = torch.rand(B * T, H_t, device="cuda") 36 >>> teacher_lin = torch.nn.Linear(H_t, V, bias=False, device="cuda") 37 >>> output = fused_jsd(student_input, student_lin.weight, teacher_input, teacher_lin.weight) 38 >>> output.backward() 39 >>> 40 >>> # Example with labels for supervised fine-tuning (SFT) context: 41 >>> 42 >>> # Assume hidden_states, lm_heads and corresponding labels are given 43 >>> student_lm_head = torch.nn.Linear(H_s, V, bias=False) 44 >>> student_hidden_states = torch.randn(B * T, H_s, requires_grad=True).log_softmax(dim=-1) 45 >>> teacher_lm_head = torch.nn.Linear(H_t, V, bias=False) 46 >>> teacher_hidden_states = torch.randn(B * T, H_t).log_softmax(dim=-1) 47 >>> labels = torch.randint(0, V, (B * T,), torch.long) 48 >>> 49 >>> # Shift so that tokens < n predict n 50 >>> shift_student_hidden_states = student_hidden_states[..., :-1, :].contiguous() 51 >>> shift_teacher_hidden_states = teacher_hidden_states[..., :-1, :].contiguous() 52 >>> shift_labels = labels[..., 1:].contiguous() 53 >>> 54 >>> # Flatten tokens 55 >>> shift_student_hidden_states = shift_student_hidden_states.view(-1, V) 56 >>> shift_teacher_hidden_states = shift_teacher_hidden_states.view(-1, V) 57 >>> shift_labels = shift_labels.view(-1) 58 >>> 59 >>> # Calculate loss 60 >>> loss_fct = LigerJSD(beta=0.1) 61 >>> loss = loss_fct( 62 >>> shift_studetn_hidden_states, 63 >>> student_lm_head.weight, 64 >>> shift_teacher_hidden_states, 65 >>> teacher_lm_head.weight, 66 >>> shift_labels 67 >>> ) 68 ``` 69 """ 70 71 def __init__(self, jsd_beta=0.5, ignore_index=-100, temperature=1.0): 72 super().__init__() 73 assert temperature != 0, "temperature cannot be 0." 74 self.jsd_beta = jsd_beta 75 self.temperature = temperature 76 self.ignore_index = ignore_index 77 78 def forward( 79 self, 80 student_input: torch.Tensor, 81 student_weight: torch.Tensor, 82 teacher_input: torch.Tensor, 83 teacher_weight: torch.Tensor, 84 shift_labels: Optional[torch.LongTensor], 85 ): 86 return LigerFusedLinearJSDFunction.apply( 87 student_input, 88 student_weight, 89 teacher_input, 90 teacher_weight, 91 shift_labels, 92 self.jsd_beta, 93 self.ignore_index, 94 self.temperature, 95 )