phi3.py
1 from typing import List, Optional, Tuple, Union 2 3 import torch 4 from torch.nn import CrossEntropyLoss 5 from transformers.modeling_outputs import CausalLMOutputWithPast 6 from transformers.models.phi3.modeling_phi3 import ( 7 _CONFIG_FOR_DOC, 8 PHI3_INPUTS_DOCSTRING, 9 ) 10 from transformers.utils import ( 11 add_start_docstrings_to_model_forward, 12 replace_return_docstrings, 13 ) 14 15 from liger_kernel.transformers.fused_linear_cross_entropy import ( 16 LigerFusedLinearCrossEntropyLoss, 17 ) 18 19 20 @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) 21 @replace_return_docstrings( 22 output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC 23 ) 24 def lce_forward_deprecated( 25 self, 26 input_ids: torch.LongTensor = None, 27 attention_mask: Optional[torch.Tensor] = None, 28 position_ids: Optional[torch.LongTensor] = None, 29 past_key_values: Optional[List[torch.FloatTensor]] = None, 30 inputs_embeds: Optional[torch.FloatTensor] = None, 31 labels: Optional[torch.LongTensor] = None, 32 use_cache: Optional[bool] = None, 33 output_attentions: Optional[bool] = None, 34 output_hidden_states: Optional[bool] = None, 35 return_dict: Optional[bool] = None, 36 cache_position: Optional[torch.LongTensor] = None, 37 ) -> Union[Tuple, CausalLMOutputWithPast]: 38 r""" 39 Copy paste phi3 forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy 40 41 42 Args: 43 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 44 Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 45 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 46 (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 47 48 Returns: 49 50 Example: 51 52 ```python 53 >>> from transformers import AutoTokenizer, Phi3ForCausalLM 54 55 >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct") 56 >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct") 57 58 >>> prompt = "This is an example script ." 59 >>> inputs = tokenizer(prompt, return_tensors="pt") 60 61 >>> # Generate 62 >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 63 >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 64 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' 65 ```""" 66 67 output_attentions = ( 68 output_attentions 69 if output_attentions is not None 70 else self.config.output_attentions 71 ) 72 output_hidden_states = ( 73 output_hidden_states 74 if output_hidden_states is not None 75 else self.config.output_hidden_states 76 ) 77 return_dict = ( 78 return_dict if return_dict is not None else self.config.use_return_dict 79 ) 80 81 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 82 outputs = self.model( 83 input_ids=input_ids, 84 attention_mask=attention_mask, 85 position_ids=position_ids, 86 past_key_values=past_key_values, 87 inputs_embeds=inputs_embeds, 88 use_cache=use_cache, 89 output_attentions=output_attentions, 90 output_hidden_states=output_hidden_states, 91 return_dict=return_dict, 92 ) 93 94 hidden_states = outputs[0] 95 96 loss = None 97 logits = None 98 99 if self.training and labels is not None: 100 shift_hidden_states = hidden_states[..., :-1, :].contiguous() 101 shift_labels = labels[..., 1:].contiguous() 102 103 # flatten tokens 104 shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) 105 shift_labels = shift_labels.view(-1) 106 107 lce = LigerFusedLinearCrossEntropyLoss() 108 loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) 109 else: 110 logits = self.lm_head(hidden_states) 111 112 loss = None 113 if labels is not None: 114 # Upcast to float if we need to compute the loss to avoid potential precision issues 115 logits = logits.float() 116 # Shift so that tokens < n predict n 117 shift_logits = logits[..., :-1, :].contiguous() 118 shift_labels = labels[..., 1:].contiguous() 119 # Flatten the tokens 120 loss_fct = CrossEntropyLoss() 121 shift_logits = shift_logits.view(-1, self.config.vocab_size) 122 shift_labels = shift_labels.view(-1) 123 # Enable model parallelism 124 shift_labels = shift_labels.to(shift_logits.device) 125 loss = loss_fct(shift_logits, shift_labels) 126 127 if not return_dict: 128 output = (logits,) + outputs[1:] 129 return (loss,) + output if loss is not None else output 130 131 return CausalLMOutputWithPast( 132 loss=loss, 133 logits=logits, 134 past_key_values=outputs.past_key_values, 135 hidden_states=outputs.hidden_states, 136 attentions=outputs.attentions, 137 ) 138 139 140 @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) 141 @replace_return_docstrings( 142 output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC 143 ) 144 def lce_forward( 145 self, 146 input_ids: torch.LongTensor = None, 147 attention_mask: Optional[torch.Tensor] = None, 148 position_ids: Optional[torch.LongTensor] = None, 149 past_key_values: Optional[List[torch.FloatTensor]] = None, 150 inputs_embeds: Optional[torch.FloatTensor] = None, 151 labels: Optional[torch.LongTensor] = None, 152 use_cache: Optional[bool] = None, 153 output_attentions: Optional[bool] = None, 154 output_hidden_states: Optional[bool] = None, 155 return_dict: Optional[bool] = None, 156 cache_position: Optional[torch.LongTensor] = None, 157 num_logits_to_keep: int = 0, 158 **loss_kwargs, 159 ) -> Union[Tuple, CausalLMOutputWithPast]: 160 r""" 161 Args: 162 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 163 Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 164 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 165 (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 166 167 num_logits_to_keep (`int`, *optional*): 168 Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all 169 `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that 170 token can save memory, which becomes pretty significant for long sequences or large vocabulary size. 171 172 Returns: 173 174 Example: 175 176 ```python 177 >>> from transformers import AutoTokenizer, Phi3ForCausalLM 178 179 >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct") 180 >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct") 181 182 >>> prompt = "This is an example script ." 183 >>> inputs = tokenizer(prompt, return_tensors="pt") 184 185 >>> # Generate 186 >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 187 >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 188 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' 189 ```""" 190 191 from transformers.models.phi3.modeling_phi3 import logging 192 193 logger = logging.get_logger(__name__) 194 195 if ( 196 use_cache 197 and self.config.rope_scaling 198 and cache_position is not None 199 and cache_position[0] == self.config.original_max_position_embeddings 200 ): 201 logger.warning( 202 f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed." 203 ) 204 205 output_attentions = ( 206 output_attentions 207 if output_attentions is not None 208 else self.config.output_attentions 209 ) 210 output_hidden_states = ( 211 output_hidden_states 212 if output_hidden_states is not None 213 else self.config.output_hidden_states 214 ) 215 return_dict = ( 216 return_dict if return_dict is not None else self.config.use_return_dict 217 ) 218 219 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 220 outputs = self.model( 221 input_ids=input_ids, 222 attention_mask=attention_mask, 223 position_ids=position_ids, 224 past_key_values=past_key_values, 225 inputs_embeds=inputs_embeds, 226 use_cache=use_cache, 227 output_attentions=output_attentions, 228 output_hidden_states=output_hidden_states, 229 return_dict=return_dict, 230 ) 231 232 hidden_states = outputs[0] 233 234 logits = None 235 loss = None 236 # if in training mode, don't materialize logits 237 if self.training and (labels is not None): 238 # We do the same thing as ForCausalLMLoss but using Liger FLCE 239 240 shift_hidden_states = hidden_states[..., :-1, :].contiguous() 241 shift_labels = labels[..., 1:].contiguous() 242 243 # flatten tokens 244 shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) 245 shift_labels = shift_labels.view(-1) 246 247 reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean" 248 lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction) 249 250 loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) 251 if reduction == "sum": 252 loss /= loss_kwargs["num_items_in_batch"] 253 254 else: # if in inference mode materialize logits 255 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) 256 if labels is not None: 257 loss = self.loss_function( 258 logits=logits, 259 labels=labels, 260 vocab_size=self.config.vocab_size, 261 **loss_kwargs, 262 ) 263 264 if not return_dict: 265 output = (logits,) + outputs[1:] 266 return (loss,) + output if loss is not None else output 267 268 return CausalLMOutputWithPast( 269 loss=loss, 270 logits=logits, 271 past_key_values=outputs.past_key_values, 272 hidden_states=outputs.hidden_states, 273 attentions=outputs.attentions, 274 )