llama.py
  1  from typing import TYPE_CHECKING, List, Optional, Tuple, Union
  2  
  3  import torch
  4  import torch.nn.functional as F
  5  from torch.nn import CrossEntropyLoss
  6  from transformers.modeling_outputs import CausalLMOutputWithPast
  7  from transformers.models.llama.modeling_llama import (
  8      _CONFIG_FOR_DOC,
  9      LLAMA_INPUTS_DOCSTRING,
 10  )
 11  from transformers.utils import (
 12      add_start_docstrings_to_model_forward,
 13      replace_return_docstrings,
 14  )
 15  
 16  from liger_kernel.transformers.fused_linear_cross_entropy import (
 17      LigerFusedLinearCrossEntropyLoss,
 18  )
 19  
 20  if TYPE_CHECKING:
 21      from transformers.cache_utils import Cache
 22  
 23  
 24  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
 25  @replace_return_docstrings(
 26      output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
 27  )
 28  def lce_forward_deprecated(
 29      self,
 30      input_ids: torch.LongTensor = None,
 31      attention_mask: Optional[torch.Tensor] = None,
 32      position_ids: Optional[torch.LongTensor] = None,
 33      past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None,
 34      inputs_embeds: Optional[torch.FloatTensor] = None,
 35      labels: Optional[torch.LongTensor] = None,
 36      use_cache: Optional[bool] = None,
 37      output_attentions: Optional[bool] = None,
 38      output_hidden_states: Optional[bool] = None,
 39      return_dict: Optional[bool] = None,
 40      cache_position: Optional[torch.LongTensor] = None,
 41  ) -> Union[Tuple, CausalLMOutputWithPast]:
 42      r"""
 43      Copy paste llama forward but replace torch cross entropy with liger fused linear cross entropy
 44  
 45  
 46      Args:
 47          labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 48              Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
 49              config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
 50              (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
 51  
 52      Returns:
 53  
 54      Example:
 55  
 56      ```python
 57      >>> from transformers import AutoTokenizer, LlamaForCausalLM
 58  
 59      >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
 60      >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
 61  
 62      >>> prompt = "Hey, are you conscious? Can you talk to me?"
 63      >>> inputs = tokenizer(prompt, return_tensors="pt")
 64  
 65      >>> # Generate
 66      >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
 67      >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 68      "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
 69      ```"""
 70      output_attentions = (
 71          output_attentions
 72          if output_attentions is not None
 73          else self.config.output_attentions
 74      )
 75      output_hidden_states = (
 76          output_hidden_states
 77          if output_hidden_states is not None
 78          else self.config.output_hidden_states
 79      )
 80      return_dict = (
 81          return_dict if return_dict is not None else self.config.use_return_dict
 82      )
 83  
 84      # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
 85      outputs = self.model(
 86          input_ids=input_ids,
 87          attention_mask=attention_mask,
 88          position_ids=position_ids,
 89          past_key_values=past_key_values,
 90          inputs_embeds=inputs_embeds,
 91          use_cache=use_cache,
 92          output_attentions=output_attentions,
 93          output_hidden_states=output_hidden_states,
 94          return_dict=return_dict,
 95          cache_position=cache_position,
 96      )
 97  
 98      hidden_states = outputs[0]
 99  
100      loss = None
101      logits = None
102  
103      if self.training and (labels is not None):
104          shift_hidden_states = hidden_states[..., :-1, :].contiguous()
105          shift_labels = labels[..., 1:].contiguous()
106  
107          # flatten tokens
108          shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
109          shift_labels = shift_labels.view(-1)
110  
111          lce = LigerFusedLinearCrossEntropyLoss()
112          loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
113  
114      else:
115          if self.config.pretraining_tp > 1:
116              lm_head_slices = self.lm_head.weight.split(
117                  self.vocab_size // self.config.pretraining_tp, dim=0
118              )
119              logits = [
120                  F.linear(hidden_states, lm_head_slices[i])
121                  for i in range(self.config.pretraining_tp)
122              ]
123              logits = torch.cat(logits, dim=-1)
124          else:
125              logits = self.lm_head(hidden_states)
126          if labels is not None:
127              # Upcast to float if we need to compute the loss to avoid potential precision issues
128              logits = logits.float()
129              # Shift so that tokens < n predict n
130              shift_logits = logits[..., :-1, :].contiguous()
131              shift_labels = labels[..., 1:].contiguous()
132              # Flatten the tokens
133              loss_fct = CrossEntropyLoss()
134              shift_logits = shift_logits.view(-1, self.config.vocab_size)
135              shift_labels = shift_labels.view(-1)
136              # Enable model parallelism
137              shift_labels = shift_labels.to(shift_logits.device)
138              loss = loss_fct(shift_logits, shift_labels)
139  
140      if not return_dict:
141          output = (logits,) + outputs[1:]
142          return (loss,) + output if loss is not None else output
143  
144      return CausalLMOutputWithPast(
145          loss=loss,
146          logits=logits,
147          past_key_values=outputs.past_key_values,
148          hidden_states=outputs.hidden_states,
149          attentions=outputs.attentions,
150      )
151  
152  
153  @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
154  @replace_return_docstrings(
155      output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
156  )
157  def lce_forward(
158      self,
159      input_ids: torch.LongTensor = None,
160      attention_mask: Optional[torch.Tensor] = None,
161      position_ids: Optional[torch.LongTensor] = None,
162      past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None,
163      inputs_embeds: Optional[torch.FloatTensor] = None,
164      labels: Optional[torch.LongTensor] = None,
165      use_cache: Optional[bool] = None,
166      output_attentions: Optional[bool] = None,
167      output_hidden_states: Optional[bool] = None,
168      return_dict: Optional[bool] = None,
169      cache_position: Optional[torch.LongTensor] = None,
170      num_logits_to_keep: int = 0,
171      **loss_kwargs,
172  ) -> Union[Tuple, CausalLMOutputWithPast]:
173      r"""
174      Args:
175          labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
176              Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
177              config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
178              (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
179  
180          num_logits_to_keep (`int`, *optional*):
181              Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
182              `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
183              token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
184  
185      Returns:
186  
187      Example:
188  
189      ```python
190      >>> from transformers import AutoTokenizer, LlamaForCausalLM
191  
192      >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
193      >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
194  
195      >>> prompt = "Hey, are you conscious? Can you talk to me?"
196      >>> inputs = tokenizer(prompt, return_tensors="pt")
197  
198      >>> # Generate
199      >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
200      >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
201      "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
202      ```"""
203  
204      output_attentions = (
205          output_attentions
206          if output_attentions is not None
207          else self.config.output_attentions
208      )
209      output_hidden_states = (
210          output_hidden_states
211          if output_hidden_states is not None
212          else self.config.output_hidden_states
213      )
214      return_dict = (
215          return_dict if return_dict is not None else self.config.use_return_dict
216      )
217  
218      # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
219      outputs = self.model(
220          input_ids=input_ids,
221          attention_mask=attention_mask,
222          position_ids=position_ids,
223          past_key_values=past_key_values,
224          inputs_embeds=inputs_embeds,
225          use_cache=use_cache,
226          output_attentions=output_attentions,
227          output_hidden_states=output_hidden_states,
228          return_dict=return_dict,
229          cache_position=cache_position,
230      )
231  
232      hidden_states = outputs[0]
233  
234      if self.config.pretraining_tp > 1:
235          raise Exception("Liger Kernel does not support pretraining_tp!!")
236  
237      logits = None
238      loss = None
239      # if in training mode, don't materialize logits
240      if self.training and (labels is not None):
241          # We do the same thing as ForCausalLMLoss but using Liger FLCE
242  
243          shift_hidden_states = hidden_states[..., :-1, :].contiguous()
244          shift_labels = labels[..., 1:].contiguous()
245  
246          # flatten tokens
247          shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
248          shift_labels = shift_labels.view(-1)
249  
250          reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
251          lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
252  
253          loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
254          if reduction == "sum":
255              loss /= loss_kwargs["num_items_in_batch"]
256  
257      else:  # if in inference mode materialize logits
258          logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
259          if labels is not None:
260              loss = self.loss_function(
261                  logits=logits,
262                  labels=labels,
263                  vocab_size=self.config.vocab_size,
264                  **loss_kwargs,
265              )
266  
267      if not return_dict:
268          output = (logits,) + outputs[1:]
269          return (loss,) + output if loss is not None else output
270  
271      return CausalLMOutputWithPast(
272          loss=loss,
273          logits=logits,
274          past_key_values=outputs.past_key_values,
275          hidden_states=outputs.hidden_states,
276          attentions=outputs.attentions,
277      )