gemma.py
  1  from typing import List, Optional, Tuple, Union
  2  
  3  import torch
  4  from torch.nn import CrossEntropyLoss
  5  from transformers.cache_utils import Cache
  6  from transformers.modeling_outputs import CausalLMOutputWithPast
  7  from transformers.models.gemma.modeling_gemma import (
  8      _CONFIG_FOR_DOC,
  9      GEMMA_INPUTS_DOCSTRING,
 10  )
 11  from transformers.utils import (
 12      add_start_docstrings_to_model_forward,
 13      replace_return_docstrings,
 14  )
 15  
 16  from liger_kernel.transformers.fused_linear_cross_entropy import (
 17      LigerFusedLinearCrossEntropyLoss,
 18  )
 19  
 20  
 21  @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
 22  @replace_return_docstrings(
 23      output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
 24  )
 25  def lce_forward_deprecated(
 26      self,
 27      input_ids: torch.LongTensor = None,
 28      attention_mask: Optional[torch.Tensor] = None,
 29      position_ids: Optional[torch.LongTensor] = None,
 30      past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
 31      inputs_embeds: Optional[torch.FloatTensor] = None,
 32      labels: Optional[torch.LongTensor] = None,
 33      use_cache: Optional[bool] = None,
 34      output_attentions: Optional[bool] = None,
 35      output_hidden_states: Optional[bool] = None,
 36      return_dict: Optional[bool] = None,
 37      cache_position: Optional[torch.LongTensor] = None,
 38  ) -> Union[Tuple, CausalLMOutputWithPast]:
 39      r"""
 40  
 41      copy paste transformers.models.gemma.modeling_gemma causalLM with loss replaced with liger fused cross entropy
 42  
 43      Args:
 44          labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 45              Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
 46              config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
 47              (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
 48  
 49      Returns:
 50  
 51      Example:
 52  
 53      ```python
 54      >>> from transformers import AutoTokenizer, GemmaForCausalLM
 55  
 56      >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
 57      >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
 58  
 59      >>> prompt = "What is your favorite condiment?"
 60      >>> inputs = tokenizer(prompt, return_tensors="pt")
 61  
 62      >>> # Generate
 63      >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
 64      >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 65      "What is your favorite condiment?"
 66      ```"""
 67      output_attentions = (
 68          output_attentions
 69          if output_attentions is not None
 70          else self.config.output_attentions
 71      )
 72      output_hidden_states = (
 73          output_hidden_states
 74          if output_hidden_states is not None
 75          else self.config.output_hidden_states
 76      )
 77      return_dict = (
 78          return_dict if return_dict is not None else self.config.use_return_dict
 79      )
 80  
 81      # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
 82      outputs = self.model(
 83          input_ids=input_ids,
 84          attention_mask=attention_mask,
 85          position_ids=position_ids,
 86          past_key_values=past_key_values,
 87          inputs_embeds=inputs_embeds,
 88          use_cache=use_cache,
 89          output_attentions=output_attentions,
 90          output_hidden_states=output_hidden_states,
 91          return_dict=return_dict,
 92          cache_position=cache_position,
 93      )
 94  
 95      hidden_states = outputs[0]
 96  
 97      loss = None
 98      logits = None
 99  
100      if self.training and (labels is not None):
101          shift_hidden_states = hidden_states[..., :-1, :].contiguous()
102          shift_labels = labels[..., 1:].contiguous()
103  
104          # flatten
105  
106          shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
107          shift_labels = shift_labels.view(-1)
108  
109          lce = LigerFusedLinearCrossEntropyLoss()
110          loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
111  
112      else:
113          logits = self.lm_head(hidden_states)
114          if labels is not None:
115              # Upcast to float if we need to compute the loss to avoid potential precision issues
116              logits = logits.float()
117              # Shift so that tokens < n predict n
118              shift_logits = logits[..., :-1, :].contiguous()
119              shift_labels = labels[..., 1:].contiguous()
120              # Flatten the tokens
121              shift_logits = shift_logits.view(-1, self.config.vocab_size)
122              shift_labels = shift_labels.view(-1)
123              # Ensure tensors are on the same device
124              shift_labels = shift_labels.to(shift_logits.device)
125              loss_fct = CrossEntropyLoss()
126              loss = loss_fct(shift_logits, shift_labels)
127  
128      if not return_dict:
129          output = (logits,) + outputs[1:]
130          return (loss,) + output if loss is not None else output
131  
132      return CausalLMOutputWithPast(
133          loss=loss,
134          logits=logits,
135          past_key_values=outputs.past_key_values,
136          hidden_states=outputs.hidden_states,
137          attentions=outputs.attentions,
138      )
139  
140  
141  @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
142  @replace_return_docstrings(
143      output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
144  )
145  def lce_forward(
146      self,
147      input_ids: torch.LongTensor = None,
148      attention_mask: Optional[torch.Tensor] = None,
149      position_ids: Optional[torch.LongTensor] = None,
150      past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
151      inputs_embeds: Optional[torch.FloatTensor] = None,
152      labels: Optional[torch.LongTensor] = None,
153      use_cache: Optional[bool] = None,
154      output_attentions: Optional[bool] = None,
155      output_hidden_states: Optional[bool] = None,
156      return_dict: Optional[bool] = None,
157      cache_position: Optional[torch.LongTensor] = None,
158      num_logits_to_keep: int = 0,
159      **loss_kwargs,
160  ) -> Union[Tuple, CausalLMOutputWithPast]:
161      r"""
162      Args:
163          labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
164              Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
165              config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
166              (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
167  
168          num_logits_to_keep (`int`, *optional*):
169              Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
170              `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
171              token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
172  
173      Returns:
174  
175      Example:
176  
177      ```python
178      >>> from transformers import AutoTokenizer, GemmaForCausalLM
179  
180      >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
181      >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
182  
183      >>> prompt = "What is your favorite condiment?"
184      >>> inputs = tokenizer(prompt, return_tensors="pt")
185  
186      >>> # Generate
187      >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
188      >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
189      "What is your favorite condiment?"
190      ```"""
191      output_attentions = (
192          output_attentions
193          if output_attentions is not None
194          else self.config.output_attentions
195      )
196      output_hidden_states = (
197          output_hidden_states
198          if output_hidden_states is not None
199          else self.config.output_hidden_states
200      )
201      return_dict = (
202          return_dict if return_dict is not None else self.config.use_return_dict
203      )
204  
205      # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
206      outputs = self.model(
207          input_ids=input_ids,
208          attention_mask=attention_mask,
209          position_ids=position_ids,
210          past_key_values=past_key_values,
211          inputs_embeds=inputs_embeds,
212          use_cache=use_cache,
213          output_attentions=output_attentions,
214          output_hidden_states=output_hidden_states,
215          return_dict=return_dict,
216          cache_position=cache_position,
217      )
218  
219      hidden_states = outputs[0]
220  
221      logits = None
222      loss = None
223      # if in training mode, don't materialize logits
224      if self.training and (labels is not None):
225          # We do the same thing as ForCausalLMLoss but using Liger FLCE
226  
227          shift_hidden_states = hidden_states[..., :-1, :].contiguous()
228          shift_labels = labels[..., 1:].contiguous()
229  
230          # flatten tokens
231          shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
232          shift_labels = shift_labels.view(-1)
233  
234          reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
235          lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
236  
237          loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
238          if reduction == "sum":
239              loss /= loss_kwargs["num_items_in_batch"]
240  
241      else:  # if in inference mode materialize logits
242          logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
243          if labels is not None:
244              loss = self.loss_function(
245                  logits=logits,
246                  labels=labels,
247                  vocab_size=self.config.vocab_size,
248                  **loss_kwargs,
249              )
250  
251      if not return_dict:
252          output = (logits,) + outputs[1:]
253          return (loss,) + output if loss is not None else output
254  
255      return CausalLMOutputWithPast(
256          loss=loss,
257          logits=logits,
258          past_key_values=outputs.past_key_values,
259          hidden_states=outputs.hidden_states,
260          attentions=outputs.attentions,
261      )