dpo_loss.py
1 import torch 2 import torch.nn.functional as F 3 4 from liger_kernel.chunked_loss.fused_linear_preference import ( 5 LigerFusedLinearPreferenceBase, 6 ) 7 8 9 class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase): 10 11 @staticmethod 12 def preference_loss_fn( 13 chosen_logps, 14 rejected_logps, 15 ref_chosen_logps=None, 16 ref_rejected_logps=None, 17 beta=0.1, 18 ): 19 """ 20 Compute DPO loss (Direct Preference Optimization). 21 Args: 22 chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). 23 rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). 24 ref_chosen_logps (torch.Tensor, optional): Reference log probabilities of chosen tokens. Shape: (batch_size,). 25 ref_rejected_logps (torch.Tensor, optional): Reference log probabilities of rejected tokens. Shape: (batch_size,). 26 beta (float): Weight for the direct preference loss. 27 """ 28 if ref_chosen_logps is None: 29 ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device) 30 if ref_rejected_logps is None: 31 ref_rejected_logps = torch.tensor(0.0, device=rejected_logps.device) 32 33 chosen_logratios = chosen_logps - ref_chosen_logps 34 rejected_logratios = rejected_logps - ref_rejected_logps 35 36 logits_diff = beta * (chosen_logratios - rejected_logratios) 37 losses = -F.logsigmoid(logits_diff) 38 return losses.sum() 39 40 @staticmethod 41 def forward( 42 ctx, 43 _input, 44 weight, 45 target, 46 bias=None, 47 ref_weight=None, 48 ref_bias=None, 49 ignore_index=-100, 50 beta=0.1, 51 compute_nll_loss=True, 52 compiled=True, 53 use_ref_model=True, 54 ): 55 """ 56 Fused linear layer with DPO (Direct Preference Optimization) loss. 57 Handles both the forward and backward pass of the final linear layer with DPO loss. 58 """ 59 return LigerFusedLinearPreferenceBase.forward( 60 ctx=ctx, 61 _input=_input, 62 weight=weight, 63 target=target, 64 bias=bias, 65 loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn, 66 ignore_index=ignore_index, 67 beta=beta, 68 compute_nll_loss=compute_nll_loss, 69 compiled=compiled, 70 use_ref_model=use_ref_model, 71 ref_weight=ref_weight, 72 ref_bias=ref_bias, 73 ) 74 75 @staticmethod 76 def backward(ctx, grad_output): 77 # Get gradients for _input, weight, bias, and target from the base class 78 grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] 79 # Return these gradients, followed by None for the remaining inputs 80 return *grads, None, None, None, None, None, None, None 81 82 83 class LigerFusedLinearDPOLoss(torch.nn.Module): 84 """ 85 Fused linear layer with DPO loss. 86 """ 87 88 def __init__( 89 self, 90 ignore_index: int = -100, 91 beta: float = 0.1, 92 compute_nll_loss: bool = True, 93 compiled: bool = True, 94 use_ref_model: bool = False, 95 ): 96 """ 97 Args: 98 ignore_index (int): Index to ignore in the loss. 99 beta (float): Weight for the odds ratio loss. 100 compute_nll_loss (bool): Whether to compute the NLL loss. 101 compiled (bool): Whether to use the torch compiled kernel. 102 use_ref_model (bool): Whether to use a reference model for the DPO loss. 103 """ 104 super().__init__() 105 self.ignore_index = ignore_index 106 self.beta = beta 107 self.compute_nll_loss = compute_nll_loss 108 self.compiled = compiled 109 self.use_ref_model = use_ref_model 110 111 def forward( 112 self, lin_weight, _input, target, bias=None, ref_weight=None, ref_bias=None 113 ): 114 return LigerFusedLinearDPOFunction.apply( 115 _input, 116 lin_weight, 117 target, 118 bias, 119 ref_weight, 120 ref_bias, 121 self.ignore_index, 122 self.beta, 123 self.compute_nll_loss, 124 self.compiled, 125 self.use_ref_model, 126 )