/ src / liger_kernel / transformers / monkey_patch.py
monkey_patch.py
  1  import inspect
  2  import logging
  3  from functools import partial
  4  from typing import Callable
  5  
  6  import transformers
  7  from packaging import version
  8  from transformers import PreTrainedModel
  9  
 10  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
 11  from liger_kernel.transformers.functional import liger_cross_entropy
 12  from liger_kernel.transformers.geglu import LigerGEGLUMLP
 13  from liger_kernel.transformers.layer_norm import LigerLayerNorm
 14  from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
 15  from liger_kernel.transformers.model.gemma import (
 16      lce_forward_deprecated as gemma_lce_forward_deprecated,
 17  )
 18  from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
 19  from liger_kernel.transformers.model.gemma2 import (
 20      lce_forward_deprecated as gemma2_lce_forward_deprected,
 21  )
 22  from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
 23  from liger_kernel.transformers.model.llama import (
 24      lce_forward_deprecated as llama_lce_forward_deprecated,
 25  )
 26  from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
 27  from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
 28  from liger_kernel.transformers.model.mixtral import (
 29      lce_forward_deprecated as mixtral_lce_forward_deprecated,
 30  )
 31  from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
 32  from liger_kernel.transformers.model.phi3 import (
 33      lce_forward_deprecated as phi3_lce_forward_deprecated,
 34  )
 35  from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
 36  from liger_kernel.transformers.model.qwen2 import (
 37      lce_forward_deprecated as qwen2_lce_forward_deprecated,
 38  )
 39  from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
 40  from liger_kernel.transformers.rms_norm import LigerRMSNorm
 41  from liger_kernel.transformers.rope import liger_rotary_pos_emb
 42  from liger_kernel.transformers.swiglu import (
 43      LigerBlockSparseTop2MLP,
 44      LigerPhi3SwiGLUMLP,
 45      LigerSwiGLUMLP,
 46  )
 47  
 48  transformer_version = version.parse(transformers.__version__)
 49  
 50  logger = logging.getLogger(__name__)
 51  SUPPORTED_TRANSFORMER_VERSION = "4.46.1"
 52  TRANSFORMER_DEPRECATION_WARNING = "Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/34191"
 53  
 54  
 55  def _bind_method_to_module(module, method_name: str, new_method: Callable):
 56      # Binds a new method to a module instance so that self is passed as the first argument
 57      module.__dict__[method_name] = new_method.__get__(module, module.__class__)
 58  
 59  
 60  def _patch_rms_norm_module(
 61      module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True
 62  ):
 63      module.offset = offset
 64      module.casting_mode = casting_mode
 65      module.variance_epsilon = (
 66          getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
 67      )
 68      module.in_place = in_place
 69      _bind_method_to_module(module, "forward", LigerRMSNorm.forward)
 70      _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr)
 71  
 72  
 73  def _patch_layer_norm_module(module, eps=1e-6):
 74      module.variance_epsilon = (
 75          getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps
 76      )
 77      module.hidden_size = module.normalized_shape
 78      _bind_method_to_module(module, "forward", LigerLayerNorm.forward)
 79      _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr)
 80  
 81  
 82  def apply_liger_kernel_to_llama(
 83      rope: bool = True,
 84      cross_entropy: bool = False,
 85      fused_linear_cross_entropy: bool = True,
 86      rms_norm: bool = True,
 87      swiglu: bool = True,
 88      model: PreTrainedModel = None,
 89  ) -> None:
 90      """
 91      Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3)
 92  
 93      Args:
 94          rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
 95          cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
 96          fused_linear_cross_entropy (bool):
 97              Whether to apply Liger's fused linear cross entropy loss. Default is True.
 98              `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
 99              If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
100          rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
101          swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
102          model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
103          loaded. Default is None.
104      """
105  
106      assert not (
107          cross_entropy and fused_linear_cross_entropy
108      ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
109  
110      from transformers.models.llama import modeling_llama
111      from transformers.models.llama.modeling_llama import LlamaModel
112  
113      if rope:
114          modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
115      if rms_norm:
116          modeling_llama.LlamaRMSNorm = LigerRMSNorm
117      if swiglu:
118          modeling_llama.LlamaMLP = LigerSwiGLUMLP
119  
120      if cross_entropy:
121          if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
122              from transformers.loss.loss_utils import nn
123  
124              nn.functional.cross_entropy = liger_cross_entropy
125          else:
126              logger.warning(TRANSFORMER_DEPRECATION_WARNING)
127              modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
128  
129      if fused_linear_cross_entropy:
130          if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
131              modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
132          else:  # if version < 4.46.1
133              logger.warning(TRANSFORMER_DEPRECATION_WARNING)
134              modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated
135  
136      if model is not None:
137          # The model instance already exists, so we need to additionally patch the
138          # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
139  
140          # get the base model from the model instance
141          base_model: LlamaModel = getattr(model, model.base_model_prefix, model)
142  
143          if rms_norm:
144              _patch_rms_norm_module(base_model.norm)
145  
146          for decoder_layer in base_model.layers:
147              if swiglu:
148                  _bind_method_to_module(
149                      decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
150                  )
151              if rms_norm:
152                  _patch_rms_norm_module(decoder_layer.input_layernorm)
153                  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
154  
155  
156  def apply_liger_kernel_to_mllama(
157      rope: bool = True,
158      cross_entropy: bool = False,
159      fused_linear_cross_entropy: bool = True,
160      layer_norm: bool = True,
161      rms_norm: bool = True,
162      swiglu: bool = True,
163      model: PreTrainedModel = None,
164  ) -> None:
165      """
166      Apply Liger kernels to replace original implementation in HuggingFace MLlama models.
167      NOTE: MLlama is not available in transformers<4.45.0
168  
169      Args:
170          rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
171          cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
172          fused_linear_cross_entropy (bool):
173              Whether to apply Liger's fused linear cross entropy loss. Default is True.
174              `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
175              If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
176          rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
177          swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
178          model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
179          loaded. Default is None.
180      """
181  
182      assert not (
183          cross_entropy and fused_linear_cross_entropy
184      ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
185  
186      from transformers.models.mllama import modeling_mllama
187      from transformers.models.mllama.modeling_mllama import (
188          MllamaForCausalLM,
189          MllamaForConditionalGeneration,
190          MllamaTextModel,
191          MllamaVisionModel,
192      )
193  
194      from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward
195      from liger_kernel.transformers.model.mllama import (
196          lce_forward_deprecated as mllama_lce_forward_deprecated,
197      )
198  
199      if rope:
200          modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb
201      if layer_norm:
202          modeling_mllama.nn.LayerNorm = LigerLayerNorm
203      if rms_norm:
204          modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm
205      if swiglu:
206          modeling_mllama.MllamaTextMLP = LigerSwiGLUMLP
207      if cross_entropy:
208          if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
209              from transformers.loss.loss_utils import nn
210  
211              nn.functional.cross_entropy = liger_cross_entropy
212          else:
213              logger.warning(TRANSFORMER_DEPRECATION_WARNING)
214              modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss
215      if fused_linear_cross_entropy:
216          if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
217              modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward
218          else:  # if version < 4.46.1
219              logger.warning(TRANSFORMER_DEPRECATION_WARNING)
220              modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated
221  
222      if model is not None:
223          # The model instance already exists, so we need to additionally patch the
224          # instance variables that reference already-instantiated modules
225  
226          if isinstance(model, MllamaForConditionalGeneration):
227              language_model: MllamaForCausalLM = model.language_model
228              vision_model: MllamaVisionModel = model.vision_model
229              text_model: MllamaTextModel = language_model.model
230          elif isinstance(model, MllamaForCausalLM):
231              text_model = model.model
232              vision_model = None
233          elif isinstance(model, MllamaTextModel):
234              text_model = model
235              vision_model = None
236          else:
237              raise ValueError(f"Unsupported Mllama model type: {type(model)}")
238  
239          if text_model:
240              if rms_norm:
241                  _patch_rms_norm_module(text_model.norm)
242              for decoder_layer in text_model.layers:
243                  if swiglu:
244                      _bind_method_to_module(
245                          decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
246                      )
247                  if rms_norm:
248                      _patch_rms_norm_module(decoder_layer.input_layernorm)
249                      _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
250  
251          if vision_model:
252              _patch_layer_norm_module(vision_model.layernorm_pre)
253              _patch_layer_norm_module(vision_model.layernorm_post)
254  
255              for layer in vision_model.transformer.layers:
256                  if layer_norm:
257                      _patch_layer_norm_module(layer.input_layernorm)
258                      _patch_layer_norm_module(layer.post_attention_layernorm)
259  
260              for layer in vision_model.global_transformer.layers:
261                  if layer_norm:
262                      _patch_layer_norm_module(layer.input_layernorm)
263                      _patch_layer_norm_module(layer.post_attention_layernorm)
264  
265  
266  def apply_liger_kernel_to_mistral(
267      rope: bool = True,
268      cross_entropy: bool = False,
269      fused_linear_cross_entropy: bool = True,
270      rms_norm: bool = True,
271      swiglu: bool = True,
272      model: PreTrainedModel = None,
273  ) -> None:
274      """
275      Apply Liger kernels to replace original implementation in HuggingFace Mistral models
276  
277      Args:
278          rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
279          cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
280          fused_linear_cross_entropy (bool):
281              Whether to apply Liger's fused linear cross entropy loss. Default is True.
282              `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
283              If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
284          rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
285          rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
286          swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
287          model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
288          loaded. Default is None.
289      """
290      assert not (
291          cross_entropy and fused_linear_cross_entropy
292      ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
293  
294      from transformers.models.mistral import modeling_mistral
295      from transformers.models.mistral.modeling_mistral import MistralModel
296  
297      if rope:
298          modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
299      if rms_norm:
300          modeling_mistral.MistralRMSNorm = LigerRMSNorm
301      if cross_entropy:
302          modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
303      if fused_linear_cross_entropy:
304          modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
305      if swiglu:
306          modeling_mistral.MistralMLP = LigerSwiGLUMLP
307  
308      if model is not None:
309          # The model instance already exists, so we need to additionally patch the
310          # instance variables that reference already-instantiated modules
311  
312          # get the base model from the model instance
313          base_model: MistralModel = getattr(model, model.base_model_prefix, model)
314  
315          if rms_norm:
316              _patch_rms_norm_module(base_model.norm)
317  
318          for decoder_layer in base_model.layers:
319              if swiglu:
320                  _bind_method_to_module(
321                      decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
322                  )
323              if rms_norm:
324                  _patch_rms_norm_module(decoder_layer.input_layernorm)
325                  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
326  
327  
328  def apply_liger_kernel_to_mixtral(
329      rope: bool = True,
330      cross_entropy: bool = False,
331      fused_linear_cross_entropy: bool = True,
332      rms_norm: bool = True,
333      swiglu: bool = True,
334      model: PreTrainedModel = None,
335  ) -> None:
336      """
337      Apply Liger kernels to replace original implementation in HuggingFace Mixtral models
338  
339      Args:
340          rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
341          cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
342          fused_linear_cross_entropy (bool):
343              Whether to apply Liger's fused linear cross entropy loss. Default is True.
344              `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
345              If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
346          rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
347          swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
348          model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
349          loaded. Default is None.
350      """
351  
352      assert not (
353          cross_entropy and fused_linear_cross_entropy
354      ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
355  
356      from transformers.models.mixtral import modeling_mixtral
357      from transformers.models.mixtral.modeling_mixtral import MixtralModel
358  
359      if rope:
360          modeling_mixtral.apply_rotary_pos_emb = liger_rotary_pos_emb
361      if rms_norm:
362          modeling_mixtral.MixtralRMSNorm = LigerRMSNorm
363      if cross_entropy:
364          if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
365              from transformers.loss.loss_utils import nn
366  
367              nn.functional.cross_entropy = liger_cross_entropy
368          else:
369              logger.warning(TRANSFORMER_DEPRECATION_WARNING)
370              modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss
371  
372      if fused_linear_cross_entropy:
373          if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
374              modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward
375          else:  # if version < 4.46.1
376              logger.warning(TRANSFORMER_DEPRECATION_WARNING)
377              modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated
378      if swiglu:
379          modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP
380  
381      if model is not None:
382          # The model instance already exists, so we need to additionally patch the
383          # instance variables that reference already-instantiated modules
384  
385          # get the base model from the model instance
386          base_model: MixtralModel = getattr(model, model.base_model_prefix, model)
387  
388          if rms_norm:
389              _patch_rms_norm_module(base_model.norm)
390  
391          for decoder_layer in base_model.layers:
392              if swiglu:
393                  for expert in decoder_layer.block_sparse_moe.experts:
394                      _bind_method_to_module(
395                          expert, "forward", LigerBlockSparseTop2MLP.forward
396                      )
397              if rms_norm:
398                  _patch_rms_norm_module(decoder_layer.input_layernorm)
399                  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
400  
401  
402  def apply_liger_kernel_to_gemma(
403      rope: bool = True,
404      cross_entropy: bool = False,
405      fused_linear_cross_entropy: bool = True,
406      rms_norm: bool = True,
407      geglu: bool = True,
408      model: PreTrainedModel = None,
409  ) -> None:
410      """
411      Apply Liger kernels to replace original implementation in HuggingFace Gemma
412      (Gemma 1 and 1.1 supported, for Gemma2 please use `apply_liger_kernel_to_gemma2` ) to make GPU go burrr.
413  
414      Args:
415          rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
416          cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
417          fused_linear_cross_entropy (bool):
418              Whether to apply Liger's fused linear cross entropy loss. Default is True.
419              `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
420              If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
421          rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
422          geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
423          model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
424          loaded. Default is None.
425      """
426      assert not (
427          cross_entropy and fused_linear_cross_entropy
428      ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
429  
430      from transformers.models.gemma import modeling_gemma
431      from transformers.models.gemma.modeling_gemma import GemmaModel
432  
433      # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
434      LigerRMSNormForGemma = partial(
435          LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
436      )
437      _patch_rms_norm_module_for_gemma = partial(
438          _patch_rms_norm_module, casting_mode="gemma", offset=1.0
439      )
440  
441      if rope:
442          modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
443      if rms_norm:
444          modeling_gemma.GemmaRMSNorm = LigerRMSNormForGemma
445      if cross_entropy:
446          if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
447              from transformers.loss.loss_utils import nn
448  
449              nn.functional.cross_entropy = liger_cross_entropy
450          else:
451              logger.warning(TRANSFORMER_DEPRECATION_WARNING)
452              modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
453      if geglu:
454          modeling_gemma.GemmaMLP = LigerGEGLUMLP
455      if fused_linear_cross_entropy:
456          if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
457              modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
458          else:  # if version < 4.46.1
459              logger.warning(TRANSFORMER_DEPRECATION_WARNING)
460              modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated
461  
462      if model is not None:
463          # The model instance already exists, so we need to additionally patch the
464          # instance variables that reference already-instantiated modules
465  
466          # get the base model from the model instance
467          base_model: GemmaModel = getattr(model, model.base_model_prefix, model)
468  
469          if rms_norm:
470              _patch_rms_norm_module_for_gemma(base_model.norm)
471  
472          for decoder_layer in base_model.layers:
473              if geglu:
474                  _bind_method_to_module(
475                      decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
476                  )
477              if rms_norm:
478                  _patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm)
479                  _patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm)
480  
481  
482  def apply_liger_kernel_to_gemma2(
483      rope: bool = True,
484      cross_entropy: bool = False,
485      fused_linear_cross_entropy: bool = True,
486      rms_norm: bool = True,
487      geglu: bool = True,
488      model: PreTrainedModel = None,
489  ) -> None:
490      """
491      Apply Liger kernels to replace original implementation in HuggingFace Gemma2
492      (for Gemma1 please use `apply_liger_kernel_to_gemma`) to make GPU go burrr.
493  
494      Args:
495          rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
496          cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
497          fused_linear_cross_entropy (bool):
498              Whether to apply Liger's fused linear cross entropy loss. Default is True.
499              `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
500              If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
501          rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
502          geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
503          model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
504          loaded. Default is None.
505      """
506      assert not (
507          cross_entropy and fused_linear_cross_entropy
508      ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
509  
510      from transformers.models.gemma2 import modeling_gemma2
511      from transformers.models.gemma2.modeling_gemma2 import Gemma2Model
512  
513      LigerRMSNormForGemma2 = partial(
514          LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False
515      )
516      _patch_rms_norm_module_for_gemma2 = partial(
517          _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
518      )
519  
520      if rope:
521          modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
522      if rms_norm:
523          # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
524          modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2
525      if cross_entropy:
526          if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
527              from transformers.loss.loss_utils import nn
528  
529              nn.functional.cross_entropy = liger_cross_entropy
530          else:
531              logger.warning(TRANSFORMER_DEPRECATION_WARNING)
532              modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
533      if fused_linear_cross_entropy:
534          if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
535              modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward
536          else:
537              logger.warning(TRANSFORMER_DEPRECATION_WARNING)
538              modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected
539      if geglu:
540          modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
541  
542      if model is not None:
543          # The model instance already exists, so we need to additionally patch the
544          # instance variables that reference already-instantiated modules
545  
546          # get the base model from the model instance
547          base_model: Gemma2Model = getattr(model, model.base_model_prefix, model)
548  
549          if rms_norm:
550              _patch_rms_norm_module_for_gemma2(base_model.norm)
551  
552          for decoder_layer in base_model.layers:
553              if geglu:
554                  _bind_method_to_module(
555                      decoder_layer.mlp, "forward", LigerGEGLUMLP.forward
556                  )
557              if rms_norm:
558                  _patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm)
559                  _patch_rms_norm_module_for_gemma2(
560                      decoder_layer.post_attention_layernorm
561                  )
562                  _patch_rms_norm_module_for_gemma2(
563                      decoder_layer.pre_feedforward_layernorm
564                  )
565                  _patch_rms_norm_module_for_gemma2(
566                      decoder_layer.post_feedforward_layernorm
567                  )
568  
569  
570  def apply_liger_kernel_to_qwen2(
571      rope: bool = True,
572      cross_entropy: bool = False,
573      fused_linear_cross_entropy: bool = True,
574      rms_norm: bool = True,
575      swiglu: bool = True,
576      model: PreTrainedModel = None,
577  ) -> None:
578      """
579      Apply Liger kernels to replace original implementation in HuggingFace Qwen2 models
580  
581      Args:
582          rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
583          cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
584          fused_linear_cross_entropy (bool):
585              Whether to apply Liger's fused linear cross entropy loss. Default is True.
586              `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
587              If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
588          rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
589          swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
590          model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
591          loaded. Default is None.
592      """
593      assert not (
594          cross_entropy and fused_linear_cross_entropy
595      ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
596  
597      from transformers.models.qwen2 import modeling_qwen2
598      from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
599  
600      if rope:
601          modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
602      if rms_norm:
603          modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
604  
605      if cross_entropy:
606          if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
607              from transformers.loss.loss_utils import nn
608  
609              nn.functional.cross_entropy = liger_cross_entropy
610          else:
611              logger.warning(TRANSFORMER_DEPRECATION_WARNING)
612              modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
613  
614      if fused_linear_cross_entropy:
615          if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
616              modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
617          else:  # if version < 4.46.1
618              logger.warning(TRANSFORMER_DEPRECATION_WARNING)
619              modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated
620  
621      if swiglu:
622          modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
623  
624      if model is not None:
625          # The model instance already exists, so we need to additionally patch the
626          # instance variables that reference already-instantiated modules
627  
628          # get the base model from the model instance
629          base_model: Qwen2Model = getattr(model, model.base_model_prefix, model)
630  
631          if rms_norm:
632              _patch_rms_norm_module(base_model.norm)
633  
634          for decoder_layer in base_model.layers:
635              if swiglu:
636                  _bind_method_to_module(
637                      decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
638                  )
639              if rms_norm:
640                  _patch_rms_norm_module(decoder_layer.input_layernorm)
641                  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
642      print("Applied Liger kernels to Qwen2")
643  
644  
645  def apply_liger_kernel_to_qwen2_vl(
646      rope: bool = True,
647      cross_entropy: bool = False,
648      fused_linear_cross_entropy: bool = True,
649      rms_norm: bool = True,
650      layer_norm: bool = True,
651      swiglu: bool = True,
652      model: PreTrainedModel = None,
653  ) -> None:
654      """
655      Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models.
656      NOTE: Qwen2-VL is not available in transformers<4.45.0
657  
658      Args:
659          cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
660          fused_linear_cross_entropy (bool):
661              Whether to apply Liger's fused linear cross entropy loss. Default is True.
662              `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
663              If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
664          rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
665          layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
666          swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
667          model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
668          loaded. Default is None.
669      """
670      assert not (
671          cross_entropy and fused_linear_cross_entropy
672      ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
673  
674      from transformers.models.qwen2_vl import modeling_qwen2_vl
675      from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
676  
677      from liger_kernel.transformers.model.qwen2_vl import (
678          lce_forward as qwen2_vl_lce_forward,
679      )
680  
681      if rope:
682          modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = (
683              liger_multimodal_rotary_pos_emb
684          )
685      if rms_norm:
686          # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
687          modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
688      if layer_norm:
689          modeling_qwen2_vl.LayerNorm = LigerLayerNorm
690      if cross_entropy:
691          modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss
692      if fused_linear_cross_entropy:
693          modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward
694      if swiglu:
695          modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP
696  
697      if model is not None:
698          # The model instance already exists, so we need to additionally patch the
699          # instance variables that reference already-instantiated modules
700  
701          # get the base model from the model instance
702          base_model: Qwen2VLModel = getattr(model, model.base_model_prefix, model)
703  
704          if hasattr(model, "visual"):
705              # Patch Qwen2VisionTransformerPretrainedModel
706              for vision_block in model.visual.blocks:
707                  if layer_norm:
708                      _patch_layer_norm_module(vision_block.norm1)
709                      _patch_layer_norm_module(vision_block.norm2)
710  
711          if rms_norm:
712              _patch_rms_norm_module(base_model.norm)
713          for decoder_layer in base_model.layers:
714              if swiglu:
715                  _bind_method_to_module(
716                      decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
717                  )
718              if rms_norm:
719                  _patch_rms_norm_module(decoder_layer.input_layernorm)
720                  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
721  
722  
723  def apply_liger_kernel_to_phi3(
724      rope: bool = True,
725      cross_entropy: bool = False,
726      fused_linear_cross_entropy: bool = True,
727      rms_norm: bool = True,
728      swiglu: bool = True,
729      model: PreTrainedModel = None,
730  ) -> None:
731      """
732      Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
733  
734      Args:
735          rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
736          cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
737          fused_linear_cross_entropy (bool):
738              Whether to apply Liger's fused linear cross entropy loss. Default is True.
739              `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
740              If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
741          rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
742          swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
743          model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
744          loaded. Default is None.
745      """
746      assert not (
747          cross_entropy and fused_linear_cross_entropy
748      ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
749  
750      from transformers.models.phi3 import modeling_phi3
751      from transformers.models.phi3.modeling_phi3 import Phi3Model
752  
753      if rope:
754          modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb  # Same as Gemma
755      if rms_norm:
756          modeling_phi3.Phi3RMSNorm = LigerRMSNorm  # Same as Llama
757      if swiglu:
758          modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
759      if cross_entropy:
760          if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
761              from transformers.loss.loss_utils import nn
762  
763              nn.functional.cross_entropy = liger_cross_entropy
764          else:
765              logger.warning(TRANSFORMER_DEPRECATION_WARNING)
766              modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
767      if fused_linear_cross_entropy:
768          if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
769              modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
770          else:  # if version < 4.46.1
771              logger.warning(TRANSFORMER_DEPRECATION_WARNING)
772              modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
773  
774      if model is not None:
775          # The model instance already exists, so we need to additionally patch the
776          # instance variables that reference already-instantiated modules
777  
778          # get the base model from the model instance
779          base_model: Phi3Model = getattr(model, model.base_model_prefix, model)
780  
781          if rms_norm:
782              _patch_rms_norm_module(base_model.norm)
783  
784          for decoder_layer in base_model.layers:
785              if swiglu:
786                  _bind_method_to_module(
787                      decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward
788                  )
789              if rms_norm:
790                  _patch_rms_norm_module(decoder_layer.input_layernorm)
791                  _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
792  
793  
794  # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
795  MODEL_TYPE_TO_APPLY_LIGER_FN = {
796      "gemma": apply_liger_kernel_to_gemma,
797      "gemma2": apply_liger_kernel_to_gemma2,
798      "llama": apply_liger_kernel_to_llama,
799      "mllama": apply_liger_kernel_to_mllama,
800      "mllama_text_model": apply_liger_kernel_to_mllama,
801      "mistral": apply_liger_kernel_to_mistral,
802      "mixtral": apply_liger_kernel_to_mixtral,
803      "qwen2": apply_liger_kernel_to_qwen2,
804      "qwen2_vl": apply_liger_kernel_to_qwen2_vl,
805      "phi3": apply_liger_kernel_to_phi3,
806  }
807  
808  
809  def _apply_liger_kernel(model_type: str, **kwargs) -> None:
810      """
811      Applies Liger kernels based on the specified model type. The custom
812      kernels for the specified model type will be applied with the provided
813      keyword arguments, otherwise the default configuration will be used.
814  
815      ** Note: Calling _apply_liger_kernel() after model initialization
816      will not be able to fully patch models. This must be called before model initialization.
817      If the model has already been instantiated
818  
819      Args:
820          - model_type: the model types as defined in transformers/models/auto/modeling_auto.py
821            and specified in the model's config.json
822          - kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
823      """
824      if not model_type:
825          logger.info("Model type was not provided. No Liger kernels will be applied.")
826          return
827  
828      if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
829          logger.info(
830              f"There are currently no Liger kernels supported for model type: {model_type}."
831          )
832          return
833  
834      apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
835      apply_fn_signature = inspect.signature(apply_fn)
836  
837      # Filter out the keyword arguments that are not supported by the apply function
838      applicable_kwargs = {
839          key: value
840          for key, value in kwargs.items()
841          if key in apply_fn_signature.parameters
842      }
843  
844      logger.info(
845          f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}"
846      )
847  
848      # Assume this is invoked pre-model initialization, so we only need to patch transformers code
849      apply_fn(**applicable_kwargs)
850  
851  
852  def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None:
853      """
854      Applies Liger kernels to the provided model instance.
855  
856      Args:
857          - model: the model instance to apply Liger kernels to
858          - kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function.
859      """
860      model_type = getattr(model, "config", None) and getattr(
861          model.config, "model_type", None
862      )
863  
864      if not model_type:
865          logger.info(
866              "Model type could not be determined from model config. No Liger kernels will be applied."
867          )
868          return
869  
870      if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys():
871          logger.info(
872              f"There are currently no Liger kernels supported for model type: {model_type}."
873          )
874          return
875  
876      apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
877  
878      apply_fn_signature = inspect.signature(apply_fn)
879  
880      # Filter out the keyword arguments that are not supported by the apply function
881      applicable_kwargs = {
882          key: value
883          for key, value in kwargs.items()
884          if key in apply_fn_signature.parameters
885      }
886      logger.info(
887          f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}"
888      )
889  
890      apply_fn(model=model, **applicable_kwargs)