monkey_patch.py
1 import inspect 2 import logging 3 from functools import partial 4 from typing import Callable 5 6 import transformers 7 from packaging import version 8 from transformers import PreTrainedModel 9 10 from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss 11 from liger_kernel.transformers.functional import liger_cross_entropy 12 from liger_kernel.transformers.geglu import LigerGEGLUMLP 13 from liger_kernel.transformers.layer_norm import LigerLayerNorm 14 from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward 15 from liger_kernel.transformers.model.gemma import ( 16 lce_forward_deprecated as gemma_lce_forward_deprecated, 17 ) 18 from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward 19 from liger_kernel.transformers.model.gemma2 import ( 20 lce_forward_deprecated as gemma2_lce_forward_deprected, 21 ) 22 from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward 23 from liger_kernel.transformers.model.llama import ( 24 lce_forward_deprecated as llama_lce_forward_deprecated, 25 ) 26 from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward 27 from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward 28 from liger_kernel.transformers.model.mixtral import ( 29 lce_forward_deprecated as mixtral_lce_forward_deprecated, 30 ) 31 from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward 32 from liger_kernel.transformers.model.phi3 import ( 33 lce_forward_deprecated as phi3_lce_forward_deprecated, 34 ) 35 from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward 36 from liger_kernel.transformers.model.qwen2 import ( 37 lce_forward_deprecated as qwen2_lce_forward_deprecated, 38 ) 39 from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb 40 from liger_kernel.transformers.rms_norm import LigerRMSNorm 41 from liger_kernel.transformers.rope import liger_rotary_pos_emb 42 from liger_kernel.transformers.swiglu import ( 43 LigerBlockSparseTop2MLP, 44 LigerPhi3SwiGLUMLP, 45 LigerSwiGLUMLP, 46 ) 47 48 transformer_version = version.parse(transformers.__version__) 49 50 logger = logging.getLogger(__name__) 51 SUPPORTED_TRANSFORMER_VERSION = "4.46.1" 52 TRANSFORMER_DEPRECATION_WARNING = "Support for transformers versions < 4.46.1 will soon be discontinued due to issues with incorrect gradient accumulation. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/34191" 53 54 55 def _bind_method_to_module(module, method_name: str, new_method: Callable): 56 # Binds a new method to a module instance so that self is passed as the first argument 57 module.__dict__[method_name] = new_method.__get__(module, module.__class__) 58 59 60 def _patch_rms_norm_module( 61 module, offset=0.0, eps=1e-6, casting_mode="llama", in_place=True 62 ): 63 module.offset = offset 64 module.casting_mode = casting_mode 65 module.variance_epsilon = ( 66 getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps 67 ) 68 module.in_place = in_place 69 _bind_method_to_module(module, "forward", LigerRMSNorm.forward) 70 _bind_method_to_module(module, "extra_repr", LigerRMSNorm.extra_repr) 71 72 73 def _patch_layer_norm_module(module, eps=1e-6): 74 module.variance_epsilon = ( 75 getattr(module, "variance_epsilon", None) or getattr(module, "eps", None) or eps 76 ) 77 module.hidden_size = module.normalized_shape 78 _bind_method_to_module(module, "forward", LigerLayerNorm.forward) 79 _bind_method_to_module(module, "extra_repr", LigerLayerNorm.extra_repr) 80 81 82 def apply_liger_kernel_to_llama( 83 rope: bool = True, 84 cross_entropy: bool = False, 85 fused_linear_cross_entropy: bool = True, 86 rms_norm: bool = True, 87 swiglu: bool = True, 88 model: PreTrainedModel = None, 89 ) -> None: 90 """ 91 Apply Liger kernels to replace original implementation in HuggingFace Llama models (2 and 3) 92 93 Args: 94 rope (bool): Whether to apply Liger's rotary position embedding. Default is True. 95 cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. 96 fused_linear_cross_entropy (bool): 97 Whether to apply Liger's fused linear cross entropy loss. Default is True. 98 `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. 99 If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. 100 rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. 101 swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. 102 model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been 103 loaded. Default is None. 104 """ 105 106 assert not ( 107 cross_entropy and fused_linear_cross_entropy 108 ), "cross_entropy and fused_linear_cross_entropy cannot both be True." 109 110 from transformers.models.llama import modeling_llama 111 from transformers.models.llama.modeling_llama import LlamaModel 112 113 if rope: 114 modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb 115 if rms_norm: 116 modeling_llama.LlamaRMSNorm = LigerRMSNorm 117 if swiglu: 118 modeling_llama.LlamaMLP = LigerSwiGLUMLP 119 120 if cross_entropy: 121 if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): 122 from transformers.loss.loss_utils import nn 123 124 nn.functional.cross_entropy = liger_cross_entropy 125 else: 126 logger.warning(TRANSFORMER_DEPRECATION_WARNING) 127 modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss 128 129 if fused_linear_cross_entropy: 130 if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): 131 modeling_llama.LlamaForCausalLM.forward = llama_lce_forward 132 else: # if version < 4.46.1 133 logger.warning(TRANSFORMER_DEPRECATION_WARNING) 134 modeling_llama.LlamaForCausalLM.forward = llama_lce_forward_deprecated 135 136 if model is not None: 137 # The model instance already exists, so we need to additionally patch the 138 # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP) 139 140 # get the base model from the model instance 141 base_model: LlamaModel = getattr(model, model.base_model_prefix, model) 142 143 if rms_norm: 144 _patch_rms_norm_module(base_model.norm) 145 146 for decoder_layer in base_model.layers: 147 if swiglu: 148 _bind_method_to_module( 149 decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward 150 ) 151 if rms_norm: 152 _patch_rms_norm_module(decoder_layer.input_layernorm) 153 _patch_rms_norm_module(decoder_layer.post_attention_layernorm) 154 155 156 def apply_liger_kernel_to_mllama( 157 rope: bool = True, 158 cross_entropy: bool = False, 159 fused_linear_cross_entropy: bool = True, 160 layer_norm: bool = True, 161 rms_norm: bool = True, 162 swiglu: bool = True, 163 model: PreTrainedModel = None, 164 ) -> None: 165 """ 166 Apply Liger kernels to replace original implementation in HuggingFace MLlama models. 167 NOTE: MLlama is not available in transformers<4.45.0 168 169 Args: 170 rope (bool): Whether to apply Liger's rotary position embedding. Default is True. 171 cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. 172 fused_linear_cross_entropy (bool): 173 Whether to apply Liger's fused linear cross entropy loss. Default is True. 174 `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. 175 If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. 176 rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. 177 swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. 178 model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been 179 loaded. Default is None. 180 """ 181 182 assert not ( 183 cross_entropy and fused_linear_cross_entropy 184 ), "cross_entropy and fused_linear_cross_entropy cannot both be True." 185 186 from transformers.models.mllama import modeling_mllama 187 from transformers.models.mllama.modeling_mllama import ( 188 MllamaForCausalLM, 189 MllamaForConditionalGeneration, 190 MllamaTextModel, 191 MllamaVisionModel, 192 ) 193 194 from liger_kernel.transformers.model.mllama import lce_forward as mllama_lce_forward 195 from liger_kernel.transformers.model.mllama import ( 196 lce_forward_deprecated as mllama_lce_forward_deprecated, 197 ) 198 199 if rope: 200 modeling_mllama.apply_rotary_pos_emb = liger_rotary_pos_emb 201 if layer_norm: 202 modeling_mllama.nn.LayerNorm = LigerLayerNorm 203 if rms_norm: 204 modeling_mllama.MllamaTextRMSNorm = LigerRMSNorm 205 if swiglu: 206 modeling_mllama.MllamaTextMLP = LigerSwiGLUMLP 207 if cross_entropy: 208 if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): 209 from transformers.loss.loss_utils import nn 210 211 nn.functional.cross_entropy = liger_cross_entropy 212 else: 213 logger.warning(TRANSFORMER_DEPRECATION_WARNING) 214 modeling_mllama.CrossEntropyLoss = LigerCrossEntropyLoss 215 if fused_linear_cross_entropy: 216 if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): 217 modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward 218 else: # if version < 4.46.1 219 logger.warning(TRANSFORMER_DEPRECATION_WARNING) 220 modeling_mllama.MllamaForCausalLM.forward = mllama_lce_forward_deprecated 221 222 if model is not None: 223 # The model instance already exists, so we need to additionally patch the 224 # instance variables that reference already-instantiated modules 225 226 if isinstance(model, MllamaForConditionalGeneration): 227 language_model: MllamaForCausalLM = model.language_model 228 vision_model: MllamaVisionModel = model.vision_model 229 text_model: MllamaTextModel = language_model.model 230 elif isinstance(model, MllamaForCausalLM): 231 text_model = model.model 232 vision_model = None 233 elif isinstance(model, MllamaTextModel): 234 text_model = model 235 vision_model = None 236 else: 237 raise ValueError(f"Unsupported Mllama model type: {type(model)}") 238 239 if text_model: 240 if rms_norm: 241 _patch_rms_norm_module(text_model.norm) 242 for decoder_layer in text_model.layers: 243 if swiglu: 244 _bind_method_to_module( 245 decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward 246 ) 247 if rms_norm: 248 _patch_rms_norm_module(decoder_layer.input_layernorm) 249 _patch_rms_norm_module(decoder_layer.post_attention_layernorm) 250 251 if vision_model: 252 _patch_layer_norm_module(vision_model.layernorm_pre) 253 _patch_layer_norm_module(vision_model.layernorm_post) 254 255 for layer in vision_model.transformer.layers: 256 if layer_norm: 257 _patch_layer_norm_module(layer.input_layernorm) 258 _patch_layer_norm_module(layer.post_attention_layernorm) 259 260 for layer in vision_model.global_transformer.layers: 261 if layer_norm: 262 _patch_layer_norm_module(layer.input_layernorm) 263 _patch_layer_norm_module(layer.post_attention_layernorm) 264 265 266 def apply_liger_kernel_to_mistral( 267 rope: bool = True, 268 cross_entropy: bool = False, 269 fused_linear_cross_entropy: bool = True, 270 rms_norm: bool = True, 271 swiglu: bool = True, 272 model: PreTrainedModel = None, 273 ) -> None: 274 """ 275 Apply Liger kernels to replace original implementation in HuggingFace Mistral models 276 277 Args: 278 rope (bool): Whether to apply Liger's rotary position embedding. Default is False. 279 cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True. 280 fused_linear_cross_entropy (bool): 281 Whether to apply Liger's fused linear cross entropy loss. Default is True. 282 `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. 283 If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. 284 rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. 285 rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. 286 swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. 287 model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been 288 loaded. Default is None. 289 """ 290 assert not ( 291 cross_entropy and fused_linear_cross_entropy 292 ), "cross_entropy and fused_linear_cross_entropy cannot both be True." 293 294 from transformers.models.mistral import modeling_mistral 295 from transformers.models.mistral.modeling_mistral import MistralModel 296 297 if rope: 298 modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb 299 if rms_norm: 300 modeling_mistral.MistralRMSNorm = LigerRMSNorm 301 if cross_entropy: 302 modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss 303 if fused_linear_cross_entropy: 304 modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward 305 if swiglu: 306 modeling_mistral.MistralMLP = LigerSwiGLUMLP 307 308 if model is not None: 309 # The model instance already exists, so we need to additionally patch the 310 # instance variables that reference already-instantiated modules 311 312 # get the base model from the model instance 313 base_model: MistralModel = getattr(model, model.base_model_prefix, model) 314 315 if rms_norm: 316 _patch_rms_norm_module(base_model.norm) 317 318 for decoder_layer in base_model.layers: 319 if swiglu: 320 _bind_method_to_module( 321 decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward 322 ) 323 if rms_norm: 324 _patch_rms_norm_module(decoder_layer.input_layernorm) 325 _patch_rms_norm_module(decoder_layer.post_attention_layernorm) 326 327 328 def apply_liger_kernel_to_mixtral( 329 rope: bool = True, 330 cross_entropy: bool = False, 331 fused_linear_cross_entropy: bool = True, 332 rms_norm: bool = True, 333 swiglu: bool = True, 334 model: PreTrainedModel = None, 335 ) -> None: 336 """ 337 Apply Liger kernels to replace original implementation in HuggingFace Mixtral models 338 339 Args: 340 rope (bool): Whether to apply Liger's rotary position embedding. Default is True. 341 cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. 342 fused_linear_cross_entropy (bool): 343 Whether to apply Liger's fused linear cross entropy loss. Default is True. 344 `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. 345 If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. 346 rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. 347 swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. 348 model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been 349 loaded. Default is None. 350 """ 351 352 assert not ( 353 cross_entropy and fused_linear_cross_entropy 354 ), "cross_entropy and fused_linear_cross_entropy cannot both be True." 355 356 from transformers.models.mixtral import modeling_mixtral 357 from transformers.models.mixtral.modeling_mixtral import MixtralModel 358 359 if rope: 360 modeling_mixtral.apply_rotary_pos_emb = liger_rotary_pos_emb 361 if rms_norm: 362 modeling_mixtral.MixtralRMSNorm = LigerRMSNorm 363 if cross_entropy: 364 if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): 365 from transformers.loss.loss_utils import nn 366 367 nn.functional.cross_entropy = liger_cross_entropy 368 else: 369 logger.warning(TRANSFORMER_DEPRECATION_WARNING) 370 modeling_mixtral.CrossEntropyLoss = LigerCrossEntropyLoss 371 372 if fused_linear_cross_entropy: 373 if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): 374 modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward 375 else: # if version < 4.46.1 376 logger.warning(TRANSFORMER_DEPRECATION_WARNING) 377 modeling_mixtral.MixtralForCausalLM.forward = mixtral_lce_forward_deprecated 378 if swiglu: 379 modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP 380 381 if model is not None: 382 # The model instance already exists, so we need to additionally patch the 383 # instance variables that reference already-instantiated modules 384 385 # get the base model from the model instance 386 base_model: MixtralModel = getattr(model, model.base_model_prefix, model) 387 388 if rms_norm: 389 _patch_rms_norm_module(base_model.norm) 390 391 for decoder_layer in base_model.layers: 392 if swiglu: 393 for expert in decoder_layer.block_sparse_moe.experts: 394 _bind_method_to_module( 395 expert, "forward", LigerBlockSparseTop2MLP.forward 396 ) 397 if rms_norm: 398 _patch_rms_norm_module(decoder_layer.input_layernorm) 399 _patch_rms_norm_module(decoder_layer.post_attention_layernorm) 400 401 402 def apply_liger_kernel_to_gemma( 403 rope: bool = True, 404 cross_entropy: bool = False, 405 fused_linear_cross_entropy: bool = True, 406 rms_norm: bool = True, 407 geglu: bool = True, 408 model: PreTrainedModel = None, 409 ) -> None: 410 """ 411 Apply Liger kernels to replace original implementation in HuggingFace Gemma 412 (Gemma 1 and 1.1 supported, for Gemma2 please use `apply_liger_kernel_to_gemma2` ) to make GPU go burrr. 413 414 Args: 415 rope (bool): Whether to apply Liger's rotary position embedding. Default is True. 416 cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. 417 fused_linear_cross_entropy (bool): 418 Whether to apply Liger's fused linear cross entropy loss. Default is True. 419 `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. 420 If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. 421 rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. 422 geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True. 423 model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been 424 loaded. Default is None. 425 """ 426 assert not ( 427 cross_entropy and fused_linear_cross_entropy 428 ), "cross_entropy and fused_linear_cross_entropy cannot both be True." 429 430 from transformers.models.gemma import modeling_gemma 431 from transformers.models.gemma.modeling_gemma import GemmaModel 432 433 # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109 434 LigerRMSNormForGemma = partial( 435 LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma" 436 ) 437 _patch_rms_norm_module_for_gemma = partial( 438 _patch_rms_norm_module, casting_mode="gemma", offset=1.0 439 ) 440 441 if rope: 442 modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb 443 if rms_norm: 444 modeling_gemma.GemmaRMSNorm = LigerRMSNormForGemma 445 if cross_entropy: 446 if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): 447 from transformers.loss.loss_utils import nn 448 449 nn.functional.cross_entropy = liger_cross_entropy 450 else: 451 logger.warning(TRANSFORMER_DEPRECATION_WARNING) 452 modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss 453 if geglu: 454 modeling_gemma.GemmaMLP = LigerGEGLUMLP 455 if fused_linear_cross_entropy: 456 if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): 457 modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward 458 else: # if version < 4.46.1 459 logger.warning(TRANSFORMER_DEPRECATION_WARNING) 460 modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward_deprecated 461 462 if model is not None: 463 # The model instance already exists, so we need to additionally patch the 464 # instance variables that reference already-instantiated modules 465 466 # get the base model from the model instance 467 base_model: GemmaModel = getattr(model, model.base_model_prefix, model) 468 469 if rms_norm: 470 _patch_rms_norm_module_for_gemma(base_model.norm) 471 472 for decoder_layer in base_model.layers: 473 if geglu: 474 _bind_method_to_module( 475 decoder_layer.mlp, "forward", LigerGEGLUMLP.forward 476 ) 477 if rms_norm: 478 _patch_rms_norm_module_for_gemma(decoder_layer.input_layernorm) 479 _patch_rms_norm_module_for_gemma(decoder_layer.post_attention_layernorm) 480 481 482 def apply_liger_kernel_to_gemma2( 483 rope: bool = True, 484 cross_entropy: bool = False, 485 fused_linear_cross_entropy: bool = True, 486 rms_norm: bool = True, 487 geglu: bool = True, 488 model: PreTrainedModel = None, 489 ) -> None: 490 """ 491 Apply Liger kernels to replace original implementation in HuggingFace Gemma2 492 (for Gemma1 please use `apply_liger_kernel_to_gemma`) to make GPU go burrr. 493 494 Args: 495 rope (bool): Whether to apply Liger's rotary position embedding. Default is True. 496 cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. 497 fused_linear_cross_entropy (bool): 498 Whether to apply Liger's fused linear cross entropy loss. Default is True. 499 `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. 500 If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. 501 rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. 502 geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True. 503 model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been 504 loaded. Default is None. 505 """ 506 assert not ( 507 cross_entropy and fused_linear_cross_entropy 508 ), "cross_entropy and fused_linear_cross_entropy cannot both be True." 509 510 from transformers.models.gemma2 import modeling_gemma2 511 from transformers.models.gemma2.modeling_gemma2 import Gemma2Model 512 513 LigerRMSNormForGemma2 = partial( 514 LigerRMSNorm, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False 515 ) 516 _patch_rms_norm_module_for_gemma2 = partial( 517 _patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False 518 ) 519 520 if rope: 521 modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb 522 if rms_norm: 523 # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109 524 modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2 525 if cross_entropy: 526 if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): 527 from transformers.loss.loss_utils import nn 528 529 nn.functional.cross_entropy = liger_cross_entropy 530 else: 531 logger.warning(TRANSFORMER_DEPRECATION_WARNING) 532 modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss 533 if fused_linear_cross_entropy: 534 if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): 535 modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward 536 else: 537 logger.warning(TRANSFORMER_DEPRECATION_WARNING) 538 modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward_deprected 539 if geglu: 540 modeling_gemma2.Gemma2MLP = LigerGEGLUMLP 541 542 if model is not None: 543 # The model instance already exists, so we need to additionally patch the 544 # instance variables that reference already-instantiated modules 545 546 # get the base model from the model instance 547 base_model: Gemma2Model = getattr(model, model.base_model_prefix, model) 548 549 if rms_norm: 550 _patch_rms_norm_module_for_gemma2(base_model.norm) 551 552 for decoder_layer in base_model.layers: 553 if geglu: 554 _bind_method_to_module( 555 decoder_layer.mlp, "forward", LigerGEGLUMLP.forward 556 ) 557 if rms_norm: 558 _patch_rms_norm_module_for_gemma2(decoder_layer.input_layernorm) 559 _patch_rms_norm_module_for_gemma2( 560 decoder_layer.post_attention_layernorm 561 ) 562 _patch_rms_norm_module_for_gemma2( 563 decoder_layer.pre_feedforward_layernorm 564 ) 565 _patch_rms_norm_module_for_gemma2( 566 decoder_layer.post_feedforward_layernorm 567 ) 568 569 570 def apply_liger_kernel_to_qwen2( 571 rope: bool = True, 572 cross_entropy: bool = False, 573 fused_linear_cross_entropy: bool = True, 574 rms_norm: bool = True, 575 swiglu: bool = True, 576 model: PreTrainedModel = None, 577 ) -> None: 578 """ 579 Apply Liger kernels to replace original implementation in HuggingFace Qwen2 models 580 581 Args: 582 rope (bool): Whether to apply Liger's rotary position embedding. Default is True. 583 cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. 584 fused_linear_cross_entropy (bool): 585 Whether to apply Liger's fused linear cross entropy loss. Default is True. 586 `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. 587 If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. 588 rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. 589 swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. 590 model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been 591 loaded. Default is None. 592 """ 593 assert not ( 594 cross_entropy and fused_linear_cross_entropy 595 ), "cross_entropy and fused_linear_cross_entropy cannot both be True." 596 597 from transformers.models.qwen2 import modeling_qwen2 598 from transformers.models.qwen2.modeling_qwen2 import Qwen2Model 599 600 if rope: 601 modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb 602 if rms_norm: 603 modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm 604 605 if cross_entropy: 606 if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): 607 from transformers.loss.loss_utils import nn 608 609 nn.functional.cross_entropy = liger_cross_entropy 610 else: 611 logger.warning(TRANSFORMER_DEPRECATION_WARNING) 612 modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss 613 614 if fused_linear_cross_entropy: 615 if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): 616 modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward 617 else: # if version < 4.46.1 618 logger.warning(TRANSFORMER_DEPRECATION_WARNING) 619 modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward_deprecated 620 621 if swiglu: 622 modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP 623 624 if model is not None: 625 # The model instance already exists, so we need to additionally patch the 626 # instance variables that reference already-instantiated modules 627 628 # get the base model from the model instance 629 base_model: Qwen2Model = getattr(model, model.base_model_prefix, model) 630 631 if rms_norm: 632 _patch_rms_norm_module(base_model.norm) 633 634 for decoder_layer in base_model.layers: 635 if swiglu: 636 _bind_method_to_module( 637 decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward 638 ) 639 if rms_norm: 640 _patch_rms_norm_module(decoder_layer.input_layernorm) 641 _patch_rms_norm_module(decoder_layer.post_attention_layernorm) 642 print("Applied Liger kernels to Qwen2") 643 644 645 def apply_liger_kernel_to_qwen2_vl( 646 rope: bool = True, 647 cross_entropy: bool = False, 648 fused_linear_cross_entropy: bool = True, 649 rms_norm: bool = True, 650 layer_norm: bool = True, 651 swiglu: bool = True, 652 model: PreTrainedModel = None, 653 ) -> None: 654 """ 655 Apply Liger kernels to replace original implementation in HuggingFace Qwen2-VL models. 656 NOTE: Qwen2-VL is not available in transformers<4.45.0 657 658 Args: 659 cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. 660 fused_linear_cross_entropy (bool): 661 Whether to apply Liger's fused linear cross entropy loss. Default is True. 662 `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. 663 If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. 664 rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. 665 layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True. 666 swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. 667 model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been 668 loaded. Default is None. 669 """ 670 assert not ( 671 cross_entropy and fused_linear_cross_entropy 672 ), "cross_entropy and fused_linear_cross_entropy cannot both be True." 673 674 from transformers.models.qwen2_vl import modeling_qwen2_vl 675 from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel 676 677 from liger_kernel.transformers.model.qwen2_vl import ( 678 lce_forward as qwen2_vl_lce_forward, 679 ) 680 681 if rope: 682 modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = ( 683 liger_multimodal_rotary_pos_emb 684 ) 685 if rms_norm: 686 # https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439 687 modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm 688 if layer_norm: 689 modeling_qwen2_vl.LayerNorm = LigerLayerNorm 690 if cross_entropy: 691 modeling_qwen2_vl.CrossEntropyLoss = LigerCrossEntropyLoss 692 if fused_linear_cross_entropy: 693 modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen2_vl_lce_forward 694 if swiglu: 695 modeling_qwen2_vl.Qwen2MLP = LigerSwiGLUMLP 696 697 if model is not None: 698 # The model instance already exists, so we need to additionally patch the 699 # instance variables that reference already-instantiated modules 700 701 # get the base model from the model instance 702 base_model: Qwen2VLModel = getattr(model, model.base_model_prefix, model) 703 704 if hasattr(model, "visual"): 705 # Patch Qwen2VisionTransformerPretrainedModel 706 for vision_block in model.visual.blocks: 707 if layer_norm: 708 _patch_layer_norm_module(vision_block.norm1) 709 _patch_layer_norm_module(vision_block.norm2) 710 711 if rms_norm: 712 _patch_rms_norm_module(base_model.norm) 713 for decoder_layer in base_model.layers: 714 if swiglu: 715 _bind_method_to_module( 716 decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward 717 ) 718 if rms_norm: 719 _patch_rms_norm_module(decoder_layer.input_layernorm) 720 _patch_rms_norm_module(decoder_layer.post_attention_layernorm) 721 722 723 def apply_liger_kernel_to_phi3( 724 rope: bool = True, 725 cross_entropy: bool = False, 726 fused_linear_cross_entropy: bool = True, 727 rms_norm: bool = True, 728 swiglu: bool = True, 729 model: PreTrainedModel = None, 730 ) -> None: 731 """ 732 Apply Liger kernels to replace original implementation in HuggingFace Phi3 models. 733 734 Args: 735 rope (bool): Whether to apply Liger's rotary position embedding. Default is True. 736 cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. 737 fused_linear_cross_entropy (bool): 738 Whether to apply Liger's fused linear cross entropy loss. Default is True. 739 `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. 740 If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. 741 rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. 742 swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True. 743 model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been 744 loaded. Default is None. 745 """ 746 assert not ( 747 cross_entropy and fused_linear_cross_entropy 748 ), "cross_entropy and fused_linear_cross_entropy cannot both be True." 749 750 from transformers.models.phi3 import modeling_phi3 751 from transformers.models.phi3.modeling_phi3 import Phi3Model 752 753 if rope: 754 modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma 755 if rms_norm: 756 modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama 757 if swiglu: 758 modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP 759 if cross_entropy: 760 if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): 761 from transformers.loss.loss_utils import nn 762 763 nn.functional.cross_entropy = liger_cross_entropy 764 else: 765 logger.warning(TRANSFORMER_DEPRECATION_WARNING) 766 modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss 767 if fused_linear_cross_entropy: 768 if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): 769 modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward 770 else: # if version < 4.46.1 771 logger.warning(TRANSFORMER_DEPRECATION_WARNING) 772 modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated 773 774 if model is not None: 775 # The model instance already exists, so we need to additionally patch the 776 # instance variables that reference already-instantiated modules 777 778 # get the base model from the model instance 779 base_model: Phi3Model = getattr(model, model.base_model_prefix, model) 780 781 if rms_norm: 782 _patch_rms_norm_module(base_model.norm) 783 784 for decoder_layer in base_model.layers: 785 if swiglu: 786 _bind_method_to_module( 787 decoder_layer.mlp, "forward", LigerPhi3SwiGLUMLP.forward 788 ) 789 if rms_norm: 790 _patch_rms_norm_module(decoder_layer.input_layernorm) 791 _patch_rms_norm_module(decoder_layer.post_attention_layernorm) 792 793 794 # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py 795 MODEL_TYPE_TO_APPLY_LIGER_FN = { 796 "gemma": apply_liger_kernel_to_gemma, 797 "gemma2": apply_liger_kernel_to_gemma2, 798 "llama": apply_liger_kernel_to_llama, 799 "mllama": apply_liger_kernel_to_mllama, 800 "mllama_text_model": apply_liger_kernel_to_mllama, 801 "mistral": apply_liger_kernel_to_mistral, 802 "mixtral": apply_liger_kernel_to_mixtral, 803 "qwen2": apply_liger_kernel_to_qwen2, 804 "qwen2_vl": apply_liger_kernel_to_qwen2_vl, 805 "phi3": apply_liger_kernel_to_phi3, 806 } 807 808 809 def _apply_liger_kernel(model_type: str, **kwargs) -> None: 810 """ 811 Applies Liger kernels based on the specified model type. The custom 812 kernels for the specified model type will be applied with the provided 813 keyword arguments, otherwise the default configuration will be used. 814 815 ** Note: Calling _apply_liger_kernel() after model initialization 816 will not be able to fully patch models. This must be called before model initialization. 817 If the model has already been instantiated 818 819 Args: 820 - model_type: the model types as defined in transformers/models/auto/modeling_auto.py 821 and specified in the model's config.json 822 - kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function. 823 """ 824 if not model_type: 825 logger.info("Model type was not provided. No Liger kernels will be applied.") 826 return 827 828 if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys(): 829 logger.info( 830 f"There are currently no Liger kernels supported for model type: {model_type}." 831 ) 832 return 833 834 apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type] 835 apply_fn_signature = inspect.signature(apply_fn) 836 837 # Filter out the keyword arguments that are not supported by the apply function 838 applicable_kwargs = { 839 key: value 840 for key, value in kwargs.items() 841 if key in apply_fn_signature.parameters 842 } 843 844 logger.info( 845 f"Applying Liger kernels for model type: {model_type} with kwargs: {applicable_kwargs}" 846 ) 847 848 # Assume this is invoked pre-model initialization, so we only need to patch transformers code 849 apply_fn(**applicable_kwargs) 850 851 852 def _apply_liger_kernel_to_instance(model: PreTrainedModel, **kwargs) -> None: 853 """ 854 Applies Liger kernels to the provided model instance. 855 856 Args: 857 - model: the model instance to apply Liger kernels to 858 - kwargs: keyword arguments that are passed to the corresponding apply_liger_kernel_to_* function. 859 """ 860 model_type = getattr(model, "config", None) and getattr( 861 model.config, "model_type", None 862 ) 863 864 if not model_type: 865 logger.info( 866 "Model type could not be determined from model config. No Liger kernels will be applied." 867 ) 868 return 869 870 if model_type not in MODEL_TYPE_TO_APPLY_LIGER_FN.keys(): 871 logger.info( 872 f"There are currently no Liger kernels supported for model type: {model_type}." 873 ) 874 return 875 876 apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type] 877 878 apply_fn_signature = inspect.signature(apply_fn) 879 880 # Filter out the keyword arguments that are not supported by the apply function 881 applicable_kwargs = { 882 key: value 883 for key, value in kwargs.items() 884 if key in apply_fn_signature.parameters 885 } 886 logger.info( 887 f"Applying Liger kernels to model instance with model type: {model_type} with kwargs: {applicable_kwargs}" 888 ) 889 890 apply_fn(model=model, **applicable_kwargs)