/ src / liger_kernel / chunked_loss / fused_linear_preference.py
fused_linear_preference.py
  1  from abc import abstractmethod
  2  from functools import partial
  3  
  4  import torch
  5  from torch.nn import functional as F
  6  
  7  
  8  class LigerFusedLinearPreferenceBase(torch.autograd.Function):
  9  
 10      @abstractmethod
 11      def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
 12          """
 13          Compute preference loss.
 14          Args:
 15              chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
 16              rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
 17              beta (float): Weight for the odds ratio loss.
 18          """
 19          raise NotImplementedError("Preference loss function must be implemented.")
 20  
 21      @staticmethod
 22      def chunk_forward(
 23          input_chunk,
 24          weight,
 25          target_chunk,
 26          bias=None,
 27          ignore_index=-100,
 28          compute_nll_loss=True,
 29      ):
 30          len_chosen_chunk = target_chunk.shape[0] // 2
 31          logits_chunk = input_chunk @ weight.t()
 32          if bias is not None:
 33              logits_chunk = logits_chunk + bias
 34          log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)
 35  
 36          chosen_nll_loss = 0.0
 37          if compute_nll_loss:
 38              chosen_nll_loss = F.nll_loss(
 39                  log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
 40                  target_chunk[:len_chosen_chunk].view(-1),
 41                  reduction="sum",
 42                  ignore_index=ignore_index,
 43              )
 44  
 45          loss_mask = target_chunk != ignore_index
 46          label_chunk = torch.where(loss_mask, target_chunk, 0)
 47  
 48          per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
 49              -1
 50          )
 51          average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
 52  
 53          chosen_logps = average_log_prob[:len_chosen_chunk]
 54          rejected_logps = average_log_prob[len_chosen_chunk:]
 55          return chosen_logps, rejected_logps, chosen_nll_loss
 56  
 57      @staticmethod
 58      def forward(
 59          ctx,
 60          _input,
 61          weight,
 62          target,
 63          bias=None,
 64          loss_fn=None,
 65          chunk_size=1,
 66          ignore_index=-100,
 67          alpha=1.0,
 68          beta=0.1,
 69          compute_nll_loss=True,
 70          compiled=True,
 71          use_ref_model=False,
 72          ref_weight=None,
 73          ref_bias=None,
 74          **loss_kwargs,
 75      ):
 76          """
 77          Base class for fused linear layer with preference loss.
 78          Expects _input to be stacked with chosen and rejected inputs on the batch dimension.
 79  
 80          Args:
 81              _input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
 82              weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
 83              target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len).
 84              bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
 85              loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
 86              chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
 87              compute_nll_loss (bool): Whether to compute NLL loss.
 88              ignore_index (int): Index to ignore for loss computation.
 89              alpha (float): Weight for the NLL loss.
 90              beta (float): Weight for the odds ratio loss.
 91              compute_nll_loss (bool): Whether to compute NLL loss.
 92              compiled (bool): Whether to use torch compile for chunk accumulation.
 93              use_ref_model (bool): Whether to use a reference model for the alignment loss.
 94              ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
 95              ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
 96              loss_kwargs (dict): Other possible arguments that a loss function might need
 97          """
 98          # TODO: Tune CHUNK_SIZE to fully utilize the GPU
 99          CHUNK_SIZE = chunk_size
100  
101          grad_weight = torch.zeros_like(weight)
102          grad_chosen_inputs = []
103          grad_rejected_inputs = []
104          grad_bias = torch.zeros_like(bias) if bias is not None else None
105          loss_acc = torch.zeros((), device=_input.device)
106  
107          loss_func_to_call = partial(
108              LigerFusedLinearPreferenceBase._compute_loss,
109              preference_loss_fn=loss_fn,
110              ignore_index=ignore_index,
111              alpha=alpha,
112              beta=beta,
113              compute_nll_loss=compute_nll_loss,
114              full_target=target,
115              use_ref_model=use_ref_model,
116              ref_weight=ref_weight,
117              ref_bias=ref_bias,
118              **loss_kwargs,
119          )
120  
121          def accumulate_chunk(input_chunk, target_chunk):
122              if bias is not None:
123                  (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
124                      chunk_loss,
125                      (chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps),
126                  ) = torch.func.grad_and_value(
127                      loss_func_to_call, argnums=(0, 1, 3), has_aux=True
128                  )(
129                      input_chunk, weight, target_chunk, bias
130                  )
131                  grad_bias.add_(chunk_grad_bias)
132              else:
133                  (chunk_grad_input, chunk_grad_weight), (
134                      chunk_loss,
135                      (chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps),
136                  ) = torch.func.grad_and_value(
137                      loss_func_to_call, argnums=(0, 1), has_aux=True
138                  )(
139                      input_chunk, weight, target_chunk
140                  )
141              grad_weight.add_(chunk_grad_weight)
142              loss_acc.add_(chunk_loss)
143              return chunk_grad_input
144  
145          if compiled:
146              accumulate_chunk = torch.compile(accumulate_chunk)
147  
148          len_chosen = target.shape[0] // 2
149          chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE))
150          _chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0)
151          _chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0)
152          _rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
153          _rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)
154  
155          for (
156              chosen_input_chunk,
157              rejected_input_chunk,
158              chosen_target_chunk,
159              rejected_target_chunk,
160          ) in zip(
161              _chosen_input_chunks,
162              _rejected_input_chunks,
163              _chosen_target_chunks,
164              _rejected_target_chunks,
165          ):
166              input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
167              target_chunk = torch.cat(
168                  [chosen_target_chunk, rejected_target_chunk], dim=0
169              )
170  
171              grad_input = accumulate_chunk(input_chunk, target_chunk)
172  
173              grad_chosen_inputs.append(grad_input[: chosen_target_chunk.shape[0]])
174              grad_rejected_inputs.append(grad_input[chosen_target_chunk.shape[0] :])
175  
176          # combine grad_chosen_inputs and grad_rejected_inputs
177          grad_inputs = grad_chosen_inputs + grad_rejected_inputs
178  
179          ctx.save_for_backward(
180              torch.cat(grad_inputs, dim=0),
181              grad_weight,
182              grad_bias,
183          )
184          return loss_acc
185  
186      @staticmethod
187      def backward(ctx, grad_output):
188          grad_input, grad_weight, grad_bias = ctx.saved_tensors
189          if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
190              grad_input = grad_input * grad_output
191              grad_weight = grad_weight * grad_output
192              grad_bias = grad_bias * grad_output if grad_bias is not None else None
193  
194          return grad_input, grad_weight, None, grad_bias, None, None, None
195  
196      @staticmethod
197      def _compute_loss(
198          input_chunk,
199          weight,
200          target_chunk,
201          bias=None,
202          preference_loss_fn=None,
203          full_target=None,
204          ignore_index=-100,
205          alpha=1.0,
206          beta=0.1,
207          compute_nll_loss=True,
208          use_ref_model=False,
209          ref_weight=None,
210          ref_bias=None,
211          **loss_kwargs,
212      ):
213          """
214          Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
215          Args:
216              preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
217              input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
218              weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
219              target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
220              bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
221              full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
222              ignore_index (int): Index to ignore for loss computation.
223              alpha (float): Weight for the NLL loss.
224              beta (float): Weight for the odds ratio loss.
225              compute_nll_loss (bool): Whether to compute NLL loss.
226              use_ref_model (bool): Whether to use a reference model for the alignment loss.
227              ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
228              ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
229              loss_kwargs (dict): Additional arguments for the loss function.
230          """
231          chosen_logps, rejected_logps, chosen_nll_loss = (
232              LigerFusedLinearPreferenceBase.chunk_forward(
233                  input_chunk,
234                  weight,
235                  target_chunk,
236                  bias=bias,
237                  ignore_index=ignore_index,
238                  compute_nll_loss=compute_nll_loss,
239              )
240          )
241          chosen_nll_loss = (
242              chosen_nll_loss
243              / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
244          )
245  
246          if use_ref_model:
247              with torch.no_grad():
248                  ref_chosen_logps, ref_rejected_logps, _ = (
249                      LigerFusedLinearPreferenceBase.chunk_forward(
250                          input_chunk,
251                          ref_weight,
252                          target_chunk,
253                          ref_bias,
254                          ignore_index=ignore_index,
255                          compute_nll_loss=False,
256                      )
257                  )
258              loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
259              loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
260  
261          alignment_loss = preference_loss_fn(
262              chosen_logps, rejected_logps, beta=beta, **loss_kwargs
263          )
264          alignment_loss = alignment_loss / (full_target.shape[0] // 2)
265  
266          loss = alpha * chosen_nll_loss - alignment_loss
267          return loss, (alignment_loss, chosen_logps, rejected_logps)