/ src / liger_kernel / chunked_loss / simpo_loss.py
simpo_loss.py
  1  import torch
  2  import torch.nn.functional as F
  3  
  4  from liger_kernel.chunked_loss.fused_linear_preference import (
  5      LigerFusedLinearPreferenceBase,
  6  )
  7  
  8  
  9  class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):
 10  
 11      @staticmethod
 12      def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1, gamma=0.5):
 13          """
 14          Compute odds-ratio loss.
 15          Args:
 16              chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
 17              rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
 18              beta (float): Weight for the odds ratio loss.
 19              gamma (float): The simpo gamma, margin term.
 20          """
 21          logits = beta * (chosen_logps - rejected_logps) - gamma
 22          loss = F.logsigmoid(logits).mean()
 23          return loss
 24  
 25      @staticmethod
 26      def forward(
 27          ctx,
 28          _input,
 29          weight,
 30          target,
 31          bias=None,
 32          ignore_index=-100,
 33          beta=0.1,
 34          alpha=1.0,
 35          compute_nll_loss=False,
 36          compiled=True,
 37          gamma=0.5,
 38      ):
 39          """
 40          Fused linear layer with SimPO (Simple Preference Optimization) loss. https://arxiv.org/pdf/2405.14734
 41          Handles both the forward and backward pass of the final linear layer with SimPO loss.
 42          Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
 43          """
 44  
 45          return LigerFusedLinearPreferenceBase.forward(
 46              ctx,
 47              _input,
 48              weight,
 49              target,
 50              bias,
 51              loss_fn=LigerFusedLinearSimPOFunction.preference_loss_fn,
 52              compute_nll_loss=compute_nll_loss,
 53              ignore_index=ignore_index,
 54              alpha=alpha,
 55              beta=beta,
 56              compiled=compiled,
 57              gamma=gamma,
 58          )
 59  
 60      @staticmethod
 61      def backward(ctx, grad_output):
 62          # Get gradients for _input, weight, bias, and target from the base class
 63          grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
 64          # Return these gradients, followed by None for the remaining inputs
 65          return *grads, None, None, None, None, None, None
 66  
 67  
 68  class LigerFusedLinearSimPOLoss(torch.nn.Module):
 69      """
 70      Fused linear layer with SimPO loss.
 71      """
 72  
 73      def __init__(
 74          self,
 75          ignore_index: int = -100,
 76          beta: float = 0.1,
 77          alpha: float = 1.0,
 78          compute_nll_loss: bool = True,
 79          compiled: bool = True,
 80          gamma: float = 0.5,
 81      ):
 82          """
 83          Args:
 84              ignore_index (int): Index to ignore in the loss.
 85              beta (float): Weight for the odds ratio loss.
 86          """
 87          super().__init__()
 88          self.ignore_index = ignore_index
 89          self.beta = beta
 90          self.alpha = alpha
 91          self.compute_nll_loss = compute_nll_loss
 92          self.compiled = compiled
 93          self.gamma = gamma
 94  
 95      def forward(self, lin_weight, _input, target, bias=None):
 96          return LigerFusedLinearSimPOFunction.apply(
 97              _input,
 98              lin_weight,
 99              target,
100              bias,
101              self.ignore_index,
102              self.beta,
103              self.alpha,
104              self.compute_nll_loss,
105              self.compiled,
106              self.gamma,
107          )