llama.py
1 from typing import TYPE_CHECKING, List, Optional, Tuple, Union 2 3 import torch 4 import torch.nn.functional as F 5 from torch.nn import CrossEntropyLoss 6 from transformers.modeling_outputs import CausalLMOutputWithPast 7 from transformers.models.llama.modeling_llama import ( 8 _CONFIG_FOR_DOC, 9 LLAMA_INPUTS_DOCSTRING, 10 ) 11 from transformers.utils import ( 12 add_start_docstrings_to_model_forward, 13 replace_return_docstrings, 14 ) 15 16 from liger_kernel.transformers.fused_linear_cross_entropy import ( 17 LigerFusedLinearCrossEntropyLoss, 18 ) 19 20 if TYPE_CHECKING: 21 from transformers.cache_utils import Cache 22 23 24 @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 25 @replace_return_docstrings( 26 output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC 27 ) 28 def lce_forward_deprecated( 29 self, 30 input_ids: torch.LongTensor = None, 31 attention_mask: Optional[torch.Tensor] = None, 32 position_ids: Optional[torch.LongTensor] = None, 33 past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None, 34 inputs_embeds: Optional[torch.FloatTensor] = None, 35 labels: Optional[torch.LongTensor] = None, 36 use_cache: Optional[bool] = None, 37 output_attentions: Optional[bool] = None, 38 output_hidden_states: Optional[bool] = None, 39 return_dict: Optional[bool] = None, 40 cache_position: Optional[torch.LongTensor] = None, 41 ) -> Union[Tuple, CausalLMOutputWithPast]: 42 r""" 43 Copy paste llama forward but replace torch cross entropy with liger fused linear cross entropy 44 45 46 Args: 47 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 48 Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 49 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 50 (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 51 52 Returns: 53 54 Example: 55 56 ```python 57 >>> from transformers import AutoTokenizer, LlamaForCausalLM 58 59 >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") 60 >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") 61 62 >>> prompt = "Hey, are you conscious? Can you talk to me?" 63 >>> inputs = tokenizer(prompt, return_tensors="pt") 64 65 >>> # Generate 66 >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 67 >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 68 "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 69 ```""" 70 output_attentions = ( 71 output_attentions 72 if output_attentions is not None 73 else self.config.output_attentions 74 ) 75 output_hidden_states = ( 76 output_hidden_states 77 if output_hidden_states is not None 78 else self.config.output_hidden_states 79 ) 80 return_dict = ( 81 return_dict if return_dict is not None else self.config.use_return_dict 82 ) 83 84 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 85 outputs = self.model( 86 input_ids=input_ids, 87 attention_mask=attention_mask, 88 position_ids=position_ids, 89 past_key_values=past_key_values, 90 inputs_embeds=inputs_embeds, 91 use_cache=use_cache, 92 output_attentions=output_attentions, 93 output_hidden_states=output_hidden_states, 94 return_dict=return_dict, 95 cache_position=cache_position, 96 ) 97 98 hidden_states = outputs[0] 99 100 loss = None 101 logits = None 102 103 if self.training and (labels is not None): 104 shift_hidden_states = hidden_states[..., :-1, :].contiguous() 105 shift_labels = labels[..., 1:].contiguous() 106 107 # flatten tokens 108 shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) 109 shift_labels = shift_labels.view(-1) 110 111 lce = LigerFusedLinearCrossEntropyLoss() 112 loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) 113 114 else: 115 if self.config.pretraining_tp > 1: 116 lm_head_slices = self.lm_head.weight.split( 117 self.vocab_size // self.config.pretraining_tp, dim=0 118 ) 119 logits = [ 120 F.linear(hidden_states, lm_head_slices[i]) 121 for i in range(self.config.pretraining_tp) 122 ] 123 logits = torch.cat(logits, dim=-1) 124 else: 125 logits = self.lm_head(hidden_states) 126 if labels is not None: 127 # Upcast to float if we need to compute the loss to avoid potential precision issues 128 logits = logits.float() 129 # Shift so that tokens < n predict n 130 shift_logits = logits[..., :-1, :].contiguous() 131 shift_labels = labels[..., 1:].contiguous() 132 # Flatten the tokens 133 loss_fct = CrossEntropyLoss() 134 shift_logits = shift_logits.view(-1, self.config.vocab_size) 135 shift_labels = shift_labels.view(-1) 136 # Enable model parallelism 137 shift_labels = shift_labels.to(shift_logits.device) 138 loss = loss_fct(shift_logits, shift_labels) 139 140 if not return_dict: 141 output = (logits,) + outputs[1:] 142 return (loss,) + output if loss is not None else output 143 144 return CausalLMOutputWithPast( 145 loss=loss, 146 logits=logits, 147 past_key_values=outputs.past_key_values, 148 hidden_states=outputs.hidden_states, 149 attentions=outputs.attentions, 150 ) 151 152 153 @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 154 @replace_return_docstrings( 155 output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC 156 ) 157 def lce_forward( 158 self, 159 input_ids: torch.LongTensor = None, 160 attention_mask: Optional[torch.Tensor] = None, 161 position_ids: Optional[torch.LongTensor] = None, 162 past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None, 163 inputs_embeds: Optional[torch.FloatTensor] = None, 164 labels: Optional[torch.LongTensor] = None, 165 use_cache: Optional[bool] = None, 166 output_attentions: Optional[bool] = None, 167 output_hidden_states: Optional[bool] = None, 168 return_dict: Optional[bool] = None, 169 cache_position: Optional[torch.LongTensor] = None, 170 num_logits_to_keep: int = 0, 171 **loss_kwargs, 172 ) -> Union[Tuple, CausalLMOutputWithPast]: 173 r""" 174 Args: 175 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 176 Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 177 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 178 (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 179 180 num_logits_to_keep (`int`, *optional*): 181 Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all 182 `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that 183 token can save memory, which becomes pretty significant for long sequences or large vocabulary size. 184 185 Returns: 186 187 Example: 188 189 ```python 190 >>> from transformers import AutoTokenizer, LlamaForCausalLM 191 192 >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") 193 >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") 194 195 >>> prompt = "Hey, are you conscious? Can you talk to me?" 196 >>> inputs = tokenizer(prompt, return_tensors="pt") 197 198 >>> # Generate 199 >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 200 >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 201 "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 202 ```""" 203 204 output_attentions = ( 205 output_attentions 206 if output_attentions is not None 207 else self.config.output_attentions 208 ) 209 output_hidden_states = ( 210 output_hidden_states 211 if output_hidden_states is not None 212 else self.config.output_hidden_states 213 ) 214 return_dict = ( 215 return_dict if return_dict is not None else self.config.use_return_dict 216 ) 217 218 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 219 outputs = self.model( 220 input_ids=input_ids, 221 attention_mask=attention_mask, 222 position_ids=position_ids, 223 past_key_values=past_key_values, 224 inputs_embeds=inputs_embeds, 225 use_cache=use_cache, 226 output_attentions=output_attentions, 227 output_hidden_states=output_hidden_states, 228 return_dict=return_dict, 229 cache_position=cache_position, 230 ) 231 232 hidden_states = outputs[0] 233 234 if self.config.pretraining_tp > 1: 235 raise Exception("Liger Kernel does not support pretraining_tp!!") 236 237 logits = None 238 loss = None 239 # if in training mode, don't materialize logits 240 if self.training and (labels is not None): 241 # We do the same thing as ForCausalLMLoss but using Liger FLCE 242 243 shift_hidden_states = hidden_states[..., :-1, :].contiguous() 244 shift_labels = labels[..., 1:].contiguous() 245 246 # flatten tokens 247 shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) 248 shift_labels = shift_labels.view(-1) 249 250 reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean" 251 lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction) 252 253 loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) 254 if reduction == "sum": 255 loss /= loss_kwargs["num_items_in_batch"] 256 257 else: # if in inference mode materialize logits 258 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) 259 if labels is not None: 260 loss = self.loss_function( 261 logits=logits, 262 labels=labels, 263 vocab_size=self.config.vocab_size, 264 **loss_kwargs, 265 ) 266 267 if not return_dict: 268 output = (logits,) + outputs[1:] 269 return (loss,) + output if loss is not None else output 270 271 return CausalLMOutputWithPast( 272 loss=loss, 273 logits=logits, 274 past_key_values=outputs.past_key_values, 275 hidden_states=outputs.hidden_states, 276 attentions=outputs.attentions, 277 )