/ src / liger_kernel / chunked_loss / dpo_loss.py
dpo_loss.py
  1  import torch
  2  import torch.nn.functional as F
  3  
  4  from liger_kernel.chunked_loss.fused_linear_preference import (
  5      LigerFusedLinearPreferenceBase,
  6  )
  7  
  8  
  9  class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):
 10  
 11      @staticmethod
 12      def preference_loss_fn(
 13          chosen_logps,
 14          rejected_logps,
 15          ref_chosen_logps=None,
 16          ref_rejected_logps=None,
 17          beta=0.1,
 18      ):
 19          """
 20          Compute DPO loss (Direct Preference Optimization).
 21          Args:
 22              chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
 23              rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
 24              ref_chosen_logps (torch.Tensor, optional): Reference log probabilities of chosen tokens. Shape: (batch_size,).
 25              ref_rejected_logps (torch.Tensor, optional): Reference log probabilities of rejected tokens. Shape: (batch_size,).
 26              beta (float): Weight for the direct preference loss.
 27          """
 28          if ref_chosen_logps is None:
 29              ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device)
 30          if ref_rejected_logps is None:
 31              ref_rejected_logps = torch.tensor(0.0, device=rejected_logps.device)
 32  
 33          chosen_logratios = chosen_logps - ref_chosen_logps
 34          rejected_logratios = rejected_logps - ref_rejected_logps
 35  
 36          logits_diff = beta * (chosen_logratios - rejected_logratios)
 37          losses = -F.logsigmoid(logits_diff)
 38          return losses.sum()
 39  
 40      @staticmethod
 41      def forward(
 42          ctx,
 43          _input,
 44          weight,
 45          target,
 46          bias=None,
 47          ref_weight=None,
 48          ref_bias=None,
 49          ignore_index=-100,
 50          beta=0.1,
 51          compute_nll_loss=True,
 52          compiled=True,
 53          use_ref_model=True,
 54      ):
 55          """
 56          Fused linear layer with DPO (Direct Preference Optimization) loss.
 57          Handles both the forward and backward pass of the final linear layer with DPO loss.
 58          """
 59          return LigerFusedLinearPreferenceBase.forward(
 60              ctx=ctx,
 61              _input=_input,
 62              weight=weight,
 63              target=target,
 64              bias=bias,
 65              loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn,
 66              ignore_index=ignore_index,
 67              beta=beta,
 68              compute_nll_loss=compute_nll_loss,
 69              compiled=compiled,
 70              use_ref_model=use_ref_model,
 71              ref_weight=ref_weight,
 72              ref_bias=ref_bias,
 73          )
 74  
 75      @staticmethod
 76      def backward(ctx, grad_output):
 77          # Get gradients for _input, weight, bias, and target from the base class
 78          grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
 79          # Return these gradients, followed by None for the remaining inputs
 80          return *grads, None, None, None, None, None, None, None
 81  
 82  
 83  class LigerFusedLinearDPOLoss(torch.nn.Module):
 84      """
 85      Fused linear layer with DPO loss.
 86      """
 87  
 88      def __init__(
 89          self,
 90          ignore_index: int = -100,
 91          beta: float = 0.1,
 92          compute_nll_loss: bool = True,
 93          compiled: bool = True,
 94          use_ref_model: bool = False,
 95      ):
 96          """
 97          Args:
 98              ignore_index (int): Index to ignore in the loss.
 99              beta (float): Weight for the odds ratio loss.
100              compute_nll_loss (bool): Whether to compute the NLL loss.
101              compiled (bool): Whether to use the torch compiled kernel.
102              use_ref_model (bool): Whether to use a reference model for the DPO loss.
103          """
104          super().__init__()
105          self.ignore_index = ignore_index
106          self.beta = beta
107          self.compute_nll_loss = compute_nll_loss
108          self.compiled = compiled
109          self.use_ref_model = use_ref_model
110  
111      def forward(
112          self, lin_weight, _input, target, bias=None, ref_weight=None, ref_bias=None
113      ):
114          return LigerFusedLinearDPOFunction.apply(
115              _input,
116              lin_weight,
117              target,
118              bias,
119              ref_weight,
120              ref_bias,
121              self.ignore_index,
122              self.beta,
123              self.compute_nll_loss,
124              self.compiled,
125              self.use_ref_model,
126          )