/ src / liger_kernel / transformers / model / gemma2.py
gemma2.py
  1  import logging
  2  from typing import Optional, Tuple, Union
  3  
  4  import torch
  5  from torch.nn import CrossEntropyLoss
  6  from transformers.cache_utils import HybridCache
  7  from transformers.modeling_outputs import CausalLMOutputWithPast
  8  from transformers.models.gemma2.modeling_gemma2 import (
  9      _CONFIG_FOR_DOC,
 10      GEMMA2_INPUTS_DOCSTRING,
 11  )
 12  from transformers.utils import (
 13      add_start_docstrings_to_model_forward,
 14      replace_return_docstrings,
 15  )
 16  
 17  from liger_kernel.transformers.fused_linear_cross_entropy import (
 18      LigerFusedLinearCrossEntropyLoss,
 19  )
 20  
 21  logger = logging.getLogger(__name__)
 22  
 23  
 24  def lce_forward_deprecated(
 25      self,
 26      input_ids: torch.LongTensor = None,
 27      attention_mask: Optional[torch.Tensor] = None,
 28      position_ids: Optional[torch.LongTensor] = None,
 29      past_key_values: Optional[HybridCache] = None,
 30      inputs_embeds: Optional[torch.FloatTensor] = None,
 31      labels: Optional[torch.LongTensor] = None,
 32      use_cache: Optional[bool] = None,
 33      output_attentions: Optional[bool] = None,
 34      output_hidden_states: Optional[bool] = None,
 35      return_dict: Optional[bool] = None,
 36      cache_position: Optional[torch.LongTensor] = None,
 37  ) -> Union[Tuple, CausalLMOutputWithPast]:
 38      r"""
 39      Args:
 40          labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 41              Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
 42              config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
 43              (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
 44  
 45      Returns:
 46  
 47      Example:
 48  
 49      ```python
 50      >>> from transformers import AutoTokenizer, GemmaForCausalLM
 51      >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b")
 52      >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
 53      >>> prompt = "What is your favorite condiment?"
 54      >>> inputs = tokenizer(prompt, return_tensors="pt")
 55      >>> # Generate
 56      >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
 57      >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 58      "What is your favorite condiment?"
 59      ```"""
 60  
 61      if self.training and self.config._attn_implementation != "eager":
 62          logger.warning_once(
 63              "It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
 64              f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
 65          )
 66      output_attentions = (
 67          output_attentions
 68          if output_attentions is not None
 69          else self.config.output_attentions
 70      )
 71      output_hidden_states = (
 72          output_hidden_states
 73          if output_hidden_states is not None
 74          else self.config.output_hidden_states
 75      )
 76      return_dict = (
 77          return_dict if return_dict is not None else self.config.use_return_dict
 78      )
 79      # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
 80      outputs = self.model(
 81          input_ids=input_ids,
 82          attention_mask=attention_mask,
 83          position_ids=position_ids,
 84          past_key_values=past_key_values,
 85          inputs_embeds=inputs_embeds,
 86          use_cache=use_cache,
 87          output_attentions=output_attentions,
 88          output_hidden_states=output_hidden_states,
 89          return_dict=return_dict,
 90          cache_position=cache_position,
 91      )
 92  
 93      hidden_states = outputs[0]
 94  
 95      loss = None
 96      logits = None
 97  
 98      if self.training and (labels is not None):
 99          shift_hidden_states = hidden_states[..., :-1, :].contiguous()
100          shift_labels = labels[..., 1:].contiguous()
101  
102          # flatten
103  
104          shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
105          shift_labels = shift_labels.view(-1)
106  
107          lce = LigerFusedLinearCrossEntropyLoss(
108              softcap=self.config.final_logit_softcapping
109          )
110          loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
111  
112      else:
113          # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
114          logits = self.lm_head(hidden_states)
115          if self.config.final_logit_softcapping is not None:
116              logits = logits / self.config.final_logit_softcapping
117              logits = torch.tanh(logits)
118              logits = logits * self.config.final_logit_softcapping
119  
120          loss = None
121          if labels is not None:
122              # Upcast to float if we need to compute the loss to avoid potential precision issues
123              logits = logits.float()
124              # Shift so that tokens < n predict n
125              shift_logits = logits[..., :-1, :].contiguous()
126              shift_labels = labels[..., 1:].contiguous()
127              # Flatten the tokens
128              loss_fct = CrossEntropyLoss()
129              shift_logits = shift_logits.view(-1, self.config.vocab_size)
130              shift_labels = shift_labels.view(-1)
131              # Enable model parallelism
132              shift_labels = shift_labels.to(shift_logits.device)
133              loss = loss_fct(shift_logits, shift_labels)
134  
135      if not return_dict:
136          output = (logits,) + outputs[1:]
137          return (loss,) + output if loss is not None else output
138  
139      return CausalLMOutputWithPast(
140          loss=loss,
141          logits=logits,
142          past_key_values=outputs.past_key_values,
143          hidden_states=outputs.hidden_states,
144          attentions=outputs.attentions,
145      )
146  
147  
148  @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
149  @replace_return_docstrings(
150      output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
151  )
152  def lce_forward(
153      self,
154      input_ids: torch.LongTensor = None,
155      attention_mask: Optional[torch.Tensor] = None,
156      position_ids: Optional[torch.LongTensor] = None,
157      past_key_values: Optional[HybridCache] = None,
158      inputs_embeds: Optional[torch.FloatTensor] = None,
159      labels: Optional[torch.LongTensor] = None,
160      use_cache: Optional[bool] = None,
161      output_attentions: Optional[bool] = None,
162      output_hidden_states: Optional[bool] = None,
163      return_dict: Optional[bool] = None,
164      cache_position: Optional[torch.LongTensor] = None,
165      num_logits_to_keep: int = 0,
166      **loss_kwargs,
167  ) -> Union[Tuple, CausalLMOutputWithPast]:
168      r"""
169      Args:
170          labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
171              Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
172              config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
173              (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
174  
175          num_logits_to_keep (`int`, *optional*):
176              Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
177              `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
178              token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
179  
180      Returns:
181  
182      Example:
183  
184      ```python
185      >>> from transformers import AutoTokenizer, GemmaForCausalLM
186  
187      >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b")
188      >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
189  
190      >>> prompt = "What is your favorite condiment?"
191      >>> inputs = tokenizer(prompt, return_tensors="pt")
192  
193      >>> # Generate
194      >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
195      >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
196      "What is your favorite condiment?"
197      ```"""
198  
199      if self.training and self.config._attn_implementation != "eager":
200          logger.warning_once(
201              "It is strongly recommended to train Gemma2 models with the `eager` attention implementation "
202              f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`."
203          )
204      output_attentions = (
205          output_attentions
206          if output_attentions is not None
207          else self.config.output_attentions
208      )
209      output_hidden_states = (
210          output_hidden_states
211          if output_hidden_states is not None
212          else self.config.output_hidden_states
213      )
214      return_dict = (
215          return_dict if return_dict is not None else self.config.use_return_dict
216      )
217      # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
218      outputs = self.model(
219          input_ids=input_ids,
220          attention_mask=attention_mask,
221          position_ids=position_ids,
222          past_key_values=past_key_values,
223          inputs_embeds=inputs_embeds,
224          use_cache=use_cache,
225          output_attentions=output_attentions,
226          output_hidden_states=output_hidden_states,
227          return_dict=return_dict,
228          cache_position=cache_position,
229      )
230  
231      hidden_states = outputs[0]
232  
233      logits = None
234      loss = None
235      # if in training mode, don't materialize logits
236      if self.training and (labels is not None):
237          # We do the same thing as ForCausalLMLoss but using Liger FLCE
238  
239          shift_hidden_states = hidden_states[..., :-1, :].contiguous()
240          shift_labels = labels[..., 1:].contiguous()
241  
242          # flatten tokens
243          shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
244          shift_labels = shift_labels.view(-1)
245  
246          reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
247          lce = LigerFusedLinearCrossEntropyLoss(
248              softcap=self.config.final_logit_softcapping,
249              reduction=reduction,
250          )
251  
252          loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
253          if reduction == "sum":
254              loss /= loss_kwargs["num_items_in_batch"]
255  
256      else:  # if in inference mode materialize logits
257          logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
258          if self.config.final_logit_softcapping is not None:
259              logits = logits / self.config.final_logit_softcapping
260              logits = torch.tanh(logits)
261              logits = logits * self.config.final_logit_softcapping
262  
263          loss = None
264          if labels is not None:
265              loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
266  
267      if not return_dict:
268          output = (logits,) + outputs[1:]
269          return (loss,) + output if loss is not None else output
270  
271      return CausalLMOutputWithPast(
272          loss=loss,
273          logits=logits,
274          past_key_values=outputs.past_key_values,
275          hidden_states=outputs.hidden_states,
276          attentions=outputs.attentions,
277      )