cpo_loss.py
1 import torch 2 import torch.nn.functional as F 3 4 from liger_kernel.chunked_loss.fused_linear_preference import ( 5 LigerFusedLinearPreferenceBase, 6 ) 7 8 9 class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase): 10 11 @staticmethod 12 def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1): 13 """ 14 Compute odds-ratio loss. 15 Args: 16 chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). 17 rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). 18 beta (float): Weight for the odds ratio loss. 19 """ 20 logits = beta * (chosen_logps - rejected_logps) 21 loss = F.logsigmoid(logits).mean() 22 return loss 23 24 @staticmethod 25 def forward( 26 ctx, 27 _input, 28 weight, 29 target, 30 bias=None, 31 ignore_index=-100, 32 beta=0.1, 33 alpha=1.0, 34 compute_nll_loss=True, 35 compiled=True, 36 ): 37 """ 38 Fused linear layer with CPO (Odds-Ratio Preference Optimization) loss. 39 Handles both the forward and backward pass of the final linear layer with CPO loss. 40 Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss. 41 """ 42 43 return LigerFusedLinearPreferenceBase.forward( 44 ctx, 45 _input, 46 weight, 47 target, 48 bias, 49 loss_fn=LigerFusedLinearCPOFunction.preference_loss_fn, 50 ignore_index=ignore_index, 51 alpha=alpha, 52 beta=beta, 53 compute_nll_loss=compute_nll_loss, 54 compiled=compiled, 55 ) 56 57 @staticmethod 58 def backward(ctx, grad_output): 59 # Get gradients for _input, weight, bias, and target from the base class 60 grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] 61 # Return these gradients, followed by None for the remaining inputs 62 return *grads, None, None, None, None, None 63 64 65 class LigerFusedLinearCPOLoss(torch.nn.Module): 66 """ 67 Fused linear layer with CPO loss. 68 """ 69 70 def __init__( 71 self, 72 ignore_index: int = -100, 73 beta: float = 0.1, 74 alpha: float = 1.0, 75 compute_nll_loss: bool = True, 76 compiled: bool = True, 77 ): 78 """ 79 Args: 80 ignore_index (int): Index to ignore in the loss. 81 beta (float): Weight for the odds ratio loss. 82 """ 83 super().__init__() 84 self.ignore_index = ignore_index 85 self.beta = beta 86 self.alpha = alpha 87 self.compute_nll_loss = compute_nll_loss 88 self.compiled = compiled 89 90 def forward(self, lin_weight, _input, target, bias=None): 91 return LigerFusedLinearCPOFunction.apply( 92 _input, 93 lin_weight, 94 target, 95 bias, 96 self.ignore_index, 97 self.beta, 98 self.alpha, 99 self.compute_nll_loss, 100 self.compiled, 101 )