/ src / liger_kernel / transformers / model / mixtral.py
mixtral.py
  1  from typing import List, Optional, Tuple, Union
  2  
  3  import torch
  4  from torch.nn import CrossEntropyLoss
  5  from transformers.modeling_outputs import MoeCausalLMOutputWithPast
  6  from transformers.models.mixtral.modeling_mixtral import (
  7      _CONFIG_FOR_DOC,
  8      MIXTRAL_INPUTS_DOCSTRING,
  9      load_balancing_loss_func,
 10  )
 11  from transformers.utils import (
 12      add_start_docstrings_to_model_forward,
 13      replace_return_docstrings,
 14  )
 15  
 16  from liger_kernel.transformers.fused_linear_cross_entropy import (
 17      LigerFusedLinearCrossEntropyLoss,
 18  )
 19  
 20  
 21  @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
 22  @replace_return_docstrings(
 23      output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
 24  )
 25  def lce_forward_deprecated(
 26      self,
 27      input_ids: torch.LongTensor = None,
 28      attention_mask: Optional[torch.Tensor] = None,
 29      position_ids: Optional[torch.LongTensor] = None,
 30      past_key_values: Optional[List[torch.FloatTensor]] = None,
 31      inputs_embeds: Optional[torch.FloatTensor] = None,
 32      labels: Optional[torch.LongTensor] = None,
 33      use_cache: Optional[bool] = None,
 34      output_attentions: Optional[bool] = None,
 35      output_hidden_states: Optional[bool] = None,
 36      output_router_logits: Optional[bool] = None,
 37      return_dict: Optional[bool] = None,
 38      cache_position: Optional[torch.LongTensor] = None,
 39  ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
 40      r"""
 41      Copy paste Mixtral's forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
 42  
 43  
 44      Args:
 45          labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 46              Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
 47              config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
 48              (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
 49  
 50      Returns:
 51  
 52      Example:
 53  
 54      ```python
 55      >>> from transformers import AutoTokenizer, MixtralForCausalLM
 56  
 57      >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
 58      >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
 59  
 60      >>> prompt = "Hey, are you conscious? Can you talk to me?"
 61      >>> inputs = tokenizer(prompt, return_tensors="pt")
 62  
 63      >>> # Generate
 64      >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
 65      >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 66      "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
 67      ```"""
 68  
 69      output_attentions = (
 70          output_attentions
 71          if output_attentions is not None
 72          else self.config.output_attentions
 73      )
 74      output_router_logits = (
 75          output_router_logits
 76          if output_router_logits is not None
 77          else self.config.output_router_logits
 78      )
 79  
 80      output_hidden_states = (
 81          output_hidden_states
 82          if output_hidden_states is not None
 83          else self.config.output_hidden_states
 84      )
 85      return_dict = (
 86          return_dict if return_dict is not None else self.config.use_return_dict
 87      )
 88  
 89      # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
 90      outputs = self.model(
 91          input_ids=input_ids,
 92          attention_mask=attention_mask,
 93          position_ids=position_ids,
 94          past_key_values=past_key_values,
 95          inputs_embeds=inputs_embeds,
 96          use_cache=use_cache,
 97          output_attentions=output_attentions,
 98          output_hidden_states=output_hidden_states,
 99          output_router_logits=output_router_logits,
100          return_dict=return_dict,
101          cache_position=cache_position,
102      )
103  
104      hidden_states = outputs[0]
105      logits = self.lm_head(hidden_states)
106  
107      loss = None
108      if self.training and (labels is not None):
109          shift_hidden_states = hidden_states[..., :-1, :].contiguous()
110          shift_labels = labels[..., 1:].contiguous()
111          # Flatten the tokens
112          shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
113          shift_labels = shift_labels.view(-1)
114  
115          lce = LigerFusedLinearCrossEntropyLoss()
116          loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
117      elif labels is not None:
118          # Upcast to float if we need to compute the loss to avoid potential precision issues
119          logits = logits.float()
120          # Shift so that tokens < n predict n
121          shift_logits = logits[..., :-1, :].contiguous()
122          shift_labels = labels[..., 1:].contiguous()
123          # Flatten the tokens
124          shift_logits = shift_logits.view(-1, self.config.vocab_size)
125          shift_labels = shift_labels.view(-1)
126          # Enable model parallelism
127          shift_labels = shift_labels.to(shift_logits.device)
128  
129          loss_fct = CrossEntropyLoss()
130          loss = loss_fct(logits.weight, shift_labels)
131  
132      aux_loss = None
133      if output_router_logits:
134          aux_loss = load_balancing_loss_func(
135              outputs.router_logits if return_dict else outputs[-1],
136              self.num_experts,
137              self.num_experts_per_tok,
138              attention_mask,
139          )
140          if labels is not None:
141              loss += self.router_aux_loss_coef * aux_loss.to(
142                  loss.device
143              )  # make sure to reside in the same device
144  
145      if not return_dict:
146          output = (logits,) + outputs[1:]
147          if output_router_logits:
148              output = (aux_loss,) + output
149          return (loss,) + output if loss is not None else output
150  
151      return MoeCausalLMOutputWithPast(
152          loss=loss,
153          aux_loss=aux_loss,
154          logits=logits,
155          past_key_values=outputs.past_key_values,
156          hidden_states=outputs.hidden_states,
157          attentions=outputs.attentions,
158          router_logits=outputs.router_logits,
159      )
160  
161  
162  @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
163  @replace_return_docstrings(
164      output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
165  )
166  # Ignore copy
167  def lce_forward(
168      self,
169      input_ids: torch.LongTensor = None,
170      attention_mask: Optional[torch.Tensor] = None,
171      position_ids: Optional[torch.LongTensor] = None,
172      past_key_values: Optional[List[torch.FloatTensor]] = None,
173      inputs_embeds: Optional[torch.FloatTensor] = None,
174      labels: Optional[torch.LongTensor] = None,
175      use_cache: Optional[bool] = None,
176      output_attentions: Optional[bool] = None,
177      output_hidden_states: Optional[bool] = None,
178      output_router_logits: Optional[bool] = None,
179      return_dict: Optional[bool] = None,
180      cache_position: Optional[torch.LongTensor] = None,
181      num_logits_to_keep: int = 0,
182      **loss_kwargs,
183  ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
184      r"""
185      Args:
186          labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
187              Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
188              config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
189              (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
190  
191          num_logits_to_keep (`int`, *optional*):
192              Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
193              `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
194              token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
195  
196      Returns:
197  
198      Example:
199  
200      ```python
201      >>> from transformers import AutoTokenizer, MixtralForCausalLM
202  
203      >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
204      >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
205  
206      >>> prompt = "Hey, are you conscious? Can you talk to me?"
207      >>> inputs = tokenizer(prompt, return_tensors="pt")
208  
209      >>> # Generate
210      >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
211      >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
212      "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
213      ```"""
214  
215      output_attentions = (
216          output_attentions
217          if output_attentions is not None
218          else self.config.output_attentions
219      )
220      output_router_logits = (
221          output_router_logits
222          if output_router_logits is not None
223          else self.config.output_router_logits
224      )
225  
226      output_hidden_states = (
227          output_hidden_states
228          if output_hidden_states is not None
229          else self.config.output_hidden_states
230      )
231      return_dict = (
232          return_dict if return_dict is not None else self.config.use_return_dict
233      )
234  
235      # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
236      outputs = self.model(
237          input_ids=input_ids,
238          attention_mask=attention_mask,
239          position_ids=position_ids,
240          past_key_values=past_key_values,
241          inputs_embeds=inputs_embeds,
242          use_cache=use_cache,
243          output_attentions=output_attentions,
244          output_hidden_states=output_hidden_states,
245          output_router_logits=output_router_logits,
246          return_dict=return_dict,
247          cache_position=cache_position,
248      )
249  
250      hidden_states = outputs[0]
251  
252      logits = None
253      loss = None
254      # if in training mode, don't materialize logits
255      if self.training and (labels is not None):
256          # We do the same thing as ForCausalLMLoss but using Liger FLCE
257  
258          shift_hidden_states = hidden_states[..., :-1, :].contiguous()
259          shift_labels = labels[..., 1:].contiguous()
260  
261          # flatten tokens
262          shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
263          shift_labels = shift_labels.view(-1)
264  
265          reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
266          lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
267  
268          loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
269          if reduction == "sum":
270              loss /= loss_kwargs["num_items_in_batch"]
271  
272      else:  # if in inference mode materialize logits
273          logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
274          if labels is not None:
275              loss = self.loss_function(
276                  logits=logits,
277                  labels=labels,
278                  vocab_size=self.config.vocab_size,
279                  **loss_kwargs,
280              )
281  
282      aux_loss = None
283      if output_router_logits:
284          aux_loss = load_balancing_loss_func(
285              outputs.router_logits if return_dict else outputs[-1],
286              self.num_experts,
287              self.num_experts_per_tok,
288              attention_mask,
289          )
290          if labels is not None:
291              loss += self.router_aux_loss_coef * aux_loss.to(
292                  loss.device
293              )  # make sure to reside in the same device
294  
295      if not return_dict:
296          output = (logits,) + outputs[1:]
297          if output_router_logits:
298              output = (aux_loss,) + output
299          return (loss,) + output if loss is not None else output
300  
301      return MoeCausalLMOutputWithPast(
302          loss=loss,
303          aux_loss=aux_loss,
304          logits=logits,
305          past_key_values=outputs.past_key_values,
306          hidden_states=outputs.hidden_states,
307          attentions=outputs.attentions,
308          router_logits=outputs.router_logits,
309      )