gemma.py
1 from typing import List, Optional, Tuple, Union 2 3 import torch 4 from torch.nn import CrossEntropyLoss 5 from transformers.cache_utils import Cache 6 from transformers.modeling_outputs import CausalLMOutputWithPast 7 from transformers.models.gemma.modeling_gemma import ( 8 _CONFIG_FOR_DOC, 9 GEMMA_INPUTS_DOCSTRING, 10 ) 11 from transformers.utils import ( 12 add_start_docstrings_to_model_forward, 13 replace_return_docstrings, 14 ) 15 16 from liger_kernel.transformers.fused_linear_cross_entropy import ( 17 LigerFusedLinearCrossEntropyLoss, 18 ) 19 20 21 @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) 22 @replace_return_docstrings( 23 output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC 24 ) 25 def lce_forward_deprecated( 26 self, 27 input_ids: torch.LongTensor = None, 28 attention_mask: Optional[torch.Tensor] = None, 29 position_ids: Optional[torch.LongTensor] = None, 30 past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 31 inputs_embeds: Optional[torch.FloatTensor] = None, 32 labels: Optional[torch.LongTensor] = None, 33 use_cache: Optional[bool] = None, 34 output_attentions: Optional[bool] = None, 35 output_hidden_states: Optional[bool] = None, 36 return_dict: Optional[bool] = None, 37 cache_position: Optional[torch.LongTensor] = None, 38 ) -> Union[Tuple, CausalLMOutputWithPast]: 39 r""" 40 41 copy paste transformers.models.gemma.modeling_gemma causalLM with loss replaced with liger fused cross entropy 42 43 Args: 44 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 45 Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 46 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 47 (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 48 49 Returns: 50 51 Example: 52 53 ```python 54 >>> from transformers import AutoTokenizer, GemmaForCausalLM 55 56 >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") 57 >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") 58 59 >>> prompt = "What is your favorite condiment?" 60 >>> inputs = tokenizer(prompt, return_tensors="pt") 61 62 >>> # Generate 63 >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 64 >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 65 "What is your favorite condiment?" 66 ```""" 67 output_attentions = ( 68 output_attentions 69 if output_attentions is not None 70 else self.config.output_attentions 71 ) 72 output_hidden_states = ( 73 output_hidden_states 74 if output_hidden_states is not None 75 else self.config.output_hidden_states 76 ) 77 return_dict = ( 78 return_dict if return_dict is not None else self.config.use_return_dict 79 ) 80 81 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 82 outputs = self.model( 83 input_ids=input_ids, 84 attention_mask=attention_mask, 85 position_ids=position_ids, 86 past_key_values=past_key_values, 87 inputs_embeds=inputs_embeds, 88 use_cache=use_cache, 89 output_attentions=output_attentions, 90 output_hidden_states=output_hidden_states, 91 return_dict=return_dict, 92 cache_position=cache_position, 93 ) 94 95 hidden_states = outputs[0] 96 97 loss = None 98 logits = None 99 100 if self.training and (labels is not None): 101 shift_hidden_states = hidden_states[..., :-1, :].contiguous() 102 shift_labels = labels[..., 1:].contiguous() 103 104 # flatten 105 106 shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) 107 shift_labels = shift_labels.view(-1) 108 109 lce = LigerFusedLinearCrossEntropyLoss() 110 loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) 111 112 else: 113 logits = self.lm_head(hidden_states) 114 if labels is not None: 115 # Upcast to float if we need to compute the loss to avoid potential precision issues 116 logits = logits.float() 117 # Shift so that tokens < n predict n 118 shift_logits = logits[..., :-1, :].contiguous() 119 shift_labels = labels[..., 1:].contiguous() 120 # Flatten the tokens 121 shift_logits = shift_logits.view(-1, self.config.vocab_size) 122 shift_labels = shift_labels.view(-1) 123 # Ensure tensors are on the same device 124 shift_labels = shift_labels.to(shift_logits.device) 125 loss_fct = CrossEntropyLoss() 126 loss = loss_fct(shift_logits, shift_labels) 127 128 if not return_dict: 129 output = (logits,) + outputs[1:] 130 return (loss,) + output if loss is not None else output 131 132 return CausalLMOutputWithPast( 133 loss=loss, 134 logits=logits, 135 past_key_values=outputs.past_key_values, 136 hidden_states=outputs.hidden_states, 137 attentions=outputs.attentions, 138 ) 139 140 141 @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) 142 @replace_return_docstrings( 143 output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC 144 ) 145 def lce_forward( 146 self, 147 input_ids: torch.LongTensor = None, 148 attention_mask: Optional[torch.Tensor] = None, 149 position_ids: Optional[torch.LongTensor] = None, 150 past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 151 inputs_embeds: Optional[torch.FloatTensor] = None, 152 labels: Optional[torch.LongTensor] = None, 153 use_cache: Optional[bool] = None, 154 output_attentions: Optional[bool] = None, 155 output_hidden_states: Optional[bool] = None, 156 return_dict: Optional[bool] = None, 157 cache_position: Optional[torch.LongTensor] = None, 158 num_logits_to_keep: int = 0, 159 **loss_kwargs, 160 ) -> Union[Tuple, CausalLMOutputWithPast]: 161 r""" 162 Args: 163 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 164 Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 165 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 166 (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 167 168 num_logits_to_keep (`int`, *optional*): 169 Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all 170 `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that 171 token can save memory, which becomes pretty significant for long sequences or large vocabulary size. 172 173 Returns: 174 175 Example: 176 177 ```python 178 >>> from transformers import AutoTokenizer, GemmaForCausalLM 179 180 >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") 181 >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") 182 183 >>> prompt = "What is your favorite condiment?" 184 >>> inputs = tokenizer(prompt, return_tensors="pt") 185 186 >>> # Generate 187 >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 188 >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 189 "What is your favorite condiment?" 190 ```""" 191 output_attentions = ( 192 output_attentions 193 if output_attentions is not None 194 else self.config.output_attentions 195 ) 196 output_hidden_states = ( 197 output_hidden_states 198 if output_hidden_states is not None 199 else self.config.output_hidden_states 200 ) 201 return_dict = ( 202 return_dict if return_dict is not None else self.config.use_return_dict 203 ) 204 205 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 206 outputs = self.model( 207 input_ids=input_ids, 208 attention_mask=attention_mask, 209 position_ids=position_ids, 210 past_key_values=past_key_values, 211 inputs_embeds=inputs_embeds, 212 use_cache=use_cache, 213 output_attentions=output_attentions, 214 output_hidden_states=output_hidden_states, 215 return_dict=return_dict, 216 cache_position=cache_position, 217 ) 218 219 hidden_states = outputs[0] 220 221 logits = None 222 loss = None 223 # if in training mode, don't materialize logits 224 if self.training and (labels is not None): 225 # We do the same thing as ForCausalLMLoss but using Liger FLCE 226 227 shift_hidden_states = hidden_states[..., :-1, :].contiguous() 228 shift_labels = labels[..., 1:].contiguous() 229 230 # flatten tokens 231 shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) 232 shift_labels = shift_labels.view(-1) 233 234 reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean" 235 lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction) 236 237 loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) 238 if reduction == "sum": 239 loss /= loss_kwargs["num_items_in_batch"] 240 241 else: # if in inference mode materialize logits 242 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) 243 if labels is not None: 244 loss = self.loss_function( 245 logits=logits, 246 labels=labels, 247 vocab_size=self.config.vocab_size, 248 **loss_kwargs, 249 ) 250 251 if not return_dict: 252 output = (logits,) + outputs[1:] 253 return (loss,) + output if loss is not None else output 254 255 return CausalLMOutputWithPast( 256 loss=loss, 257 logits=logits, 258 past_key_values=outputs.past_key_values, 259 hidden_states=outputs.hidden_states, 260 attentions=outputs.attentions, 261 )