/ src / liger_kernel / chunked_loss / cpo_loss.py
cpo_loss.py
  1  import torch
  2  import torch.nn.functional as F
  3  
  4  from liger_kernel.chunked_loss.fused_linear_preference import (
  5      LigerFusedLinearPreferenceBase,
  6  )
  7  
  8  
  9  class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
 10  
 11      @staticmethod
 12      def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
 13          """
 14          Compute odds-ratio loss.
 15          Args:
 16              chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
 17              rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
 18              beta (float): Weight for the odds ratio loss.
 19          """
 20          logits = beta * (chosen_logps - rejected_logps)
 21          loss = F.logsigmoid(logits).mean()
 22          return loss
 23  
 24      @staticmethod
 25      def forward(
 26          ctx,
 27          _input,
 28          weight,
 29          target,
 30          bias=None,
 31          ignore_index=-100,
 32          beta=0.1,
 33          alpha=1.0,
 34          compute_nll_loss=True,
 35          compiled=True,
 36      ):
 37          """
 38          Fused linear layer with CPO (Odds-Ratio Preference Optimization) loss.
 39          Handles both the forward and backward pass of the final linear layer with CPO loss.
 40          Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
 41          """
 42  
 43          return LigerFusedLinearPreferenceBase.forward(
 44              ctx,
 45              _input,
 46              weight,
 47              target,
 48              bias,
 49              loss_fn=LigerFusedLinearCPOFunction.preference_loss_fn,
 50              ignore_index=ignore_index,
 51              alpha=alpha,
 52              beta=beta,
 53              compute_nll_loss=compute_nll_loss,
 54              compiled=compiled,
 55          )
 56  
 57      @staticmethod
 58      def backward(ctx, grad_output):
 59          # Get gradients for _input, weight, bias, and target from the base class
 60          grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
 61          # Return these gradients, followed by None for the remaining inputs
 62          return *grads, None, None, None, None, None
 63  
 64  
 65  class LigerFusedLinearCPOLoss(torch.nn.Module):
 66      """
 67      Fused linear layer with CPO loss.
 68      """
 69  
 70      def __init__(
 71          self,
 72          ignore_index: int = -100,
 73          beta: float = 0.1,
 74          alpha: float = 1.0,
 75          compute_nll_loss: bool = True,
 76          compiled: bool = True,
 77      ):
 78          """
 79          Args:
 80              ignore_index (int): Index to ignore in the loss.
 81              beta (float): Weight for the odds ratio loss.
 82          """
 83          super().__init__()
 84          self.ignore_index = ignore_index
 85          self.beta = beta
 86          self.alpha = alpha
 87          self.compute_nll_loss = compute_nll_loss
 88          self.compiled = compiled
 89  
 90      def forward(self, lin_weight, _input, target, bias=None):
 91          return LigerFusedLinearCPOFunction.apply(
 92              _input,
 93              lin_weight,
 94              target,
 95              bias,
 96              self.ignore_index,
 97              self.beta,
 98              self.alpha,
 99              self.compute_nll_loss,
100              self.compiled,
101          )