test_monkey_patch.py
1 import inspect 2 from inspect import signature 3 from unittest.mock import MagicMock, Mock, patch 4 5 import pytest 6 import torch 7 import transformers 8 from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel 9 10 from liger_kernel.transformers import ( 11 LigerBlockSparseTop2MLP, 12 LigerGEGLUMLP, 13 LigerPhi3SwiGLUMLP, 14 LigerRMSNorm, 15 LigerSwiGLUMLP, 16 monkey_patch, 17 ) 18 from liger_kernel.transformers.layer_norm import LigerLayerNorm 19 from liger_kernel.transformers.monkey_patch import ( 20 MODEL_TYPE_TO_APPLY_LIGER_FN, 21 _apply_liger_kernel, 22 _apply_liger_kernel_to_instance, 23 ) 24 25 26 # Check if optional modules are available 27 def is_mllama_available(): 28 try: 29 import transformers.models.mllama # noqa: F401 30 31 return True 32 except ImportError: 33 return False 34 35 36 def is_qwen2_vl_available(): 37 try: 38 import transformers.models.qwen2_vl # noqa: F401 39 40 return True 41 except ImportError: 42 return False 43 44 45 def test_import_from_root(): 46 try: 47 from liger_kernel.transformers import ( # noqa: F401 48 AutoLigerKernelForCausalLM, 49 apply_liger_kernel_to_gemma, 50 apply_liger_kernel_to_gemma2, 51 apply_liger_kernel_to_llama, 52 apply_liger_kernel_to_mistral, 53 apply_liger_kernel_to_mixtral, 54 apply_liger_kernel_to_mllama, 55 apply_liger_kernel_to_phi3, 56 apply_liger_kernel_to_qwen2, 57 apply_liger_kernel_to_qwen2_vl, 58 ) 59 except Exception: 60 pytest.fail("Import kernel patch from root fails") 61 62 63 def test_apply_liger_kernel_no_supported_model_type(): 64 # Test that calling _apply_liger_kernel with an unsupported model type is a no-op 65 mock_mistral = Mock() 66 67 with patch.dict(MODEL_TYPE_TO_APPLY_LIGER_FN, {"mistral": mock_mistral}): 68 _apply_liger_kernel("foobar") 69 MODEL_TYPE_TO_APPLY_LIGER_FN["mistral"].assert_not_called() 70 71 72 def test_apply_liger_kernel_only_supported_model_type_called(): 73 # Test that liger kernel is applied only to the specified model 74 mock_gemma = Mock() 75 mock_llama = Mock() 76 mock_mistral = Mock() 77 78 with patch.dict( 79 MODEL_TYPE_TO_APPLY_LIGER_FN, 80 {"gemma": mock_gemma, "llama": mock_llama, "mistral": mock_mistral}, 81 ): 82 _apply_liger_kernel("llama") 83 mock_llama.assert_called_once() 84 mock_gemma.assert_not_called() 85 mock_mistral.assert_not_called() 86 87 88 def test_apply_liger_kernel_only_passes_valid_kwargs(): 89 # Test that keyword args that are not valid for the apply_liger_* function are not passed 90 mock_llama = Mock() 91 92 def dummy_apply_liger_kernal_to_llama( 93 rope=False, 94 cross_entropy=False, 95 fused_linear_cross_entropy=True, 96 rms_norm=True, 97 swiglu=True, 98 ): 99 pass 100 101 apply_liger_kernal_to_llama_sig = signature(dummy_apply_liger_kernal_to_llama) 102 103 with patch.dict(MODEL_TYPE_TO_APPLY_LIGER_FN, {"llama": mock_llama}): 104 mock_llama.__signature__ = apply_liger_kernal_to_llama_sig 105 ( 106 _apply_liger_kernel( 107 "llama", 108 rope=False, 109 fused_linear_cross_entropy=False, 110 cross_entropy=True, 111 foobar=True, 112 barbaz=False, 113 ), 114 ) 115 mock_llama.assert_called_once() 116 mock_llama.assert_called_once_with( 117 rope=False, 118 fused_linear_cross_entropy=False, 119 cross_entropy=True, 120 ) 121 122 123 def test_apply_liger_kernel_to_instance_no_supported_model_type(): 124 # Test that calling _apply_liger_kernel_to_instance with an unsupported model type is a no-op 125 mock_mistral = Mock() 126 mock_unknown_model = MagicMock(spec=PreTrainedModel) 127 mock_unknown_model.config = {"model_type": "foobar"} 128 129 with patch.dict(MODEL_TYPE_TO_APPLY_LIGER_FN, {"mistral": mock_mistral}): 130 _apply_liger_kernel_to_instance(model=mock_unknown_model) 131 MODEL_TYPE_TO_APPLY_LIGER_FN["mistral"].assert_not_called() 132 133 134 def test_apply_liger_kernel_to_instance_only_supported_model_type_called(): 135 # Test that liger kernel is applied only to the specified model 136 mock_gemma = Mock() 137 mock_llama = Mock() 138 mock_mistral = Mock() 139 140 mock_llama_model_instance = MagicMock(spec=PreTrainedModel) 141 mock_llama_model_instance.config = MagicMock(spec=PretrainedConfig) 142 mock_llama_model_instance.config.model_type = "llama" 143 144 with patch.dict( 145 MODEL_TYPE_TO_APPLY_LIGER_FN, 146 {"gemma": mock_gemma, "llama": mock_llama, "mistral": mock_mistral}, 147 ): 148 _apply_liger_kernel_to_instance(model=mock_llama_model_instance) 149 mock_llama.assert_called_once() 150 mock_gemma.assert_not_called() 151 mock_mistral.assert_not_called() 152 153 154 def test_apply_liger_kernel_to_instance_only_passes_valid_kwargs(): 155 # Test that keyword args that are not valid for the apply_liger_* function are not passed 156 mock_llama = Mock() 157 158 mock_llama_model_instance = MagicMock(spec=PreTrainedModel) 159 mock_llama_model_instance.config = MagicMock(spec=PretrainedConfig) 160 mock_llama_model_instance.config.model_type = "llama" 161 162 def dummy_apply_liger_kernel_to_llama( 163 rope=False, 164 cross_entropy=False, 165 fused_linear_cross_entropy=True, 166 rms_norm=True, 167 swiglu=True, 168 model=None, 169 ): 170 pass 171 172 apply_liger_kernel_to_llama_sig = signature(dummy_apply_liger_kernel_to_llama) 173 174 with patch.dict(MODEL_TYPE_TO_APPLY_LIGER_FN, {"llama": mock_llama}): 175 mock_llama.__signature__ = apply_liger_kernel_to_llama_sig 176 ( 177 _apply_liger_kernel_to_instance( 178 model=mock_llama_model_instance, 179 rope=False, 180 fused_linear_cross_entropy=False, 181 cross_entropy=True, 182 foobar=True, 183 barbaz=False, 184 ), 185 ) 186 mock_llama.assert_called_once() 187 mock_llama.assert_called_once_with( 188 model=mock_llama_model_instance, 189 rope=False, 190 fused_linear_cross_entropy=False, 191 cross_entropy=True, 192 ) 193 194 195 def test_patching_apis_match_auto_mapping(): 196 # Test that all of the patching APIs present also have a corresponding entry in the auto mapping 197 patching_functions = [ 198 func 199 for name, func in inspect.getmembers(monkey_patch, inspect.isfunction) 200 if name.startswith("apply_liger_kernel_to_") 201 ] 202 203 assert set(patching_functions) == set(MODEL_TYPE_TO_APPLY_LIGER_FN.values()) 204 205 206 def test_patching_apis_support_patching_model_instance(): 207 # Test that all the patching APIs present support passing in 208 # model (PreTrainedModel) as an argument indicating that it supports 209 # patching post-model creation 210 patching_functions = [ 211 func 212 for name, func in inspect.getmembers(monkey_patch, inspect.isfunction) 213 if name.startswith("apply_liger_kernel_to_") 214 ] 215 216 for func in patching_functions: 217 sig = inspect.signature(func) 218 # Ensure 'model' is in the parameters 219 assert ( 220 "model" in sig.parameters 221 ), f"{func.__name__} does not have 'model' as an argument. All patching methods must support patching an existing model instance." 222 223 224 def test_apply_liger_kernel_to_instance_for_llama(): 225 # Ensure any monkey patching is cleaned up for subsequent tests 226 with patch("transformers.models.llama.modeling_llama"): 227 # Instantiate a dummy model 228 config = transformers.models.llama.configuration_llama.LlamaConfig( 229 torch_dtype=torch.bfloat16, 230 rms_norm_eps=1e-5, 231 hidden_size=32, 232 intermediate_size=64, 233 hidden_act="silu", 234 num_hidden_layers=2, 235 ) 236 dummy_model_instance = AutoModelForCausalLM.from_config(config) 237 238 # Check that model instance variables are not yet patched with Liger modules 239 assert inspect.getsource( 240 dummy_model_instance.model.norm.forward 241 ) != inspect.getsource(LigerRMSNorm.forward) 242 for layer in dummy_model_instance.model.layers: 243 assert inspect.getsource(layer.mlp.forward) != inspect.getsource( 244 LigerSwiGLUMLP.forward 245 ) 246 assert inspect.getsource( 247 layer.input_layernorm.forward 248 ) != inspect.getsource(LigerRMSNorm.forward) 249 assert inspect.getsource( 250 layer.post_attention_layernorm.forward 251 ) != inspect.getsource(LigerRMSNorm.forward) 252 253 # Test applying kernels to the model instance 254 _apply_liger_kernel_to_instance(model=dummy_model_instance) 255 256 # Check that the model's instance variables were correctly patched with Liger modules 257 assert inspect.getsource( 258 dummy_model_instance.model.norm.forward 259 ) == inspect.getsource(LigerRMSNorm.forward) 260 for layer in dummy_model_instance.model.layers: 261 assert inspect.getsource(layer.mlp.forward) == inspect.getsource( 262 LigerSwiGLUMLP.forward 263 ) 264 assert inspect.getsource( 265 layer.input_layernorm.forward 266 ) == inspect.getsource(LigerRMSNorm.forward) 267 assert inspect.getsource( 268 layer.post_attention_layernorm.forward 269 ) == inspect.getsource(LigerRMSNorm.forward) 270 271 # Ensure that the model patched with Liger modules can work properly 272 try: 273 print(dummy_model_instance) 274 except Exception as e: 275 pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}") 276 277 278 @pytest.mark.skipif(not is_mllama_available(), reason="mllama module not available") 279 def test_apply_liger_kernel_to_instance_for_mllama_for_conditional_generation(): 280 # Ensure any monkey patching is cleaned up for subsequent tests 281 with patch("transformers.models.mllama.modeling_mllama"): 282 from transformers.models.mllama.modeling_mllama import ( 283 MllamaForConditionalGeneration, 284 ) 285 286 # Instantiate a dummy model 287 config = transformers.models.mllama.configuration_mllama.MllamaConfig( 288 torch_dtype=torch.bfloat16, 289 text_config=transformers.models.mllama.configuration_mllama.MllamaTextConfig( 290 rms_norm_eps=1e-5, 291 hidden_size=32, 292 intermediate_size=64, 293 hidden_act="silu", 294 num_hidden_layers=2, 295 rope_scaling=dict( 296 factor=8.0, 297 high_freq_factor=4.0, 298 low_freq_factor=1.0, 299 original_max_position_embeddings=8192, 300 rope_type="llama3", 301 ), 302 ), 303 vision_config=transformers.models.mllama.configuration_mllama.MllamaVisionConfig( 304 rms_norm_eps=1e-5, 305 hidden_size=32, 306 intermediate_size=64, 307 hidden_act="gelu", 308 num_hidden_layers=2, 309 vision_output_dim=64, 310 ), 311 ) 312 dummy_model_instance = MllamaForConditionalGeneration._from_config(config) 313 314 assert isinstance(dummy_model_instance, MllamaForConditionalGeneration) 315 316 # Check that model instance variables are not yet patched with Liger modules 317 assert inspect.getsource( 318 dummy_model_instance.language_model.model.norm.forward 319 ) != inspect.getsource(LigerRMSNorm.forward) 320 for layer in dummy_model_instance.language_model.model.layers: 321 assert inspect.getsource(layer.mlp.forward) != inspect.getsource( 322 LigerSwiGLUMLP.forward 323 ) 324 assert inspect.getsource( 325 layer.input_layernorm.forward 326 ) != inspect.getsource(LigerRMSNorm.forward) 327 assert inspect.getsource( 328 layer.post_attention_layernorm.forward 329 ) != inspect.getsource(LigerRMSNorm.forward) 330 331 assert inspect.getsource( 332 dummy_model_instance.vision_model.layernorm_pre.forward 333 ) != inspect.getsource(LigerLayerNorm.forward) 334 assert inspect.getsource( 335 dummy_model_instance.vision_model.layernorm_post.forward 336 ) != inspect.getsource(LigerLayerNorm.forward) 337 for layer in dummy_model_instance.vision_model.transformer.layers: 338 assert inspect.getsource( 339 layer.input_layernorm.forward 340 ) != inspect.getsource(LigerLayerNorm.forward) 341 assert inspect.getsource( 342 layer.post_attention_layernorm.forward 343 ) != inspect.getsource(LigerLayerNorm.forward) 344 for layer in dummy_model_instance.vision_model.global_transformer.layers: 345 assert inspect.getsource( 346 layer.input_layernorm.forward 347 ) != inspect.getsource(LigerLayerNorm.forward) 348 assert inspect.getsource( 349 layer.post_attention_layernorm.forward 350 ) != inspect.getsource(LigerLayerNorm.forward) 351 352 # Test applying kernels to the model instance 353 _apply_liger_kernel_to_instance(model=dummy_model_instance) 354 355 # Check that the model's instance variables were correctly patched with Liger modules 356 assert inspect.getsource( 357 dummy_model_instance.language_model.model.norm.forward 358 ) == inspect.getsource(LigerRMSNorm.forward) 359 for layer in dummy_model_instance.language_model.model.layers: 360 assert inspect.getsource(layer.mlp.forward) == inspect.getsource( 361 LigerSwiGLUMLP.forward 362 ) 363 assert inspect.getsource( 364 layer.input_layernorm.forward 365 ) == inspect.getsource(LigerRMSNorm.forward) 366 assert inspect.getsource( 367 layer.post_attention_layernorm.forward 368 ) == inspect.getsource(LigerRMSNorm.forward) 369 370 assert inspect.getsource( 371 dummy_model_instance.vision_model.layernorm_pre.forward 372 ) == inspect.getsource(LigerLayerNorm.forward) 373 assert inspect.getsource( 374 dummy_model_instance.vision_model.layernorm_post.forward 375 ) == inspect.getsource(LigerLayerNorm.forward) 376 for layer in dummy_model_instance.vision_model.transformer.layers: 377 assert inspect.getsource( 378 layer.input_layernorm.forward 379 ) == inspect.getsource(LigerLayerNorm.forward) 380 assert inspect.getsource( 381 layer.post_attention_layernorm.forward 382 ) == inspect.getsource(LigerLayerNorm.forward) 383 for layer in dummy_model_instance.vision_model.global_transformer.layers: 384 assert inspect.getsource( 385 layer.input_layernorm.forward 386 ) == inspect.getsource(LigerLayerNorm.forward) 387 assert inspect.getsource( 388 layer.post_attention_layernorm.forward 389 ) == inspect.getsource(LigerLayerNorm.forward) 390 391 try: 392 print(dummy_model_instance) 393 except Exception as e: 394 pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}") 395 396 397 @pytest.mark.skipif(not is_mllama_available(), reason="mllama module not available") 398 def test_apply_liger_kernel_to_instance_for_mllama_for_causal_lm(): 399 # Ensure any monkey patching is cleaned up for subsequent tests 400 with patch("transformers.models.mllama.modeling_mllama"): 401 from transformers.models.mllama.modeling_mllama import MllamaForCausalLM 402 403 # Instantiate a dummy model 404 config = transformers.models.mllama.configuration_mllama.MllamaTextConfig( 405 rms_norm_eps=1e-5, 406 hidden_size=32, 407 intermediate_size=64, 408 hidden_act="silu", 409 num_hidden_layers=2, 410 rope_scaling=dict( 411 factor=8.0, 412 high_freq_factor=4.0, 413 low_freq_factor=1.0, 414 original_max_position_embeddings=8192, 415 rope_type="llama3", 416 ), 417 ) 418 419 dummy_model_instance = MllamaForCausalLM._from_config(config) 420 421 assert isinstance(dummy_model_instance, MllamaForCausalLM) 422 423 # Check that model instance variables are not yet patched with Liger modules 424 assert not isinstance(dummy_model_instance.model.norm, LigerRMSNorm) 425 for layer in dummy_model_instance.model.layers: 426 assert inspect.getsource(layer.mlp.forward) != inspect.getsource( 427 LigerSwiGLUMLP.forward 428 ) 429 assert inspect.getsource( 430 layer.input_layernorm.forward 431 ) != inspect.getsource(LigerRMSNorm.forward) 432 assert inspect.getsource( 433 layer.post_attention_layernorm.forward 434 ) != inspect.getsource(LigerRMSNorm.forward) 435 436 # Test applying kernels to the model instance 437 _apply_liger_kernel_to_instance(model=dummy_model_instance) 438 439 # Check that the model's instance variables were correctly patched with Liger modules 440 assert inspect.getsource( 441 dummy_model_instance.model.norm.forward 442 ) == inspect.getsource(LigerRMSNorm.forward) 443 for layer in dummy_model_instance.model.layers: 444 assert inspect.getsource(layer.mlp.forward) == inspect.getsource( 445 LigerSwiGLUMLP.forward 446 ) 447 assert inspect.getsource( 448 layer.input_layernorm.forward 449 ) == inspect.getsource(LigerRMSNorm.forward) 450 assert inspect.getsource( 451 layer.post_attention_layernorm.forward 452 ) == inspect.getsource(LigerRMSNorm.forward) 453 454 try: 455 print(dummy_model_instance) 456 except Exception as e: 457 pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}") 458 459 460 def test_apply_liger_kernel_to_instance_for_mistral(): 461 # Ensure any monkey patching is cleaned up for subsequent tests 462 with patch("transformers.models.mistral.modeling_mistral"): 463 # Instantiate a dummy model 464 config = transformers.models.mistral.configuration_mistral.MistralConfig( 465 torch_dtype=torch.bfloat16, 466 rms_norm_eps=1e-5, 467 hidden_size=32, 468 intermediate_size=64, 469 hidden_act="silu", 470 num_hidden_layers=2, 471 ) 472 dummy_model_instance = AutoModelForCausalLM.from_config(config) 473 474 # Check that model instance variables are not yet patched with Liger modules 475 assert inspect.getsource( 476 dummy_model_instance.model.norm.forward 477 ) != inspect.getsource(LigerRMSNorm.forward) 478 for layer in dummy_model_instance.model.layers: 479 assert inspect.getsource(layer.mlp.forward) != inspect.getsource( 480 LigerSwiGLUMLP.forward 481 ) 482 assert inspect.getsource( 483 layer.input_layernorm.forward 484 ) != inspect.getsource(LigerRMSNorm.forward) 485 assert inspect.getsource( 486 layer.post_attention_layernorm.forward 487 ) != inspect.getsource(LigerRMSNorm.forward) 488 489 # Test applying kernels to the model instance 490 _apply_liger_kernel_to_instance(model=dummy_model_instance) 491 492 # Check that the model's instance variables were correctly patched with Liger modules 493 assert inspect.getsource( 494 dummy_model_instance.model.norm.forward 495 ) == inspect.getsource(LigerRMSNorm.forward) 496 for layer in dummy_model_instance.model.layers: 497 assert inspect.getsource(layer.mlp.forward) == inspect.getsource( 498 LigerSwiGLUMLP.forward 499 ) 500 assert inspect.getsource( 501 layer.input_layernorm.forward 502 ) == inspect.getsource(LigerRMSNorm.forward) 503 assert inspect.getsource( 504 layer.post_attention_layernorm.forward 505 ) == inspect.getsource(LigerRMSNorm.forward) 506 507 try: 508 print(dummy_model_instance) 509 except Exception as e: 510 pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}") 511 512 513 def test_apply_liger_kernel_to_instance_for_mixtral(): 514 # Ensure any monkey patching is cleaned up for subsequent tests 515 with patch("transformers.models.mixtral.modeling_mixtral"): 516 # Instantiate a dummy model 517 config = transformers.models.mixtral.configuration_mixtral.MixtralConfig( 518 torch_dtype=torch.bfloat16, 519 rms_norm_eps=1e-5, 520 hidden_size=32, 521 intermediate_size=64, 522 hidden_act="silu", 523 num_hidden_layers=2, 524 num_local_experts=3, 525 num_experts_per_tok=2, 526 ) 527 dummy_model_instance = AutoModelForCausalLM.from_config(config) 528 529 # Check that model instance variables are not yet patched with Liger modules 530 assert inspect.getsource( 531 dummy_model_instance.model.norm.forward 532 ) != inspect.getsource(LigerRMSNorm.forward) 533 for layer in dummy_model_instance.model.layers: 534 for expert in layer.block_sparse_moe.experts: 535 assert inspect.getsource(expert.forward) != inspect.getsource( 536 LigerBlockSparseTop2MLP.forward 537 ) 538 assert inspect.getsource( 539 layer.input_layernorm.forward 540 ) != inspect.getsource(LigerRMSNorm.forward) 541 assert inspect.getsource( 542 layer.post_attention_layernorm.forward 543 ) != inspect.getsource(LigerRMSNorm.forward) 544 545 # Test applying kernels to the model instance 546 _apply_liger_kernel_to_instance(model=dummy_model_instance) 547 548 # Check that the model's instance variables were correctly patched with Liger modules 549 assert inspect.getsource( 550 dummy_model_instance.model.norm.forward 551 ) == inspect.getsource(LigerRMSNorm.forward) 552 for layer in dummy_model_instance.model.layers: 553 for expert in layer.block_sparse_moe.experts: 554 assert inspect.getsource(expert.forward) == inspect.getsource( 555 LigerBlockSparseTop2MLP.forward 556 ) 557 assert inspect.getsource( 558 layer.input_layernorm.forward 559 ) == inspect.getsource(LigerRMSNorm.forward) 560 assert inspect.getsource( 561 layer.post_attention_layernorm.forward 562 ) == inspect.getsource(LigerRMSNorm.forward) 563 564 try: 565 print(dummy_model_instance) 566 except Exception as e: 567 pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}") 568 569 570 def test_apply_liger_kernel_to_instance_for_gemma(): 571 # Ensure any monkey patching is cleaned up for subsequent tests 572 with patch("transformers.models.gemma.modeling_gemma"): 573 # Instantiate a dummy model 574 config = transformers.models.gemma.configuration_gemma.GemmaConfig( 575 torch_dtype=torch.bfloat16, 576 rms_norm_eps=1e-5, 577 hidden_size=32, 578 intermediate_size=64, 579 hidden_act="silu", 580 num_hidden_layers=2, 581 ) 582 dummy_model_instance = AutoModelForCausalLM.from_config(config) 583 584 # Check that model instance variables are not yet patched with Liger modules 585 assert inspect.getsource( 586 dummy_model_instance.model.norm.forward 587 ) != inspect.getsource(LigerRMSNorm.forward) 588 for layer in dummy_model_instance.model.layers: 589 assert inspect.getsource(layer.mlp.forward) != inspect.getsource( 590 LigerGEGLUMLP.forward 591 ) 592 assert inspect.getsource( 593 layer.input_layernorm.forward 594 ) != inspect.getsource(LigerRMSNorm.forward) 595 assert inspect.getsource( 596 layer.post_attention_layernorm.forward 597 ) != inspect.getsource(LigerRMSNorm.forward) 598 599 # Test applying kernels to the model instance 600 _apply_liger_kernel_to_instance(model=dummy_model_instance) 601 602 # Check that the model's instance variables were correctly patched with Liger modules 603 assert inspect.getsource( 604 dummy_model_instance.model.norm.forward 605 ) == inspect.getsource(LigerRMSNorm.forward) 606 for layer in dummy_model_instance.model.layers: 607 assert inspect.getsource(layer.mlp.forward) == inspect.getsource( 608 LigerGEGLUMLP.forward 609 ) 610 assert inspect.getsource( 611 layer.input_layernorm.forward 612 ) == inspect.getsource(LigerRMSNorm.forward) 613 assert inspect.getsource( 614 layer.post_attention_layernorm.forward 615 ) == inspect.getsource(LigerRMSNorm.forward) 616 617 try: 618 print(dummy_model_instance) 619 except Exception as e: 620 pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}") 621 622 623 def test_apply_liger_kernel_to_instance_for_gemma2(): 624 # Ensure any monkey patching is cleaned up for subsequent tests 625 with patch("transformers.models.gemma2.modeling_gemma2"): 626 # Instantiate a dummy model 627 config = transformers.models.gemma2.configuration_gemma2.Gemma2Config( 628 torch_dtype=torch.bfloat16, 629 rms_norm_eps=1e-5, 630 hidden_size=32, 631 intermediate_size=64, 632 hidden_act="silu", 633 num_hidden_layers=2, 634 ) 635 dummy_model_instance = AutoModelForCausalLM.from_config(config) 636 637 # Check that model instance variables are not yet patched with Liger modules 638 assert inspect.getsource( 639 dummy_model_instance.model.norm.forward 640 ) != inspect.getsource(LigerRMSNorm.forward) 641 for layer in dummy_model_instance.model.layers: 642 assert inspect.getsource(layer.mlp.forward) != inspect.getsource( 643 LigerGEGLUMLP.forward 644 ) 645 assert inspect.getsource( 646 layer.input_layernorm.forward 647 ) != inspect.getsource(LigerRMSNorm.forward) 648 assert inspect.getsource( 649 layer.post_attention_layernorm.forward 650 ) != inspect.getsource(LigerRMSNorm.forward) 651 assert inspect.getsource( 652 layer.pre_feedforward_layernorm.forward 653 ) != inspect.getsource(LigerRMSNorm.forward) 654 assert inspect.getsource( 655 layer.post_feedforward_layernorm.forward 656 ) != inspect.getsource(LigerRMSNorm.forward) 657 658 # Test applying kernels to the model instance 659 _apply_liger_kernel_to_instance(model=dummy_model_instance) 660 661 # Check that the model's instance variables were correctly patched with Liger modules 662 assert inspect.getsource( 663 dummy_model_instance.model.norm.forward 664 ) == inspect.getsource(LigerRMSNorm.forward) 665 for layer in dummy_model_instance.model.layers: 666 assert inspect.getsource(layer.mlp.forward) == inspect.getsource( 667 LigerGEGLUMLP.forward 668 ) 669 assert inspect.getsource( 670 layer.input_layernorm.forward 671 ) == inspect.getsource(LigerRMSNorm.forward) 672 assert inspect.getsource( 673 layer.post_attention_layernorm.forward 674 ) == inspect.getsource(LigerRMSNorm.forward) 675 assert inspect.getsource( 676 layer.pre_feedforward_layernorm.forward 677 ) == inspect.getsource(LigerRMSNorm.forward) 678 assert inspect.getsource( 679 layer.post_feedforward_layernorm.forward 680 ) == inspect.getsource(LigerRMSNorm.forward) 681 682 try: 683 print(dummy_model_instance) 684 except Exception as e: 685 pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}") 686 687 688 def test_apply_liger_kernel_to_instance_for_qwen2(): 689 # Ensure any monkey patching is cleaned up for subsequent tests 690 with patch("transformers.models.qwen2.modeling_qwen2"): 691 # Instantiate a dummy model 692 config = transformers.models.qwen2.configuration_qwen2.Qwen2Config( 693 torch_dtype=torch.bfloat16, 694 rms_norm_eps=1e-5, 695 hidden_size=32, 696 intermediate_size=64, 697 hidden_act="silu", 698 num_hidden_layers=2, 699 ) 700 dummy_model_instance = AutoModelForCausalLM.from_config(config) 701 702 # Check that model instance variables are not yet patched with Liger modules 703 assert inspect.getsource( 704 dummy_model_instance.model.norm.forward 705 ) != inspect.getsource(LigerRMSNorm.forward) 706 for layer in dummy_model_instance.model.layers: 707 assert inspect.getsource(layer.mlp.forward) != inspect.getsource( 708 LigerSwiGLUMLP.forward 709 ) 710 assert inspect.getsource( 711 layer.input_layernorm.forward 712 ) != inspect.getsource(LigerRMSNorm.forward) 713 assert inspect.getsource( 714 layer.post_attention_layernorm.forward 715 ) != inspect.getsource(LigerRMSNorm.forward) 716 717 # Test applying kernels to the model instance 718 _apply_liger_kernel_to_instance(model=dummy_model_instance) 719 720 # Check that the model's instance variables were correctly patched with Liger modules 721 assert inspect.getsource( 722 dummy_model_instance.model.norm.forward 723 ) == inspect.getsource(LigerRMSNorm.forward) 724 for layer in dummy_model_instance.model.layers: 725 assert inspect.getsource(layer.mlp.forward) == inspect.getsource( 726 LigerSwiGLUMLP.forward 727 ) 728 assert inspect.getsource( 729 layer.input_layernorm.forward 730 ) == inspect.getsource(LigerRMSNorm.forward) 731 assert inspect.getsource( 732 layer.post_attention_layernorm.forward 733 ) == inspect.getsource(LigerRMSNorm.forward) 734 735 try: 736 print(dummy_model_instance) 737 except Exception as e: 738 pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}") 739 740 741 @pytest.mark.skipif(not is_qwen2_vl_available(), reason="qwen2_vl module not available") 742 def test_apply_liger_kernel_to_instance_for_qwen2_vl(): 743 # Ensure any monkey patching is cleaned up for subsequent tests 744 with patch("transformers.models.qwen2_vl.modeling_qwen2_vl"): 745 from transformers.models.qwen2_vl.modeling_qwen2_vl import ( 746 Qwen2VLForConditionalGeneration, 747 ) 748 749 # Instantiate a dummy model 750 config = transformers.models.qwen2_vl.configuration_qwen2_vl.Qwen2VLConfig( 751 torch_dtype=torch.bfloat16, 752 rms_norm_eps=1e-5, 753 hidden_size=32, 754 intermediate_size=48, 755 embed_dim=16, 756 hidden_act="silu", 757 num_hidden_layers=2, 758 num_attention_heads=2, 759 max_position_embeddings=128, 760 vocab_size=1000, 761 vision_config={ 762 "depth": 4, 763 "embed_dim": 128, 764 "num_heads": 8, 765 "hidden_size": 1024, 766 }, 767 ) 768 dummy_model_instance = Qwen2VLForConditionalGeneration._from_config(config) 769 770 assert isinstance(dummy_model_instance, Qwen2VLForConditionalGeneration) 771 772 # Check that model instance variables are not yet patched with Liger modules 773 assert inspect.getsource( 774 dummy_model_instance.model.norm.forward 775 ) != inspect.getsource(LigerRMSNorm.forward) 776 for layer in dummy_model_instance.model.layers: 777 assert inspect.getsource(layer.mlp.forward) != inspect.getsource( 778 LigerSwiGLUMLP.forward 779 ) 780 assert inspect.getsource( 781 layer.input_layernorm.forward 782 ) != inspect.getsource(LigerRMSNorm.forward) 783 assert inspect.getsource( 784 layer.post_attention_layernorm.forward 785 ) != inspect.getsource(LigerRMSNorm.forward) 786 for vision_block in dummy_model_instance.visual.blocks: 787 assert inspect.getsource(vision_block.norm1.forward) != inspect.getsource( 788 LigerLayerNorm.forward 789 ) 790 assert inspect.getsource(vision_block.norm2.forward) != inspect.getsource( 791 LigerLayerNorm.forward 792 ) 793 794 # Test applying kernels to the model instance 795 _apply_liger_kernel_to_instance(model=dummy_model_instance) 796 797 # Check that the model's instance variables were correctly patched with Liger modules 798 assert inspect.getsource( 799 dummy_model_instance.model.norm.forward 800 ) == inspect.getsource(LigerRMSNorm.forward) 801 for layer in dummy_model_instance.model.layers: 802 assert inspect.getsource(layer.mlp.forward) == inspect.getsource( 803 LigerSwiGLUMLP.forward 804 ) 805 assert inspect.getsource( 806 layer.input_layernorm.forward 807 ) == inspect.getsource(LigerRMSNorm.forward) 808 assert inspect.getsource( 809 layer.post_attention_layernorm.forward 810 ) == inspect.getsource(LigerRMSNorm.forward) 811 for vision_block in dummy_model_instance.visual.blocks: 812 assert inspect.getsource(vision_block.norm1.forward) == inspect.getsource( 813 LigerLayerNorm.forward 814 ) 815 assert inspect.getsource(vision_block.norm2.forward) == inspect.getsource( 816 LigerLayerNorm.forward 817 ) 818 819 try: 820 print(dummy_model_instance) 821 except Exception as e: 822 pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}") 823 824 825 def test_apply_liger_kernel_to_instance_for_phi3(): 826 # Ensure any monkey patching is cleaned up for subsequent tests 827 with patch("transformers.models.phi3.modeling_phi3"): 828 # Instantiate a dummy model 829 config = transformers.models.phi3.configuration_phi3.Phi3Config( 830 torch_dtype=torch.bfloat16, 831 rms_norm_eps=1e-5, 832 hidden_size=32, 833 intermediate_size=64, 834 hidden_act="silu", 835 num_hidden_layers=2, 836 ) 837 dummy_model_instance = AutoModelForCausalLM.from_config(config) 838 839 # Check that model instance variables are not yet patched with Liger modules 840 assert inspect.getsource( 841 dummy_model_instance.model.norm.forward 842 ) != inspect.getsource(LigerRMSNorm.forward) 843 for layer in dummy_model_instance.model.layers: 844 assert inspect.getsource(layer.mlp.forward) != inspect.getsource( 845 LigerPhi3SwiGLUMLP.forward 846 ) 847 assert inspect.getsource( 848 layer.input_layernorm.forward 849 ) != inspect.getsource(LigerRMSNorm.forward) 850 assert inspect.getsource( 851 layer.post_attention_layernorm.forward 852 ) != inspect.getsource(LigerRMSNorm.forward) 853 854 # Test applying kernels to the model instance 855 _apply_liger_kernel_to_instance(model=dummy_model_instance) 856 857 # Check that the model's instance variables were correctly patched with Liger modules 858 assert inspect.getsource( 859 dummy_model_instance.model.norm.forward 860 ) == inspect.getsource(LigerRMSNorm.forward) 861 for layer in dummy_model_instance.model.layers: 862 assert inspect.getsource(layer.mlp.forward) == inspect.getsource( 863 LigerPhi3SwiGLUMLP.forward 864 ) 865 assert inspect.getsource( 866 layer.input_layernorm.forward 867 ) == inspect.getsource(LigerRMSNorm.forward) 868 assert inspect.getsource( 869 layer.post_attention_layernorm.forward 870 ) == inspect.getsource(LigerRMSNorm.forward) 871 872 try: 873 print(dummy_model_instance) 874 except Exception as e: 875 pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}")