mllama.py
1 from typing import List, Optional, Tuple, Union 2 3 import torch 4 from torch.nn import CrossEntropyLoss 5 from transformers.cache_utils import Cache 6 from transformers.modeling_outputs import CausalLMOutputWithPast 7 from transformers.models.mllama.modeling_mllama import MLLAMA_INPUTS_DOCSTRING 8 from transformers.utils import ( 9 add_start_docstrings_to_model_forward, 10 replace_return_docstrings, 11 ) 12 13 from liger_kernel.transformers.fused_linear_cross_entropy import ( 14 LigerFusedLinearCrossEntropyLoss, 15 ) 16 17 18 @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) 19 @replace_return_docstrings( 20 output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig" 21 ) 22 def lce_forward_deprecated( 23 self, 24 input_ids: torch.LongTensor = None, 25 attention_mask: Optional[torch.Tensor] = None, 26 position_ids: Optional[torch.LongTensor] = None, 27 cross_attention_states: Optional[torch.LongTensor] = None, 28 cross_attention_mask: Optional[torch.LongTensor] = None, 29 full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 30 past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 31 inputs_embeds: Optional[torch.FloatTensor] = None, 32 labels: Optional[torch.LongTensor] = None, 33 use_cache: Optional[bool] = None, 34 output_attentions: Optional[bool] = None, 35 output_hidden_states: Optional[bool] = None, 36 return_dict: Optional[bool] = None, 37 cache_position: Optional[torch.LongTensor] = None, 38 num_logits_to_keep: int = 0, 39 ) -> Union[Tuple, CausalLMOutputWithPast]: 40 r""" 41 Copy paste mllama forward but replace torch cross entropy with liger fused linear cross entropy 42 43 44 Args: 45 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 46 Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 47 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 48 (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 49 num_logits_to_keep (`int`, *optional*): 50 Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all 51 `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that 52 token can save memory, which becomes pretty significant for long sequences or large vocabulary size. 53 Returns: 54 Example: 55 ```python 56 >>> from transformers import AutoTokenizer, MllamaForCausalLM 57 >>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision") 58 >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision") 59 >>> prompt = "If I had to write a haiku, it would be:" 60 >>> inputs = tokenizer(prompt, return_tensors="pt") 61 >>> # Generate 62 >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6) 63 >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 64 >>> print(result) 65 If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful. 66 I love the idea of snowflakes gently falling, each one 67 ``` 68 """ 69 output_attentions = ( 70 output_attentions 71 if output_attentions is not None 72 else self.config.output_attentions 73 ) 74 output_hidden_states = ( 75 output_hidden_states 76 if output_hidden_states is not None 77 else self.config.output_hidden_states 78 ) 79 return_dict = ( 80 return_dict if return_dict is not None else self.config.use_return_dict 81 ) 82 83 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 84 outputs = self.model( 85 input_ids=input_ids, 86 cross_attention_states=cross_attention_states, 87 attention_mask=attention_mask, 88 position_ids=position_ids, 89 cross_attention_mask=cross_attention_mask, 90 full_text_row_masked_out_mask=full_text_row_masked_out_mask, 91 past_key_values=past_key_values, 92 inputs_embeds=inputs_embeds, 93 use_cache=use_cache, 94 output_attentions=output_attentions, 95 output_hidden_states=output_hidden_states, 96 return_dict=return_dict, 97 cache_position=cache_position, 98 ) 99 100 hidden_states = outputs[0] 101 102 loss = None 103 logits = None 104 105 if self.training and (labels is not None): 106 kept_hidden_states = hidden_states[:, -num_logits_to_keep:, :] 107 108 shift_hidden_states = kept_hidden_states[..., :-1, :].contiguous() 109 shift_labels = labels[..., 1:].contiguous() 110 111 # flatten tokens 112 shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) 113 shift_labels = shift_labels.view(-1) 114 115 lce = LigerFusedLinearCrossEntropyLoss() 116 loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) 117 118 else: 119 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() 120 if labels is not None: 121 # Shift so that tokens < n predict n 122 shift_logits = logits[..., :-1, :].contiguous() 123 shift_labels = labels[..., 1:].contiguous() 124 # Flatten the tokens 125 loss_fct = CrossEntropyLoss() 126 shift_logits = shift_logits.view(-1, self.config.vocab_size) 127 shift_labels = shift_labels.view(-1) 128 # Enable model parallelism 129 shift_labels = shift_labels.to(shift_logits.device) 130 loss = loss_fct(shift_logits, shift_labels) 131 132 if not return_dict: 133 output = (logits,) + outputs[1:] 134 return (loss,) + output if loss is not None else output 135 136 return CausalLMOutputWithPast( 137 loss=loss, 138 logits=logits, 139 past_key_values=outputs.past_key_values, 140 hidden_states=outputs.hidden_states, 141 attentions=outputs.attentions, 142 ) 143 144 145 @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) 146 @replace_return_docstrings( 147 output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig" 148 ) 149 def lce_forward( 150 self, 151 input_ids: torch.LongTensor = None, 152 attention_mask: Optional[torch.Tensor] = None, 153 position_ids: Optional[torch.LongTensor] = None, 154 cross_attention_states: Optional[torch.LongTensor] = None, 155 cross_attention_mask: Optional[torch.LongTensor] = None, 156 full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 157 past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 158 inputs_embeds: Optional[torch.FloatTensor] = None, 159 labels: Optional[torch.LongTensor] = None, 160 use_cache: Optional[bool] = None, 161 output_attentions: Optional[bool] = None, 162 output_hidden_states: Optional[bool] = None, 163 return_dict: Optional[bool] = None, 164 cache_position: Optional[torch.LongTensor] = None, 165 num_logits_to_keep: int = 0, 166 **loss_kwargs, 167 ) -> Union[Tuple, CausalLMOutputWithPast]: 168 r""" 169 Args: 170 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 171 Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 172 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 173 (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 174 175 num_logits_to_keep (`int`, *optional*): 176 Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all 177 `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that 178 token can save memory, which becomes pretty significant for long sequences or large vocabulary size. 179 180 Returns: 181 182 Example: 183 184 ```python 185 >>> from transformers import AutoTokenizer, MllamaForCausalLM 186 187 >>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision") 188 >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision") 189 190 >>> prompt = "If I had to write a haiku, it would be:" 191 >>> inputs = tokenizer(prompt, return_tensors="pt") 192 193 >>> # Generate 194 >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6) 195 >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 196 >>> print(result) 197 If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful. 198 I love the idea of snowflakes gently falling, each one 199 ``` 200 """ 201 output_attentions = ( 202 output_attentions 203 if output_attentions is not None 204 else self.config.output_attentions 205 ) 206 output_hidden_states = ( 207 output_hidden_states 208 if output_hidden_states is not None 209 else self.config.output_hidden_states 210 ) 211 return_dict = ( 212 return_dict if return_dict is not None else self.config.use_return_dict 213 ) 214 215 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 216 outputs = self.model( 217 input_ids=input_ids, 218 cross_attention_states=cross_attention_states, 219 attention_mask=attention_mask, 220 position_ids=position_ids, 221 cross_attention_mask=cross_attention_mask, 222 full_text_row_masked_out_mask=full_text_row_masked_out_mask, 223 past_key_values=past_key_values, 224 inputs_embeds=inputs_embeds, 225 use_cache=use_cache, 226 output_attentions=output_attentions, 227 output_hidden_states=output_hidden_states, 228 return_dict=return_dict, 229 cache_position=cache_position, 230 ) 231 232 hidden_states = outputs[0] 233 234 logits = None 235 loss = None 236 # if in training mode, don't materialize logits 237 if self.training and (labels is not None): 238 # We do the same thing as ForCausalLMLoss but using Liger FLCE 239 240 shift_hidden_states = hidden_states[..., :-1, :].contiguous() 241 shift_labels = labels[..., 1:].contiguous() 242 243 # flatten tokens 244 shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) 245 shift_labels = shift_labels.view(-1) 246 247 reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean" 248 lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction) 249 250 loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) 251 if reduction == "sum": 252 loss /= loss_kwargs["num_items_in_batch"] 253 254 else: # if in inference mode materialize logits 255 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) 256 if labels is not None: 257 loss = self.loss_function( 258 logits=logits, 259 labels=labels, 260 vocab_size=self.config.vocab_size, 261 **loss_kwargs, 262 ) 263 264 if not return_dict: 265 output = (logits,) + outputs[1:] 266 return (loss,) + output if loss is not None else output 267 268 return CausalLMOutputWithPast( 269 loss=loss, 270 logits=logits, 271 past_key_values=outputs.past_key_values, 272 hidden_states=outputs.hidden_states, 273 attentions=outputs.attentions, 274 )