/ src / liger_kernel / transformers / model / mistral.py
mistral.py
  1  from typing import List, Optional, Tuple, Union
  2  
  3  import torch
  4  from torch.nn import CrossEntropyLoss
  5  from transformers.cache_utils import Cache
  6  from transformers.modeling_outputs import CausalLMOutputWithPast
  7  from transformers.models.mistral.modeling_mistral import (
  8      _CONFIG_FOR_DOC,
  9      MISTRAL_INPUTS_DOCSTRING,
 10  )
 11  from transformers.utils import (
 12      add_start_docstrings_to_model_forward,
 13      replace_return_docstrings,
 14  )
 15  
 16  from liger_kernel.transformers.fused_linear_cross_entropy import (
 17      LigerFusedLinearCrossEntropyLoss,
 18  )
 19  
 20  
 21  @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
 22  @replace_return_docstrings(
 23      output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
 24  )
 25  def lce_forward(
 26      self,
 27      input_ids: torch.LongTensor = None,
 28      attention_mask: Optional[torch.Tensor] = None,
 29      position_ids: Optional[torch.LongTensor] = None,
 30      past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
 31      inputs_embeds: Optional[torch.FloatTensor] = None,
 32      labels: Optional[torch.LongTensor] = None,
 33      use_cache: Optional[bool] = None,
 34      output_attentions: Optional[bool] = None,
 35      output_hidden_states: Optional[bool] = None,
 36      return_dict: Optional[bool] = None,
 37      cache_position: Optional[torch.LongTensor] = None,
 38  ) -> Union[Tuple, CausalLMOutputWithPast]:
 39      r"""
 40      Copy paste Mistral's forward but replace torch cross entropy with liger fused linear cross entropy
 41  
 42  
 43      Args:
 44          labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 45              Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
 46              config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
 47              (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
 48  
 49      Returns:
 50  
 51      Example:
 52  
 53      ```python
 54      >>> from transformers import AutoTokenizer, MistralForCausalLM
 55  
 56      >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
 57      >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
 58  
 59      >>> prompt = "Hey, are you conscious? Can you talk to me?"
 60      >>> inputs = tokenizer(prompt, return_tensors="pt")
 61  
 62      >>> # Generate
 63      >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
 64      >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 65      "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
 66      ```"""
 67  
 68      output_attentions = (
 69          output_attentions
 70          if output_attentions is not None
 71          else self.config.output_attentions
 72      )
 73      output_hidden_states = (
 74          output_hidden_states
 75          if output_hidden_states is not None
 76          else self.config.output_hidden_states
 77      )
 78      return_dict = (
 79          return_dict if return_dict is not None else self.config.use_return_dict
 80      )
 81  
 82      # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
 83      outputs = self.model(
 84          input_ids=input_ids,
 85          attention_mask=attention_mask,
 86          position_ids=position_ids,
 87          past_key_values=past_key_values,
 88          inputs_embeds=inputs_embeds,
 89          use_cache=use_cache,
 90          output_attentions=output_attentions,
 91          output_hidden_states=output_hidden_states,
 92          return_dict=return_dict,
 93          cache_position=cache_position,
 94      )
 95  
 96      hidden_states = outputs[0]
 97  
 98      loss = None
 99      logits = None
100  
101      if self.training and (labels is not None):
102          shift_hidden_states = hidden_states[..., :-1, :].contiguous()
103          shift_labels = labels[..., 1:].contiguous()
104  
105          # flatten tokens
106          shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
107          shift_labels = shift_labels.view(-1)
108  
109          lce = LigerFusedLinearCrossEntropyLoss()
110          loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
111  
112      else:
113          logits = self.lm_head(hidden_states)
114          if labels is not None:
115              # Upcast to float if we need to compute the loss to avoid potential precision issues
116              logits = logits.float()
117              # Shift so that tokens < n predict n
118              shift_logits = logits[..., :-1, :].contiguous()
119              shift_labels = labels[..., 1:].contiguous()
120              # Flatten the tokens
121              shift_logits = shift_logits.view(-1, self.config.vocab_size)
122              shift_labels = shift_labels.view(-1)
123              # Ensure tensors are on the same device
124              shift_labels = shift_labels.to(shift_logits.device)
125              loss_fct = CrossEntropyLoss()
126              loss = loss_fct(shift_logits, shift_labels)
127  
128      if not return_dict:
129          output = (logits,) + outputs[1:]
130          return (loss,) + output if loss is not None else output
131  
132      return CausalLMOutputWithPast(
133          loss=loss,
134          logits=logits,
135          past_key_values=outputs.past_key_values,
136          hidden_states=outputs.hidden_states,
137          attentions=outputs.attentions,
138      )
139  
140  
141  # Note: Grad Acc is not fixed in mistral at transformer 4.46.1