gemma2.py
1 import logging 2 from typing import Optional, Tuple, Union 3 4 import torch 5 from torch.nn import CrossEntropyLoss 6 from transformers.cache_utils import HybridCache 7 from transformers.modeling_outputs import CausalLMOutputWithPast 8 from transformers.models.gemma2.modeling_gemma2 import ( 9 _CONFIG_FOR_DOC, 10 GEMMA2_INPUTS_DOCSTRING, 11 ) 12 from transformers.utils import ( 13 add_start_docstrings_to_model_forward, 14 replace_return_docstrings, 15 ) 16 17 from liger_kernel.transformers.fused_linear_cross_entropy import ( 18 LigerFusedLinearCrossEntropyLoss, 19 ) 20 21 logger = logging.getLogger(__name__) 22 23 24 def lce_forward_deprecated( 25 self, 26 input_ids: torch.LongTensor = None, 27 attention_mask: Optional[torch.Tensor] = None, 28 position_ids: Optional[torch.LongTensor] = None, 29 past_key_values: Optional[HybridCache] = None, 30 inputs_embeds: Optional[torch.FloatTensor] = None, 31 labels: Optional[torch.LongTensor] = None, 32 use_cache: Optional[bool] = None, 33 output_attentions: Optional[bool] = None, 34 output_hidden_states: Optional[bool] = None, 35 return_dict: Optional[bool] = None, 36 cache_position: Optional[torch.LongTensor] = None, 37 ) -> Union[Tuple, CausalLMOutputWithPast]: 38 r""" 39 Args: 40 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 41 Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 42 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 43 (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 44 45 Returns: 46 47 Example: 48 49 ```python 50 >>> from transformers import AutoTokenizer, GemmaForCausalLM 51 >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b") 52 >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") 53 >>> prompt = "What is your favorite condiment?" 54 >>> inputs = tokenizer(prompt, return_tensors="pt") 55 >>> # Generate 56 >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 57 >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 58 "What is your favorite condiment?" 59 ```""" 60 61 if self.training and self.config._attn_implementation != "eager": 62 logger.warning_once( 63 "It is strongly recommended to train Gemma2 models with the `eager` attention implementation " 64 f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`." 65 ) 66 output_attentions = ( 67 output_attentions 68 if output_attentions is not None 69 else self.config.output_attentions 70 ) 71 output_hidden_states = ( 72 output_hidden_states 73 if output_hidden_states is not None 74 else self.config.output_hidden_states 75 ) 76 return_dict = ( 77 return_dict if return_dict is not None else self.config.use_return_dict 78 ) 79 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 80 outputs = self.model( 81 input_ids=input_ids, 82 attention_mask=attention_mask, 83 position_ids=position_ids, 84 past_key_values=past_key_values, 85 inputs_embeds=inputs_embeds, 86 use_cache=use_cache, 87 output_attentions=output_attentions, 88 output_hidden_states=output_hidden_states, 89 return_dict=return_dict, 90 cache_position=cache_position, 91 ) 92 93 hidden_states = outputs[0] 94 95 loss = None 96 logits = None 97 98 if self.training and (labels is not None): 99 shift_hidden_states = hidden_states[..., :-1, :].contiguous() 100 shift_labels = labels[..., 1:].contiguous() 101 102 # flatten 103 104 shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) 105 shift_labels = shift_labels.view(-1) 106 107 lce = LigerFusedLinearCrossEntropyLoss( 108 softcap=self.config.final_logit_softcapping 109 ) 110 loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) 111 112 else: 113 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss 114 logits = self.lm_head(hidden_states) 115 if self.config.final_logit_softcapping is not None: 116 logits = logits / self.config.final_logit_softcapping 117 logits = torch.tanh(logits) 118 logits = logits * self.config.final_logit_softcapping 119 120 loss = None 121 if labels is not None: 122 # Upcast to float if we need to compute the loss to avoid potential precision issues 123 logits = logits.float() 124 # Shift so that tokens < n predict n 125 shift_logits = logits[..., :-1, :].contiguous() 126 shift_labels = labels[..., 1:].contiguous() 127 # Flatten the tokens 128 loss_fct = CrossEntropyLoss() 129 shift_logits = shift_logits.view(-1, self.config.vocab_size) 130 shift_labels = shift_labels.view(-1) 131 # Enable model parallelism 132 shift_labels = shift_labels.to(shift_logits.device) 133 loss = loss_fct(shift_logits, shift_labels) 134 135 if not return_dict: 136 output = (logits,) + outputs[1:] 137 return (loss,) + output if loss is not None else output 138 139 return CausalLMOutputWithPast( 140 loss=loss, 141 logits=logits, 142 past_key_values=outputs.past_key_values, 143 hidden_states=outputs.hidden_states, 144 attentions=outputs.attentions, 145 ) 146 147 148 @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) 149 @replace_return_docstrings( 150 output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC 151 ) 152 def lce_forward( 153 self, 154 input_ids: torch.LongTensor = None, 155 attention_mask: Optional[torch.Tensor] = None, 156 position_ids: Optional[torch.LongTensor] = None, 157 past_key_values: Optional[HybridCache] = None, 158 inputs_embeds: Optional[torch.FloatTensor] = None, 159 labels: Optional[torch.LongTensor] = None, 160 use_cache: Optional[bool] = None, 161 output_attentions: Optional[bool] = None, 162 output_hidden_states: Optional[bool] = None, 163 return_dict: Optional[bool] = None, 164 cache_position: Optional[torch.LongTensor] = None, 165 num_logits_to_keep: int = 0, 166 **loss_kwargs, 167 ) -> Union[Tuple, CausalLMOutputWithPast]: 168 r""" 169 Args: 170 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 171 Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 172 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 173 (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 174 175 num_logits_to_keep (`int`, *optional*): 176 Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all 177 `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that 178 token can save memory, which becomes pretty significant for long sequences or large vocabulary size. 179 180 Returns: 181 182 Example: 183 184 ```python 185 >>> from transformers import AutoTokenizer, GemmaForCausalLM 186 187 >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b") 188 >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") 189 190 >>> prompt = "What is your favorite condiment?" 191 >>> inputs = tokenizer(prompt, return_tensors="pt") 192 193 >>> # Generate 194 >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 195 >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 196 "What is your favorite condiment?" 197 ```""" 198 199 if self.training and self.config._attn_implementation != "eager": 200 logger.warning_once( 201 "It is strongly recommended to train Gemma2 models with the `eager` attention implementation " 202 f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`." 203 ) 204 output_attentions = ( 205 output_attentions 206 if output_attentions is not None 207 else self.config.output_attentions 208 ) 209 output_hidden_states = ( 210 output_hidden_states 211 if output_hidden_states is not None 212 else self.config.output_hidden_states 213 ) 214 return_dict = ( 215 return_dict if return_dict is not None else self.config.use_return_dict 216 ) 217 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 218 outputs = self.model( 219 input_ids=input_ids, 220 attention_mask=attention_mask, 221 position_ids=position_ids, 222 past_key_values=past_key_values, 223 inputs_embeds=inputs_embeds, 224 use_cache=use_cache, 225 output_attentions=output_attentions, 226 output_hidden_states=output_hidden_states, 227 return_dict=return_dict, 228 cache_position=cache_position, 229 ) 230 231 hidden_states = outputs[0] 232 233 logits = None 234 loss = None 235 # if in training mode, don't materialize logits 236 if self.training and (labels is not None): 237 # We do the same thing as ForCausalLMLoss but using Liger FLCE 238 239 shift_hidden_states = hidden_states[..., :-1, :].contiguous() 240 shift_labels = labels[..., 1:].contiguous() 241 242 # flatten tokens 243 shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) 244 shift_labels = shift_labels.view(-1) 245 246 reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean" 247 lce = LigerFusedLinearCrossEntropyLoss( 248 softcap=self.config.final_logit_softcapping, 249 reduction=reduction, 250 ) 251 252 loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) 253 if reduction == "sum": 254 loss /= loss_kwargs["num_items_in_batch"] 255 256 else: # if in inference mode materialize logits 257 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) 258 if self.config.final_logit_softcapping is not None: 259 logits = logits / self.config.final_logit_softcapping 260 logits = torch.tanh(logits) 261 logits = logits * self.config.final_logit_softcapping 262 263 loss = None 264 if labels is not None: 265 loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) 266 267 if not return_dict: 268 output = (logits,) + outputs[1:] 269 return (loss,) + output if loss is not None else output 270 271 return CausalLMOutputWithPast( 272 loss=loss, 273 logits=logits, 274 past_key_values=outputs.past_key_values, 275 hidden_states=outputs.hidden_states, 276 attentions=outputs.attentions, 277 )