simpo_loss.py
1 import torch 2 import torch.nn.functional as F 3 4 from liger_kernel.chunked_loss.fused_linear_preference import ( 5 LigerFusedLinearPreferenceBase, 6 ) 7 8 9 class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase): 10 11 @staticmethod 12 def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1, gamma=0.5): 13 """ 14 Compute odds-ratio loss. 15 Args: 16 chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). 17 rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). 18 beta (float): Weight for the odds ratio loss. 19 gamma (float): The simpo gamma, margin term. 20 """ 21 logits = beta * (chosen_logps - rejected_logps) - gamma 22 loss = F.logsigmoid(logits).mean() 23 return loss 24 25 @staticmethod 26 def forward( 27 ctx, 28 _input, 29 weight, 30 target, 31 bias=None, 32 ignore_index=-100, 33 beta=0.1, 34 alpha=1.0, 35 compute_nll_loss=False, 36 compiled=True, 37 gamma=0.5, 38 ): 39 """ 40 Fused linear layer with SimPO (Simple Preference Optimization) loss. https://arxiv.org/pdf/2405.14734 41 Handles both the forward and backward pass of the final linear layer with SimPO loss. 42 Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss. 43 """ 44 45 return LigerFusedLinearPreferenceBase.forward( 46 ctx, 47 _input, 48 weight, 49 target, 50 bias, 51 loss_fn=LigerFusedLinearSimPOFunction.preference_loss_fn, 52 compute_nll_loss=compute_nll_loss, 53 ignore_index=ignore_index, 54 alpha=alpha, 55 beta=beta, 56 compiled=compiled, 57 gamma=gamma, 58 ) 59 60 @staticmethod 61 def backward(ctx, grad_output): 62 # Get gradients for _input, weight, bias, and target from the base class 63 grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] 64 # Return these gradients, followed by None for the remaining inputs 65 return *grads, None, None, None, None, None, None 66 67 68 class LigerFusedLinearSimPOLoss(torch.nn.Module): 69 """ 70 Fused linear layer with SimPO loss. 71 """ 72 73 def __init__( 74 self, 75 ignore_index: int = -100, 76 beta: float = 0.1, 77 alpha: float = 1.0, 78 compute_nll_loss: bool = True, 79 compiled: bool = True, 80 gamma: float = 0.5, 81 ): 82 """ 83 Args: 84 ignore_index (int): Index to ignore in the loss. 85 beta (float): Weight for the odds ratio loss. 86 """ 87 super().__init__() 88 self.ignore_index = ignore_index 89 self.beta = beta 90 self.alpha = alpha 91 self.compute_nll_loss = compute_nll_loss 92 self.compiled = compiled 93 self.gamma = gamma 94 95 def forward(self, lin_weight, _input, target, bias=None): 96 return LigerFusedLinearSimPOFunction.apply( 97 _input, 98 lin_weight, 99 target, 100 bias, 101 self.ignore_index, 102 self.beta, 103 self.alpha, 104 self.compute_nll_loss, 105 self.compiled, 106 self.gamma, 107 )