fused_linear_preference.py
1 from abc import abstractmethod 2 from functools import partial 3 4 import torch 5 from torch.nn import functional as F 6 7 8 class LigerFusedLinearPreferenceBase(torch.autograd.Function): 9 10 @abstractmethod 11 def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1): 12 """ 13 Compute preference loss. 14 Args: 15 chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). 16 rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). 17 beta (float): Weight for the odds ratio loss. 18 """ 19 raise NotImplementedError("Preference loss function must be implemented.") 20 21 @staticmethod 22 def chunk_forward( 23 input_chunk, 24 weight, 25 target_chunk, 26 bias=None, 27 ignore_index=-100, 28 compute_nll_loss=True, 29 ): 30 len_chosen_chunk = target_chunk.shape[0] // 2 31 logits_chunk = input_chunk @ weight.t() 32 if bias is not None: 33 logits_chunk = logits_chunk + bias 34 log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) 35 36 chosen_nll_loss = 0.0 37 if compute_nll_loss: 38 chosen_nll_loss = F.nll_loss( 39 log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]), 40 target_chunk[:len_chosen_chunk].view(-1), 41 reduction="sum", 42 ignore_index=ignore_index, 43 ) 44 45 loss_mask = target_chunk != ignore_index 46 label_chunk = torch.where(loss_mask, target_chunk, 0) 47 48 per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze( 49 -1 50 ) 51 average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) 52 53 chosen_logps = average_log_prob[:len_chosen_chunk] 54 rejected_logps = average_log_prob[len_chosen_chunk:] 55 return chosen_logps, rejected_logps, chosen_nll_loss 56 57 @staticmethod 58 def forward( 59 ctx, 60 _input, 61 weight, 62 target, 63 bias=None, 64 loss_fn=None, 65 chunk_size=1, 66 ignore_index=-100, 67 alpha=1.0, 68 beta=0.1, 69 compute_nll_loss=True, 70 compiled=True, 71 use_ref_model=False, 72 ref_weight=None, 73 ref_bias=None, 74 **loss_kwargs, 75 ): 76 """ 77 Base class for fused linear layer with preference loss. 78 Expects _input to be stacked with chosen and rejected inputs on the batch dimension. 79 80 Args: 81 _input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size). 82 weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). 83 target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len). 84 bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). 85 loss_fn (callable): Loss function to compute the loss on a chunk of input/target. 86 chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs). 87 compute_nll_loss (bool): Whether to compute NLL loss. 88 ignore_index (int): Index to ignore for loss computation. 89 alpha (float): Weight for the NLL loss. 90 beta (float): Weight for the odds ratio loss. 91 compute_nll_loss (bool): Whether to compute NLL loss. 92 compiled (bool): Whether to use torch compile for chunk accumulation. 93 use_ref_model (bool): Whether to use a reference model for the alignment loss. 94 ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). 95 ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). 96 loss_kwargs (dict): Other possible arguments that a loss function might need 97 """ 98 # TODO: Tune CHUNK_SIZE to fully utilize the GPU 99 CHUNK_SIZE = chunk_size 100 101 grad_weight = torch.zeros_like(weight) 102 grad_chosen_inputs = [] 103 grad_rejected_inputs = [] 104 grad_bias = torch.zeros_like(bias) if bias is not None else None 105 loss_acc = torch.zeros((), device=_input.device) 106 107 loss_func_to_call = partial( 108 LigerFusedLinearPreferenceBase._compute_loss, 109 preference_loss_fn=loss_fn, 110 ignore_index=ignore_index, 111 alpha=alpha, 112 beta=beta, 113 compute_nll_loss=compute_nll_loss, 114 full_target=target, 115 use_ref_model=use_ref_model, 116 ref_weight=ref_weight, 117 ref_bias=ref_bias, 118 **loss_kwargs, 119 ) 120 121 def accumulate_chunk(input_chunk, target_chunk): 122 if bias is not None: 123 (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), ( 124 chunk_loss, 125 (chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps), 126 ) = torch.func.grad_and_value( 127 loss_func_to_call, argnums=(0, 1, 3), has_aux=True 128 )( 129 input_chunk, weight, target_chunk, bias 130 ) 131 grad_bias.add_(chunk_grad_bias) 132 else: 133 (chunk_grad_input, chunk_grad_weight), ( 134 chunk_loss, 135 (chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps), 136 ) = torch.func.grad_and_value( 137 loss_func_to_call, argnums=(0, 1), has_aux=True 138 )( 139 input_chunk, weight, target_chunk 140 ) 141 grad_weight.add_(chunk_grad_weight) 142 loss_acc.add_(chunk_loss) 143 return chunk_grad_input 144 145 if compiled: 146 accumulate_chunk = torch.compile(accumulate_chunk) 147 148 len_chosen = target.shape[0] // 2 149 chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE)) 150 _chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0) 151 _chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0) 152 _rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0) 153 _rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0) 154 155 for ( 156 chosen_input_chunk, 157 rejected_input_chunk, 158 chosen_target_chunk, 159 rejected_target_chunk, 160 ) in zip( 161 _chosen_input_chunks, 162 _rejected_input_chunks, 163 _chosen_target_chunks, 164 _rejected_target_chunks, 165 ): 166 input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0) 167 target_chunk = torch.cat( 168 [chosen_target_chunk, rejected_target_chunk], dim=0 169 ) 170 171 grad_input = accumulate_chunk(input_chunk, target_chunk) 172 173 grad_chosen_inputs.append(grad_input[: chosen_target_chunk.shape[0]]) 174 grad_rejected_inputs.append(grad_input[chosen_target_chunk.shape[0] :]) 175 176 # combine grad_chosen_inputs and grad_rejected_inputs 177 grad_inputs = grad_chosen_inputs + grad_rejected_inputs 178 179 ctx.save_for_backward( 180 torch.cat(grad_inputs, dim=0), 181 grad_weight, 182 grad_bias, 183 ) 184 return loss_acc 185 186 @staticmethod 187 def backward(ctx, grad_output): 188 grad_input, grad_weight, grad_bias = ctx.saved_tensors 189 if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)): 190 grad_input = grad_input * grad_output 191 grad_weight = grad_weight * grad_output 192 grad_bias = grad_bias * grad_output if grad_bias is not None else None 193 194 return grad_input, grad_weight, None, grad_bias, None, None, None 195 196 @staticmethod 197 def _compute_loss( 198 input_chunk, 199 weight, 200 target_chunk, 201 bias=None, 202 preference_loss_fn=None, 203 full_target=None, 204 ignore_index=-100, 205 alpha=1.0, 206 beta=0.1, 207 compute_nll_loss=True, 208 use_ref_model=False, 209 ref_weight=None, 210 ref_bias=None, 211 **loss_kwargs, 212 ): 213 """ 214 Compute the total loss for a chunk of input and target, while using an alignment/preference loss function. 215 Args: 216 preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. 217 input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size). 218 weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). 219 target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length). 220 bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). 221 full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length). 222 ignore_index (int): Index to ignore for loss computation. 223 alpha (float): Weight for the NLL loss. 224 beta (float): Weight for the odds ratio loss. 225 compute_nll_loss (bool): Whether to compute NLL loss. 226 use_ref_model (bool): Whether to use a reference model for the alignment loss. 227 ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). 228 ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). 229 loss_kwargs (dict): Additional arguments for the loss function. 230 """ 231 chosen_logps, rejected_logps, chosen_nll_loss = ( 232 LigerFusedLinearPreferenceBase.chunk_forward( 233 input_chunk, 234 weight, 235 target_chunk, 236 bias=bias, 237 ignore_index=ignore_index, 238 compute_nll_loss=compute_nll_loss, 239 ) 240 ) 241 chosen_nll_loss = ( 242 chosen_nll_loss 243 / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() 244 ) 245 246 if use_ref_model: 247 with torch.no_grad(): 248 ref_chosen_logps, ref_rejected_logps, _ = ( 249 LigerFusedLinearPreferenceBase.chunk_forward( 250 input_chunk, 251 ref_weight, 252 target_chunk, 253 ref_bias, 254 ignore_index=ignore_index, 255 compute_nll_loss=False, 256 ) 257 ) 258 loss_kwargs["ref_chosen_logps"] = ref_chosen_logps 259 loss_kwargs["ref_rejected_logps"] = ref_rejected_logps 260 261 alignment_loss = preference_loss_fn( 262 chosen_logps, rejected_logps, beta=beta, **loss_kwargs 263 ) 264 alignment_loss = alignment_loss / (full_target.shape[0] // 2) 265 266 loss = alpha * chosen_nll_loss - alignment_loss 267 return loss, (alignment_loss, chosen_logps, rejected_logps)