mixtral.py
1 from typing import List, Optional, Tuple, Union 2 3 import torch 4 from torch.nn import CrossEntropyLoss 5 from transformers.modeling_outputs import MoeCausalLMOutputWithPast 6 from transformers.models.mixtral.modeling_mixtral import ( 7 _CONFIG_FOR_DOC, 8 MIXTRAL_INPUTS_DOCSTRING, 9 load_balancing_loss_func, 10 ) 11 from transformers.utils import ( 12 add_start_docstrings_to_model_forward, 13 replace_return_docstrings, 14 ) 15 16 from liger_kernel.transformers.fused_linear_cross_entropy import ( 17 LigerFusedLinearCrossEntropyLoss, 18 ) 19 20 21 @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) 22 @replace_return_docstrings( 23 output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC 24 ) 25 def lce_forward_deprecated( 26 self, 27 input_ids: torch.LongTensor = None, 28 attention_mask: Optional[torch.Tensor] = None, 29 position_ids: Optional[torch.LongTensor] = None, 30 past_key_values: Optional[List[torch.FloatTensor]] = None, 31 inputs_embeds: Optional[torch.FloatTensor] = None, 32 labels: Optional[torch.LongTensor] = None, 33 use_cache: Optional[bool] = None, 34 output_attentions: Optional[bool] = None, 35 output_hidden_states: Optional[bool] = None, 36 output_router_logits: Optional[bool] = None, 37 return_dict: Optional[bool] = None, 38 cache_position: Optional[torch.LongTensor] = None, 39 ) -> Union[Tuple, MoeCausalLMOutputWithPast]: 40 r""" 41 Copy paste Mixtral's forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy 42 43 44 Args: 45 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 46 Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 47 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 48 (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 49 50 Returns: 51 52 Example: 53 54 ```python 55 >>> from transformers import AutoTokenizer, MixtralForCausalLM 56 57 >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1") 58 >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1") 59 60 >>> prompt = "Hey, are you conscious? Can you talk to me?" 61 >>> inputs = tokenizer(prompt, return_tensors="pt") 62 63 >>> # Generate 64 >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 65 >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 66 "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 67 ```""" 68 69 output_attentions = ( 70 output_attentions 71 if output_attentions is not None 72 else self.config.output_attentions 73 ) 74 output_router_logits = ( 75 output_router_logits 76 if output_router_logits is not None 77 else self.config.output_router_logits 78 ) 79 80 output_hidden_states = ( 81 output_hidden_states 82 if output_hidden_states is not None 83 else self.config.output_hidden_states 84 ) 85 return_dict = ( 86 return_dict if return_dict is not None else self.config.use_return_dict 87 ) 88 89 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 90 outputs = self.model( 91 input_ids=input_ids, 92 attention_mask=attention_mask, 93 position_ids=position_ids, 94 past_key_values=past_key_values, 95 inputs_embeds=inputs_embeds, 96 use_cache=use_cache, 97 output_attentions=output_attentions, 98 output_hidden_states=output_hidden_states, 99 output_router_logits=output_router_logits, 100 return_dict=return_dict, 101 cache_position=cache_position, 102 ) 103 104 hidden_states = outputs[0] 105 logits = self.lm_head(hidden_states) 106 107 loss = None 108 if self.training and (labels is not None): 109 shift_hidden_states = hidden_states[..., :-1, :].contiguous() 110 shift_labels = labels[..., 1:].contiguous() 111 # Flatten the tokens 112 shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) 113 shift_labels = shift_labels.view(-1) 114 115 lce = LigerFusedLinearCrossEntropyLoss() 116 loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) 117 elif labels is not None: 118 # Upcast to float if we need to compute the loss to avoid potential precision issues 119 logits = logits.float() 120 # Shift so that tokens < n predict n 121 shift_logits = logits[..., :-1, :].contiguous() 122 shift_labels = labels[..., 1:].contiguous() 123 # Flatten the tokens 124 shift_logits = shift_logits.view(-1, self.config.vocab_size) 125 shift_labels = shift_labels.view(-1) 126 # Enable model parallelism 127 shift_labels = shift_labels.to(shift_logits.device) 128 129 loss_fct = CrossEntropyLoss() 130 loss = loss_fct(logits.weight, shift_labels) 131 132 aux_loss = None 133 if output_router_logits: 134 aux_loss = load_balancing_loss_func( 135 outputs.router_logits if return_dict else outputs[-1], 136 self.num_experts, 137 self.num_experts_per_tok, 138 attention_mask, 139 ) 140 if labels is not None: 141 loss += self.router_aux_loss_coef * aux_loss.to( 142 loss.device 143 ) # make sure to reside in the same device 144 145 if not return_dict: 146 output = (logits,) + outputs[1:] 147 if output_router_logits: 148 output = (aux_loss,) + output 149 return (loss,) + output if loss is not None else output 150 151 return MoeCausalLMOutputWithPast( 152 loss=loss, 153 aux_loss=aux_loss, 154 logits=logits, 155 past_key_values=outputs.past_key_values, 156 hidden_states=outputs.hidden_states, 157 attentions=outputs.attentions, 158 router_logits=outputs.router_logits, 159 ) 160 161 162 @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) 163 @replace_return_docstrings( 164 output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC 165 ) 166 # Ignore copy 167 def lce_forward( 168 self, 169 input_ids: torch.LongTensor = None, 170 attention_mask: Optional[torch.Tensor] = None, 171 position_ids: Optional[torch.LongTensor] = None, 172 past_key_values: Optional[List[torch.FloatTensor]] = None, 173 inputs_embeds: Optional[torch.FloatTensor] = None, 174 labels: Optional[torch.LongTensor] = None, 175 use_cache: Optional[bool] = None, 176 output_attentions: Optional[bool] = None, 177 output_hidden_states: Optional[bool] = None, 178 output_router_logits: Optional[bool] = None, 179 return_dict: Optional[bool] = None, 180 cache_position: Optional[torch.LongTensor] = None, 181 num_logits_to_keep: int = 0, 182 **loss_kwargs, 183 ) -> Union[Tuple, MoeCausalLMOutputWithPast]: 184 r""" 185 Args: 186 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 187 Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 188 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 189 (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 190 191 num_logits_to_keep (`int`, *optional*): 192 Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all 193 `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that 194 token can save memory, which becomes pretty significant for long sequences or large vocabulary size. 195 196 Returns: 197 198 Example: 199 200 ```python 201 >>> from transformers import AutoTokenizer, MixtralForCausalLM 202 203 >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1") 204 >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1") 205 206 >>> prompt = "Hey, are you conscious? Can you talk to me?" 207 >>> inputs = tokenizer(prompt, return_tensors="pt") 208 209 >>> # Generate 210 >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 211 >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 212 "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 213 ```""" 214 215 output_attentions = ( 216 output_attentions 217 if output_attentions is not None 218 else self.config.output_attentions 219 ) 220 output_router_logits = ( 221 output_router_logits 222 if output_router_logits is not None 223 else self.config.output_router_logits 224 ) 225 226 output_hidden_states = ( 227 output_hidden_states 228 if output_hidden_states is not None 229 else self.config.output_hidden_states 230 ) 231 return_dict = ( 232 return_dict if return_dict is not None else self.config.use_return_dict 233 ) 234 235 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 236 outputs = self.model( 237 input_ids=input_ids, 238 attention_mask=attention_mask, 239 position_ids=position_ids, 240 past_key_values=past_key_values, 241 inputs_embeds=inputs_embeds, 242 use_cache=use_cache, 243 output_attentions=output_attentions, 244 output_hidden_states=output_hidden_states, 245 output_router_logits=output_router_logits, 246 return_dict=return_dict, 247 cache_position=cache_position, 248 ) 249 250 hidden_states = outputs[0] 251 252 logits = None 253 loss = None 254 # if in training mode, don't materialize logits 255 if self.training and (labels is not None): 256 # We do the same thing as ForCausalLMLoss but using Liger FLCE 257 258 shift_hidden_states = hidden_states[..., :-1, :].contiguous() 259 shift_labels = labels[..., 1:].contiguous() 260 261 # flatten tokens 262 shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) 263 shift_labels = shift_labels.view(-1) 264 265 reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean" 266 lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction) 267 268 loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) 269 if reduction == "sum": 270 loss /= loss_kwargs["num_items_in_batch"] 271 272 else: # if in inference mode materialize logits 273 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) 274 if labels is not None: 275 loss = self.loss_function( 276 logits=logits, 277 labels=labels, 278 vocab_size=self.config.vocab_size, 279 **loss_kwargs, 280 ) 281 282 aux_loss = None 283 if output_router_logits: 284 aux_loss = load_balancing_loss_func( 285 outputs.router_logits if return_dict else outputs[-1], 286 self.num_experts, 287 self.num_experts_per_tok, 288 attention_mask, 289 ) 290 if labels is not None: 291 loss += self.router_aux_loss_coef * aux_loss.to( 292 loss.device 293 ) # make sure to reside in the same device 294 295 if not return_dict: 296 output = (logits,) + outputs[1:] 297 if output_router_logits: 298 output = (aux_loss,) + output 299 return (loss,) + output if loss is not None else output 300 301 return MoeCausalLMOutputWithPast( 302 loss=loss, 303 aux_loss=aux_loss, 304 logits=logits, 305 past_key_values=outputs.past_key_values, 306 hidden_states=outputs.hidden_states, 307 attentions=outputs.attentions, 308 router_logits=outputs.router_logits, 309 )