mistral.py
1 from typing import List, Optional, Tuple, Union 2 3 import torch 4 from torch.nn import CrossEntropyLoss 5 from transformers.cache_utils import Cache 6 from transformers.modeling_outputs import CausalLMOutputWithPast 7 from transformers.models.mistral.modeling_mistral import ( 8 _CONFIG_FOR_DOC, 9 MISTRAL_INPUTS_DOCSTRING, 10 ) 11 from transformers.utils import ( 12 add_start_docstrings_to_model_forward, 13 replace_return_docstrings, 14 ) 15 16 from liger_kernel.transformers.fused_linear_cross_entropy import ( 17 LigerFusedLinearCrossEntropyLoss, 18 ) 19 20 21 @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) 22 @replace_return_docstrings( 23 output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC 24 ) 25 def lce_forward( 26 self, 27 input_ids: torch.LongTensor = None, 28 attention_mask: Optional[torch.Tensor] = None, 29 position_ids: Optional[torch.LongTensor] = None, 30 past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 31 inputs_embeds: Optional[torch.FloatTensor] = None, 32 labels: Optional[torch.LongTensor] = None, 33 use_cache: Optional[bool] = None, 34 output_attentions: Optional[bool] = None, 35 output_hidden_states: Optional[bool] = None, 36 return_dict: Optional[bool] = None, 37 cache_position: Optional[torch.LongTensor] = None, 38 ) -> Union[Tuple, CausalLMOutputWithPast]: 39 r""" 40 Copy paste Mistral's forward but replace torch cross entropy with liger fused linear cross entropy 41 42 43 Args: 44 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 45 Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 46 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 47 (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 48 49 Returns: 50 51 Example: 52 53 ```python 54 >>> from transformers import AutoTokenizer, MistralForCausalLM 55 56 >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") 57 >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") 58 59 >>> prompt = "Hey, are you conscious? Can you talk to me?" 60 >>> inputs = tokenizer(prompt, return_tensors="pt") 61 62 >>> # Generate 63 >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 64 >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 65 "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 66 ```""" 67 68 output_attentions = ( 69 output_attentions 70 if output_attentions is not None 71 else self.config.output_attentions 72 ) 73 output_hidden_states = ( 74 output_hidden_states 75 if output_hidden_states is not None 76 else self.config.output_hidden_states 77 ) 78 return_dict = ( 79 return_dict if return_dict is not None else self.config.use_return_dict 80 ) 81 82 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 83 outputs = self.model( 84 input_ids=input_ids, 85 attention_mask=attention_mask, 86 position_ids=position_ids, 87 past_key_values=past_key_values, 88 inputs_embeds=inputs_embeds, 89 use_cache=use_cache, 90 output_attentions=output_attentions, 91 output_hidden_states=output_hidden_states, 92 return_dict=return_dict, 93 cache_position=cache_position, 94 ) 95 96 hidden_states = outputs[0] 97 98 loss = None 99 logits = None 100 101 if self.training and (labels is not None): 102 shift_hidden_states = hidden_states[..., :-1, :].contiguous() 103 shift_labels = labels[..., 1:].contiguous() 104 105 # flatten tokens 106 shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) 107 shift_labels = shift_labels.view(-1) 108 109 lce = LigerFusedLinearCrossEntropyLoss() 110 loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) 111 112 else: 113 logits = self.lm_head(hidden_states) 114 if labels is not None: 115 # Upcast to float if we need to compute the loss to avoid potential precision issues 116 logits = logits.float() 117 # Shift so that tokens < n predict n 118 shift_logits = logits[..., :-1, :].contiguous() 119 shift_labels = labels[..., 1:].contiguous() 120 # Flatten the tokens 121 shift_logits = shift_logits.view(-1, self.config.vocab_size) 122 shift_labels = shift_labels.view(-1) 123 # Ensure tensors are on the same device 124 shift_labels = shift_labels.to(shift_logits.device) 125 loss_fct = CrossEntropyLoss() 126 loss = loss_fct(shift_logits, shift_labels) 127 128 if not return_dict: 129 output = (logits,) + outputs[1:] 130 return (loss,) + output if loss is not None else output 131 132 return CausalLMOutputWithPast( 133 loss=loss, 134 logits=logits, 135 past_key_values=outputs.past_key_values, 136 hidden_states=outputs.hidden_states, 137 attentions=outputs.attentions, 138 ) 139 140 141 # Note: Grad Acc is not fixed in mistral at transformer 4.46.1