/ src / liger_kernel / transformers / model / mllama.py
mllama.py
  1  from typing import List, Optional, Tuple, Union
  2  
  3  import torch
  4  from torch.nn import CrossEntropyLoss
  5  from transformers.cache_utils import Cache
  6  from transformers.modeling_outputs import CausalLMOutputWithPast
  7  from transformers.models.mllama.modeling_mllama import MLLAMA_INPUTS_DOCSTRING
  8  from transformers.utils import (
  9      add_start_docstrings_to_model_forward,
 10      replace_return_docstrings,
 11  )
 12  
 13  from liger_kernel.transformers.fused_linear_cross_entropy import (
 14      LigerFusedLinearCrossEntropyLoss,
 15  )
 16  
 17  
 18  @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
 19  @replace_return_docstrings(
 20      output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
 21  )
 22  def lce_forward_deprecated(
 23      self,
 24      input_ids: torch.LongTensor = None,
 25      attention_mask: Optional[torch.Tensor] = None,
 26      position_ids: Optional[torch.LongTensor] = None,
 27      cross_attention_states: Optional[torch.LongTensor] = None,
 28      cross_attention_mask: Optional[torch.LongTensor] = None,
 29      full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
 30      past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
 31      inputs_embeds: Optional[torch.FloatTensor] = None,
 32      labels: Optional[torch.LongTensor] = None,
 33      use_cache: Optional[bool] = None,
 34      output_attentions: Optional[bool] = None,
 35      output_hidden_states: Optional[bool] = None,
 36      return_dict: Optional[bool] = None,
 37      cache_position: Optional[torch.LongTensor] = None,
 38      num_logits_to_keep: int = 0,
 39  ) -> Union[Tuple, CausalLMOutputWithPast]:
 40      r"""
 41      Copy paste mllama forward but replace torch cross entropy with liger fused linear cross entropy
 42  
 43  
 44      Args:
 45          labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 46              Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
 47              config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
 48              (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
 49          num_logits_to_keep (`int`, *optional*):
 50              Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
 51              `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
 52              token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
 53      Returns:
 54      Example:
 55      ```python
 56      >>> from transformers import AutoTokenizer, MllamaForCausalLM
 57      >>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision")
 58      >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision")
 59      >>> prompt = "If I had to write a haiku, it would be:"
 60      >>> inputs = tokenizer(prompt, return_tensors="pt")
 61      >>> # Generate
 62      >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6)
 63      >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 64      >>> print(result)
 65      If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful.
 66      I love the idea of snowflakes gently falling, each one
 67      ```
 68      """
 69      output_attentions = (
 70          output_attentions
 71          if output_attentions is not None
 72          else self.config.output_attentions
 73      )
 74      output_hidden_states = (
 75          output_hidden_states
 76          if output_hidden_states is not None
 77          else self.config.output_hidden_states
 78      )
 79      return_dict = (
 80          return_dict if return_dict is not None else self.config.use_return_dict
 81      )
 82  
 83      # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
 84      outputs = self.model(
 85          input_ids=input_ids,
 86          cross_attention_states=cross_attention_states,
 87          attention_mask=attention_mask,
 88          position_ids=position_ids,
 89          cross_attention_mask=cross_attention_mask,
 90          full_text_row_masked_out_mask=full_text_row_masked_out_mask,
 91          past_key_values=past_key_values,
 92          inputs_embeds=inputs_embeds,
 93          use_cache=use_cache,
 94          output_attentions=output_attentions,
 95          output_hidden_states=output_hidden_states,
 96          return_dict=return_dict,
 97          cache_position=cache_position,
 98      )
 99  
100      hidden_states = outputs[0]
101  
102      loss = None
103      logits = None
104  
105      if self.training and (labels is not None):
106          kept_hidden_states = hidden_states[:, -num_logits_to_keep:, :]
107  
108          shift_hidden_states = kept_hidden_states[..., :-1, :].contiguous()
109          shift_labels = labels[..., 1:].contiguous()
110  
111          # flatten tokens
112          shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
113          shift_labels = shift_labels.view(-1)
114  
115          lce = LigerFusedLinearCrossEntropyLoss()
116          loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
117  
118      else:
119          logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
120          if labels is not None:
121              # Shift so that tokens < n predict n
122              shift_logits = logits[..., :-1, :].contiguous()
123              shift_labels = labels[..., 1:].contiguous()
124              # Flatten the tokens
125              loss_fct = CrossEntropyLoss()
126              shift_logits = shift_logits.view(-1, self.config.vocab_size)
127              shift_labels = shift_labels.view(-1)
128              # Enable model parallelism
129              shift_labels = shift_labels.to(shift_logits.device)
130              loss = loss_fct(shift_logits, shift_labels)
131  
132      if not return_dict:
133          output = (logits,) + outputs[1:]
134          return (loss,) + output if loss is not None else output
135  
136      return CausalLMOutputWithPast(
137          loss=loss,
138          logits=logits,
139          past_key_values=outputs.past_key_values,
140          hidden_states=outputs.hidden_states,
141          attentions=outputs.attentions,
142      )
143  
144  
145  @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
146  @replace_return_docstrings(
147      output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
148  )
149  def lce_forward(
150      self,
151      input_ids: torch.LongTensor = None,
152      attention_mask: Optional[torch.Tensor] = None,
153      position_ids: Optional[torch.LongTensor] = None,
154      cross_attention_states: Optional[torch.LongTensor] = None,
155      cross_attention_mask: Optional[torch.LongTensor] = None,
156      full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
157      past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
158      inputs_embeds: Optional[torch.FloatTensor] = None,
159      labels: Optional[torch.LongTensor] = None,
160      use_cache: Optional[bool] = None,
161      output_attentions: Optional[bool] = None,
162      output_hidden_states: Optional[bool] = None,
163      return_dict: Optional[bool] = None,
164      cache_position: Optional[torch.LongTensor] = None,
165      num_logits_to_keep: int = 0,
166      **loss_kwargs,
167  ) -> Union[Tuple, CausalLMOutputWithPast]:
168      r"""
169      Args:
170          labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
171              Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
172              config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
173              (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
174  
175          num_logits_to_keep (`int`, *optional*):
176              Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
177              `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
178              token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
179  
180      Returns:
181  
182      Example:
183  
184      ```python
185      >>> from transformers import AutoTokenizer, MllamaForCausalLM
186  
187      >>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision")
188      >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision")
189  
190      >>> prompt = "If I had to write a haiku, it would be:"
191      >>> inputs = tokenizer(prompt, return_tensors="pt")
192  
193      >>> # Generate
194      >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6)
195      >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
196      >>> print(result)
197      If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful.
198      I love the idea of snowflakes gently falling, each one
199      ```
200      """
201      output_attentions = (
202          output_attentions
203          if output_attentions is not None
204          else self.config.output_attentions
205      )
206      output_hidden_states = (
207          output_hidden_states
208          if output_hidden_states is not None
209          else self.config.output_hidden_states
210      )
211      return_dict = (
212          return_dict if return_dict is not None else self.config.use_return_dict
213      )
214  
215      # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
216      outputs = self.model(
217          input_ids=input_ids,
218          cross_attention_states=cross_attention_states,
219          attention_mask=attention_mask,
220          position_ids=position_ids,
221          cross_attention_mask=cross_attention_mask,
222          full_text_row_masked_out_mask=full_text_row_masked_out_mask,
223          past_key_values=past_key_values,
224          inputs_embeds=inputs_embeds,
225          use_cache=use_cache,
226          output_attentions=output_attentions,
227          output_hidden_states=output_hidden_states,
228          return_dict=return_dict,
229          cache_position=cache_position,
230      )
231  
232      hidden_states = outputs[0]
233  
234      logits = None
235      loss = None
236      # if in training mode, don't materialize logits
237      if self.training and (labels is not None):
238          # We do the same thing as ForCausalLMLoss but using Liger FLCE
239  
240          shift_hidden_states = hidden_states[..., :-1, :].contiguous()
241          shift_labels = labels[..., 1:].contiguous()
242  
243          # flatten tokens
244          shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
245          shift_labels = shift_labels.view(-1)
246  
247          reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
248          lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
249  
250          loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
251          if reduction == "sum":
252              loss /= loss_kwargs["num_items_in_batch"]
253  
254      else:  # if in inference mode materialize logits
255          logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
256          if labels is not None:
257              loss = self.loss_function(
258                  logits=logits,
259                  labels=labels,
260                  vocab_size=self.config.vocab_size,
261                  **loss_kwargs,
262              )
263  
264      if not return_dict:
265          output = (logits,) + outputs[1:]
266          return (loss,) + output if loss is not None else output
267  
268      return CausalLMOutputWithPast(
269          loss=loss,
270          logits=logits,
271          past_key_values=outputs.past_key_values,
272          hidden_states=outputs.hidden_states,
273          attentions=outputs.attentions,
274      )