phi3.py
  1  from typing import List, Optional, Tuple, Union
  2  
  3  import torch
  4  from torch.nn import CrossEntropyLoss
  5  from transformers.modeling_outputs import CausalLMOutputWithPast
  6  from transformers.models.phi3.modeling_phi3 import (
  7      _CONFIG_FOR_DOC,
  8      PHI3_INPUTS_DOCSTRING,
  9  )
 10  from transformers.utils import (
 11      add_start_docstrings_to_model_forward,
 12      replace_return_docstrings,
 13  )
 14  
 15  from liger_kernel.transformers.fused_linear_cross_entropy import (
 16      LigerFusedLinearCrossEntropyLoss,
 17  )
 18  
 19  
 20  @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
 21  @replace_return_docstrings(
 22      output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
 23  )
 24  def lce_forward_deprecated(
 25      self,
 26      input_ids: torch.LongTensor = None,
 27      attention_mask: Optional[torch.Tensor] = None,
 28      position_ids: Optional[torch.LongTensor] = None,
 29      past_key_values: Optional[List[torch.FloatTensor]] = None,
 30      inputs_embeds: Optional[torch.FloatTensor] = None,
 31      labels: Optional[torch.LongTensor] = None,
 32      use_cache: Optional[bool] = None,
 33      output_attentions: Optional[bool] = None,
 34      output_hidden_states: Optional[bool] = None,
 35      return_dict: Optional[bool] = None,
 36      cache_position: Optional[torch.LongTensor] = None,
 37  ) -> Union[Tuple, CausalLMOutputWithPast]:
 38      r"""
 39      Copy paste phi3 forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
 40  
 41  
 42      Args:
 43          labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 44              Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
 45              config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
 46              (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
 47  
 48      Returns:
 49  
 50      Example:
 51  
 52      ```python
 53      >>> from transformers import AutoTokenizer, Phi3ForCausalLM
 54  
 55      >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
 56      >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
 57  
 58      >>> prompt = "This is an example script ."
 59      >>> inputs = tokenizer(prompt, return_tensors="pt")
 60  
 61      >>> # Generate
 62      >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
 63      >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 64      'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
 65      ```"""
 66  
 67      output_attentions = (
 68          output_attentions
 69          if output_attentions is not None
 70          else self.config.output_attentions
 71      )
 72      output_hidden_states = (
 73          output_hidden_states
 74          if output_hidden_states is not None
 75          else self.config.output_hidden_states
 76      )
 77      return_dict = (
 78          return_dict if return_dict is not None else self.config.use_return_dict
 79      )
 80  
 81      # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
 82      outputs = self.model(
 83          input_ids=input_ids,
 84          attention_mask=attention_mask,
 85          position_ids=position_ids,
 86          past_key_values=past_key_values,
 87          inputs_embeds=inputs_embeds,
 88          use_cache=use_cache,
 89          output_attentions=output_attentions,
 90          output_hidden_states=output_hidden_states,
 91          return_dict=return_dict,
 92      )
 93  
 94      hidden_states = outputs[0]
 95  
 96      loss = None
 97      logits = None
 98  
 99      if self.training and labels is not None:
100          shift_hidden_states = hidden_states[..., :-1, :].contiguous()
101          shift_labels = labels[..., 1:].contiguous()
102  
103          # flatten tokens
104          shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
105          shift_labels = shift_labels.view(-1)
106  
107          lce = LigerFusedLinearCrossEntropyLoss()
108          loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
109      else:
110          logits = self.lm_head(hidden_states)
111  
112          loss = None
113          if labels is not None:
114              # Upcast to float if we need to compute the loss to avoid potential precision issues
115              logits = logits.float()
116              # Shift so that tokens < n predict n
117              shift_logits = logits[..., :-1, :].contiguous()
118              shift_labels = labels[..., 1:].contiguous()
119              # Flatten the tokens
120              loss_fct = CrossEntropyLoss()
121              shift_logits = shift_logits.view(-1, self.config.vocab_size)
122              shift_labels = shift_labels.view(-1)
123              # Enable model parallelism
124              shift_labels = shift_labels.to(shift_logits.device)
125              loss = loss_fct(shift_logits, shift_labels)
126  
127      if not return_dict:
128          output = (logits,) + outputs[1:]
129          return (loss,) + output if loss is not None else output
130  
131      return CausalLMOutputWithPast(
132          loss=loss,
133          logits=logits,
134          past_key_values=outputs.past_key_values,
135          hidden_states=outputs.hidden_states,
136          attentions=outputs.attentions,
137      )
138  
139  
140  @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
141  @replace_return_docstrings(
142      output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
143  )
144  def lce_forward(
145      self,
146      input_ids: torch.LongTensor = None,
147      attention_mask: Optional[torch.Tensor] = None,
148      position_ids: Optional[torch.LongTensor] = None,
149      past_key_values: Optional[List[torch.FloatTensor]] = None,
150      inputs_embeds: Optional[torch.FloatTensor] = None,
151      labels: Optional[torch.LongTensor] = None,
152      use_cache: Optional[bool] = None,
153      output_attentions: Optional[bool] = None,
154      output_hidden_states: Optional[bool] = None,
155      return_dict: Optional[bool] = None,
156      cache_position: Optional[torch.LongTensor] = None,
157      num_logits_to_keep: int = 0,
158      **loss_kwargs,
159  ) -> Union[Tuple, CausalLMOutputWithPast]:
160      r"""
161      Args:
162          labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
163              Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
164              config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
165              (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
166  
167          num_logits_to_keep (`int`, *optional*):
168              Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
169              `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
170              token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
171  
172      Returns:
173  
174      Example:
175  
176      ```python
177      >>> from transformers import AutoTokenizer, Phi3ForCausalLM
178  
179      >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
180      >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
181  
182      >>> prompt = "This is an example script ."
183      >>> inputs = tokenizer(prompt, return_tensors="pt")
184  
185      >>> # Generate
186      >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
187      >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
188      'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
189      ```"""
190  
191      from transformers.models.phi3.modeling_phi3 import logging
192  
193      logger = logging.get_logger(__name__)
194  
195      if (
196          use_cache
197          and self.config.rope_scaling
198          and cache_position is not None
199          and cache_position[0] == self.config.original_max_position_embeddings
200      ):
201          logger.warning(
202              f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed."
203          )
204  
205      output_attentions = (
206          output_attentions
207          if output_attentions is not None
208          else self.config.output_attentions
209      )
210      output_hidden_states = (
211          output_hidden_states
212          if output_hidden_states is not None
213          else self.config.output_hidden_states
214      )
215      return_dict = (
216          return_dict if return_dict is not None else self.config.use_return_dict
217      )
218  
219      # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
220      outputs = self.model(
221          input_ids=input_ids,
222          attention_mask=attention_mask,
223          position_ids=position_ids,
224          past_key_values=past_key_values,
225          inputs_embeds=inputs_embeds,
226          use_cache=use_cache,
227          output_attentions=output_attentions,
228          output_hidden_states=output_hidden_states,
229          return_dict=return_dict,
230      )
231  
232      hidden_states = outputs[0]
233  
234      logits = None
235      loss = None
236      # if in training mode, don't materialize logits
237      if self.training and (labels is not None):
238          # We do the same thing as ForCausalLMLoss but using Liger FLCE
239  
240          shift_hidden_states = hidden_states[..., :-1, :].contiguous()
241          shift_labels = labels[..., 1:].contiguous()
242  
243          # flatten tokens
244          shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
245          shift_labels = shift_labels.view(-1)
246  
247          reduction = "sum" if "num_items_in_batch" in loss_kwargs else "mean"
248          lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction)
249  
250          loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
251          if reduction == "sum":
252              loss /= loss_kwargs["num_items_in_batch"]
253  
254      else:  # if in inference mode materialize logits
255          logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
256          if labels is not None:
257              loss = self.loss_function(
258                  logits=logits,
259                  labels=labels,
260                  vocab_size=self.config.vocab_size,
261                  **loss_kwargs,
262              )
263  
264      if not return_dict:
265          output = (logits,) + outputs[1:]
266          return (loss,) + output if loss is not None else output
267  
268      return CausalLMOutputWithPast(
269          loss=loss,
270          logits=logits,
271          past_key_values=outputs.past_key_values,
272          hidden_states=outputs.hidden_states,
273          attentions=outputs.attentions,
274      )