/ test / transformers / test_monkey_patch.py
test_monkey_patch.py
  1  import inspect
  2  from inspect import signature
  3  from unittest.mock import MagicMock, Mock, patch
  4  
  5  import pytest
  6  import torch
  7  import transformers
  8  from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
  9  
 10  from liger_kernel.transformers import (
 11      LigerBlockSparseTop2MLP,
 12      LigerGEGLUMLP,
 13      LigerPhi3SwiGLUMLP,
 14      LigerRMSNorm,
 15      LigerSwiGLUMLP,
 16      monkey_patch,
 17  )
 18  from liger_kernel.transformers.layer_norm import LigerLayerNorm
 19  from liger_kernel.transformers.monkey_patch import (
 20      MODEL_TYPE_TO_APPLY_LIGER_FN,
 21      _apply_liger_kernel,
 22      _apply_liger_kernel_to_instance,
 23  )
 24  
 25  
 26  # Check if optional modules are available
 27  def is_mllama_available():
 28      try:
 29          import transformers.models.mllama  # noqa: F401
 30  
 31          return True
 32      except ImportError:
 33          return False
 34  
 35  
 36  def is_qwen2_vl_available():
 37      try:
 38          import transformers.models.qwen2_vl  # noqa: F401
 39  
 40          return True
 41      except ImportError:
 42          return False
 43  
 44  
 45  def test_import_from_root():
 46      try:
 47          from liger_kernel.transformers import (  # noqa: F401
 48              AutoLigerKernelForCausalLM,
 49              apply_liger_kernel_to_gemma,
 50              apply_liger_kernel_to_gemma2,
 51              apply_liger_kernel_to_llama,
 52              apply_liger_kernel_to_mistral,
 53              apply_liger_kernel_to_mixtral,
 54              apply_liger_kernel_to_mllama,
 55              apply_liger_kernel_to_phi3,
 56              apply_liger_kernel_to_qwen2,
 57              apply_liger_kernel_to_qwen2_vl,
 58          )
 59      except Exception:
 60          pytest.fail("Import kernel patch from root fails")
 61  
 62  
 63  def test_apply_liger_kernel_no_supported_model_type():
 64      # Test that calling _apply_liger_kernel with an unsupported model type is a no-op
 65      mock_mistral = Mock()
 66  
 67      with patch.dict(MODEL_TYPE_TO_APPLY_LIGER_FN, {"mistral": mock_mistral}):
 68          _apply_liger_kernel("foobar")
 69          MODEL_TYPE_TO_APPLY_LIGER_FN["mistral"].assert_not_called()
 70  
 71  
 72  def test_apply_liger_kernel_only_supported_model_type_called():
 73      # Test that liger kernel is applied only to the specified model
 74      mock_gemma = Mock()
 75      mock_llama = Mock()
 76      mock_mistral = Mock()
 77  
 78      with patch.dict(
 79          MODEL_TYPE_TO_APPLY_LIGER_FN,
 80          {"gemma": mock_gemma, "llama": mock_llama, "mistral": mock_mistral},
 81      ):
 82          _apply_liger_kernel("llama")
 83          mock_llama.assert_called_once()
 84          mock_gemma.assert_not_called()
 85          mock_mistral.assert_not_called()
 86  
 87  
 88  def test_apply_liger_kernel_only_passes_valid_kwargs():
 89      # Test that keyword args that are not valid for the apply_liger_* function are not passed
 90      mock_llama = Mock()
 91  
 92      def dummy_apply_liger_kernal_to_llama(
 93          rope=False,
 94          cross_entropy=False,
 95          fused_linear_cross_entropy=True,
 96          rms_norm=True,
 97          swiglu=True,
 98      ):
 99          pass
100  
101      apply_liger_kernal_to_llama_sig = signature(dummy_apply_liger_kernal_to_llama)
102  
103      with patch.dict(MODEL_TYPE_TO_APPLY_LIGER_FN, {"llama": mock_llama}):
104          mock_llama.__signature__ = apply_liger_kernal_to_llama_sig
105          (
106              _apply_liger_kernel(
107                  "llama",
108                  rope=False,
109                  fused_linear_cross_entropy=False,
110                  cross_entropy=True,
111                  foobar=True,
112                  barbaz=False,
113              ),
114          )
115          mock_llama.assert_called_once()
116          mock_llama.assert_called_once_with(
117              rope=False,
118              fused_linear_cross_entropy=False,
119              cross_entropy=True,
120          )
121  
122  
123  def test_apply_liger_kernel_to_instance_no_supported_model_type():
124      # Test that calling _apply_liger_kernel_to_instance with an unsupported model type is a no-op
125      mock_mistral = Mock()
126      mock_unknown_model = MagicMock(spec=PreTrainedModel)
127      mock_unknown_model.config = {"model_type": "foobar"}
128  
129      with patch.dict(MODEL_TYPE_TO_APPLY_LIGER_FN, {"mistral": mock_mistral}):
130          _apply_liger_kernel_to_instance(model=mock_unknown_model)
131          MODEL_TYPE_TO_APPLY_LIGER_FN["mistral"].assert_not_called()
132  
133  
134  def test_apply_liger_kernel_to_instance_only_supported_model_type_called():
135      # Test that liger kernel is applied only to the specified model
136      mock_gemma = Mock()
137      mock_llama = Mock()
138      mock_mistral = Mock()
139  
140      mock_llama_model_instance = MagicMock(spec=PreTrainedModel)
141      mock_llama_model_instance.config = MagicMock(spec=PretrainedConfig)
142      mock_llama_model_instance.config.model_type = "llama"
143  
144      with patch.dict(
145          MODEL_TYPE_TO_APPLY_LIGER_FN,
146          {"gemma": mock_gemma, "llama": mock_llama, "mistral": mock_mistral},
147      ):
148          _apply_liger_kernel_to_instance(model=mock_llama_model_instance)
149          mock_llama.assert_called_once()
150          mock_gemma.assert_not_called()
151          mock_mistral.assert_not_called()
152  
153  
154  def test_apply_liger_kernel_to_instance_only_passes_valid_kwargs():
155      # Test that keyword args that are not valid for the apply_liger_* function are not passed
156      mock_llama = Mock()
157  
158      mock_llama_model_instance = MagicMock(spec=PreTrainedModel)
159      mock_llama_model_instance.config = MagicMock(spec=PretrainedConfig)
160      mock_llama_model_instance.config.model_type = "llama"
161  
162      def dummy_apply_liger_kernel_to_llama(
163          rope=False,
164          cross_entropy=False,
165          fused_linear_cross_entropy=True,
166          rms_norm=True,
167          swiglu=True,
168          model=None,
169      ):
170          pass
171  
172      apply_liger_kernel_to_llama_sig = signature(dummy_apply_liger_kernel_to_llama)
173  
174      with patch.dict(MODEL_TYPE_TO_APPLY_LIGER_FN, {"llama": mock_llama}):
175          mock_llama.__signature__ = apply_liger_kernel_to_llama_sig
176          (
177              _apply_liger_kernel_to_instance(
178                  model=mock_llama_model_instance,
179                  rope=False,
180                  fused_linear_cross_entropy=False,
181                  cross_entropy=True,
182                  foobar=True,
183                  barbaz=False,
184              ),
185          )
186          mock_llama.assert_called_once()
187          mock_llama.assert_called_once_with(
188              model=mock_llama_model_instance,
189              rope=False,
190              fused_linear_cross_entropy=False,
191              cross_entropy=True,
192          )
193  
194  
195  def test_patching_apis_match_auto_mapping():
196      # Test that all of the patching APIs present also have a corresponding entry in the auto mapping
197      patching_functions = [
198          func
199          for name, func in inspect.getmembers(monkey_patch, inspect.isfunction)
200          if name.startswith("apply_liger_kernel_to_")
201      ]
202  
203      assert set(patching_functions) == set(MODEL_TYPE_TO_APPLY_LIGER_FN.values())
204  
205  
206  def test_patching_apis_support_patching_model_instance():
207      # Test that all the patching APIs present support passing in
208      # model (PreTrainedModel) as an argument indicating that it supports
209      # patching post-model creation
210      patching_functions = [
211          func
212          for name, func in inspect.getmembers(monkey_patch, inspect.isfunction)
213          if name.startswith("apply_liger_kernel_to_")
214      ]
215  
216      for func in patching_functions:
217          sig = inspect.signature(func)
218          # Ensure 'model' is in the parameters
219          assert (
220              "model" in sig.parameters
221          ), f"{func.__name__} does not have 'model' as an argument. All patching methods must support patching an existing model instance."
222  
223  
224  def test_apply_liger_kernel_to_instance_for_llama():
225      # Ensure any monkey patching is cleaned up for subsequent tests
226      with patch("transformers.models.llama.modeling_llama"):
227          # Instantiate a dummy model
228          config = transformers.models.llama.configuration_llama.LlamaConfig(
229              torch_dtype=torch.bfloat16,
230              rms_norm_eps=1e-5,
231              hidden_size=32,
232              intermediate_size=64,
233              hidden_act="silu",
234              num_hidden_layers=2,
235          )
236          dummy_model_instance = AutoModelForCausalLM.from_config(config)
237  
238          # Check that model instance variables are not yet patched with Liger modules
239          assert inspect.getsource(
240              dummy_model_instance.model.norm.forward
241          ) != inspect.getsource(LigerRMSNorm.forward)
242          for layer in dummy_model_instance.model.layers:
243              assert inspect.getsource(layer.mlp.forward) != inspect.getsource(
244                  LigerSwiGLUMLP.forward
245              )
246              assert inspect.getsource(
247                  layer.input_layernorm.forward
248              ) != inspect.getsource(LigerRMSNorm.forward)
249              assert inspect.getsource(
250                  layer.post_attention_layernorm.forward
251              ) != inspect.getsource(LigerRMSNorm.forward)
252  
253          # Test applying kernels to the model instance
254          _apply_liger_kernel_to_instance(model=dummy_model_instance)
255  
256          # Check that the model's instance variables were correctly patched with Liger modules
257          assert inspect.getsource(
258              dummy_model_instance.model.norm.forward
259          ) == inspect.getsource(LigerRMSNorm.forward)
260          for layer in dummy_model_instance.model.layers:
261              assert inspect.getsource(layer.mlp.forward) == inspect.getsource(
262                  LigerSwiGLUMLP.forward
263              )
264              assert inspect.getsource(
265                  layer.input_layernorm.forward
266              ) == inspect.getsource(LigerRMSNorm.forward)
267              assert inspect.getsource(
268                  layer.post_attention_layernorm.forward
269              ) == inspect.getsource(LigerRMSNorm.forward)
270  
271          # Ensure that the model patched with Liger modules can work properly
272          try:
273              print(dummy_model_instance)
274          except Exception as e:
275              pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}")
276  
277  
278  @pytest.mark.skipif(not is_mllama_available(), reason="mllama module not available")
279  def test_apply_liger_kernel_to_instance_for_mllama_for_conditional_generation():
280      # Ensure any monkey patching is cleaned up for subsequent tests
281      with patch("transformers.models.mllama.modeling_mllama"):
282          from transformers.models.mllama.modeling_mllama import (
283              MllamaForConditionalGeneration,
284          )
285  
286          # Instantiate a dummy model
287          config = transformers.models.mllama.configuration_mllama.MllamaConfig(
288              torch_dtype=torch.bfloat16,
289              text_config=transformers.models.mllama.configuration_mllama.MllamaTextConfig(
290                  rms_norm_eps=1e-5,
291                  hidden_size=32,
292                  intermediate_size=64,
293                  hidden_act="silu",
294                  num_hidden_layers=2,
295                  rope_scaling=dict(
296                      factor=8.0,
297                      high_freq_factor=4.0,
298                      low_freq_factor=1.0,
299                      original_max_position_embeddings=8192,
300                      rope_type="llama3",
301                  ),
302              ),
303              vision_config=transformers.models.mllama.configuration_mllama.MllamaVisionConfig(
304                  rms_norm_eps=1e-5,
305                  hidden_size=32,
306                  intermediate_size=64,
307                  hidden_act="gelu",
308                  num_hidden_layers=2,
309                  vision_output_dim=64,
310              ),
311          )
312          dummy_model_instance = MllamaForConditionalGeneration._from_config(config)
313  
314          assert isinstance(dummy_model_instance, MllamaForConditionalGeneration)
315  
316          # Check that model instance variables are not yet patched with Liger modules
317          assert inspect.getsource(
318              dummy_model_instance.language_model.model.norm.forward
319          ) != inspect.getsource(LigerRMSNorm.forward)
320          for layer in dummy_model_instance.language_model.model.layers:
321              assert inspect.getsource(layer.mlp.forward) != inspect.getsource(
322                  LigerSwiGLUMLP.forward
323              )
324              assert inspect.getsource(
325                  layer.input_layernorm.forward
326              ) != inspect.getsource(LigerRMSNorm.forward)
327              assert inspect.getsource(
328                  layer.post_attention_layernorm.forward
329              ) != inspect.getsource(LigerRMSNorm.forward)
330  
331          assert inspect.getsource(
332              dummy_model_instance.vision_model.layernorm_pre.forward
333          ) != inspect.getsource(LigerLayerNorm.forward)
334          assert inspect.getsource(
335              dummy_model_instance.vision_model.layernorm_post.forward
336          ) != inspect.getsource(LigerLayerNorm.forward)
337          for layer in dummy_model_instance.vision_model.transformer.layers:
338              assert inspect.getsource(
339                  layer.input_layernorm.forward
340              ) != inspect.getsource(LigerLayerNorm.forward)
341              assert inspect.getsource(
342                  layer.post_attention_layernorm.forward
343              ) != inspect.getsource(LigerLayerNorm.forward)
344          for layer in dummy_model_instance.vision_model.global_transformer.layers:
345              assert inspect.getsource(
346                  layer.input_layernorm.forward
347              ) != inspect.getsource(LigerLayerNorm.forward)
348              assert inspect.getsource(
349                  layer.post_attention_layernorm.forward
350              ) != inspect.getsource(LigerLayerNorm.forward)
351  
352          # Test applying kernels to the model instance
353          _apply_liger_kernel_to_instance(model=dummy_model_instance)
354  
355          # Check that the model's instance variables were correctly patched with Liger modules
356          assert inspect.getsource(
357              dummy_model_instance.language_model.model.norm.forward
358          ) == inspect.getsource(LigerRMSNorm.forward)
359          for layer in dummy_model_instance.language_model.model.layers:
360              assert inspect.getsource(layer.mlp.forward) == inspect.getsource(
361                  LigerSwiGLUMLP.forward
362              )
363              assert inspect.getsource(
364                  layer.input_layernorm.forward
365              ) == inspect.getsource(LigerRMSNorm.forward)
366              assert inspect.getsource(
367                  layer.post_attention_layernorm.forward
368              ) == inspect.getsource(LigerRMSNorm.forward)
369  
370          assert inspect.getsource(
371              dummy_model_instance.vision_model.layernorm_pre.forward
372          ) == inspect.getsource(LigerLayerNorm.forward)
373          assert inspect.getsource(
374              dummy_model_instance.vision_model.layernorm_post.forward
375          ) == inspect.getsource(LigerLayerNorm.forward)
376          for layer in dummy_model_instance.vision_model.transformer.layers:
377              assert inspect.getsource(
378                  layer.input_layernorm.forward
379              ) == inspect.getsource(LigerLayerNorm.forward)
380              assert inspect.getsource(
381                  layer.post_attention_layernorm.forward
382              ) == inspect.getsource(LigerLayerNorm.forward)
383          for layer in dummy_model_instance.vision_model.global_transformer.layers:
384              assert inspect.getsource(
385                  layer.input_layernorm.forward
386              ) == inspect.getsource(LigerLayerNorm.forward)
387              assert inspect.getsource(
388                  layer.post_attention_layernorm.forward
389              ) == inspect.getsource(LigerLayerNorm.forward)
390  
391          try:
392              print(dummy_model_instance)
393          except Exception as e:
394              pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}")
395  
396  
397  @pytest.mark.skipif(not is_mllama_available(), reason="mllama module not available")
398  def test_apply_liger_kernel_to_instance_for_mllama_for_causal_lm():
399      # Ensure any monkey patching is cleaned up for subsequent tests
400      with patch("transformers.models.mllama.modeling_mllama"):
401          from transformers.models.mllama.modeling_mllama import MllamaForCausalLM
402  
403          # Instantiate a dummy model
404          config = transformers.models.mllama.configuration_mllama.MllamaTextConfig(
405              rms_norm_eps=1e-5,
406              hidden_size=32,
407              intermediate_size=64,
408              hidden_act="silu",
409              num_hidden_layers=2,
410              rope_scaling=dict(
411                  factor=8.0,
412                  high_freq_factor=4.0,
413                  low_freq_factor=1.0,
414                  original_max_position_embeddings=8192,
415                  rope_type="llama3",
416              ),
417          )
418  
419          dummy_model_instance = MllamaForCausalLM._from_config(config)
420  
421          assert isinstance(dummy_model_instance, MllamaForCausalLM)
422  
423          # Check that model instance variables are not yet patched with Liger modules
424          assert not isinstance(dummy_model_instance.model.norm, LigerRMSNorm)
425          for layer in dummy_model_instance.model.layers:
426              assert inspect.getsource(layer.mlp.forward) != inspect.getsource(
427                  LigerSwiGLUMLP.forward
428              )
429              assert inspect.getsource(
430                  layer.input_layernorm.forward
431              ) != inspect.getsource(LigerRMSNorm.forward)
432              assert inspect.getsource(
433                  layer.post_attention_layernorm.forward
434              ) != inspect.getsource(LigerRMSNorm.forward)
435  
436          # Test applying kernels to the model instance
437          _apply_liger_kernel_to_instance(model=dummy_model_instance)
438  
439          # Check that the model's instance variables were correctly patched with Liger modules
440          assert inspect.getsource(
441              dummy_model_instance.model.norm.forward
442          ) == inspect.getsource(LigerRMSNorm.forward)
443          for layer in dummy_model_instance.model.layers:
444              assert inspect.getsource(layer.mlp.forward) == inspect.getsource(
445                  LigerSwiGLUMLP.forward
446              )
447              assert inspect.getsource(
448                  layer.input_layernorm.forward
449              ) == inspect.getsource(LigerRMSNorm.forward)
450              assert inspect.getsource(
451                  layer.post_attention_layernorm.forward
452              ) == inspect.getsource(LigerRMSNorm.forward)
453  
454          try:
455              print(dummy_model_instance)
456          except Exception as e:
457              pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}")
458  
459  
460  def test_apply_liger_kernel_to_instance_for_mistral():
461      # Ensure any monkey patching is cleaned up for subsequent tests
462      with patch("transformers.models.mistral.modeling_mistral"):
463          # Instantiate a dummy model
464          config = transformers.models.mistral.configuration_mistral.MistralConfig(
465              torch_dtype=torch.bfloat16,
466              rms_norm_eps=1e-5,
467              hidden_size=32,
468              intermediate_size=64,
469              hidden_act="silu",
470              num_hidden_layers=2,
471          )
472          dummy_model_instance = AutoModelForCausalLM.from_config(config)
473  
474          # Check that model instance variables are not yet patched with Liger modules
475          assert inspect.getsource(
476              dummy_model_instance.model.norm.forward
477          ) != inspect.getsource(LigerRMSNorm.forward)
478          for layer in dummy_model_instance.model.layers:
479              assert inspect.getsource(layer.mlp.forward) != inspect.getsource(
480                  LigerSwiGLUMLP.forward
481              )
482              assert inspect.getsource(
483                  layer.input_layernorm.forward
484              ) != inspect.getsource(LigerRMSNorm.forward)
485              assert inspect.getsource(
486                  layer.post_attention_layernorm.forward
487              ) != inspect.getsource(LigerRMSNorm.forward)
488  
489          # Test applying kernels to the model instance
490          _apply_liger_kernel_to_instance(model=dummy_model_instance)
491  
492          # Check that the model's instance variables were correctly patched with Liger modules
493          assert inspect.getsource(
494              dummy_model_instance.model.norm.forward
495          ) == inspect.getsource(LigerRMSNorm.forward)
496          for layer in dummy_model_instance.model.layers:
497              assert inspect.getsource(layer.mlp.forward) == inspect.getsource(
498                  LigerSwiGLUMLP.forward
499              )
500              assert inspect.getsource(
501                  layer.input_layernorm.forward
502              ) == inspect.getsource(LigerRMSNorm.forward)
503              assert inspect.getsource(
504                  layer.post_attention_layernorm.forward
505              ) == inspect.getsource(LigerRMSNorm.forward)
506  
507          try:
508              print(dummy_model_instance)
509          except Exception as e:
510              pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}")
511  
512  
513  def test_apply_liger_kernel_to_instance_for_mixtral():
514      # Ensure any monkey patching is cleaned up for subsequent tests
515      with patch("transformers.models.mixtral.modeling_mixtral"):
516          # Instantiate a dummy model
517          config = transformers.models.mixtral.configuration_mixtral.MixtralConfig(
518              torch_dtype=torch.bfloat16,
519              rms_norm_eps=1e-5,
520              hidden_size=32,
521              intermediate_size=64,
522              hidden_act="silu",
523              num_hidden_layers=2,
524              num_local_experts=3,
525              num_experts_per_tok=2,
526          )
527          dummy_model_instance = AutoModelForCausalLM.from_config(config)
528  
529          # Check that model instance variables are not yet patched with Liger modules
530          assert inspect.getsource(
531              dummy_model_instance.model.norm.forward
532          ) != inspect.getsource(LigerRMSNorm.forward)
533          for layer in dummy_model_instance.model.layers:
534              for expert in layer.block_sparse_moe.experts:
535                  assert inspect.getsource(expert.forward) != inspect.getsource(
536                      LigerBlockSparseTop2MLP.forward
537                  )
538              assert inspect.getsource(
539                  layer.input_layernorm.forward
540              ) != inspect.getsource(LigerRMSNorm.forward)
541              assert inspect.getsource(
542                  layer.post_attention_layernorm.forward
543              ) != inspect.getsource(LigerRMSNorm.forward)
544  
545          # Test applying kernels to the model instance
546          _apply_liger_kernel_to_instance(model=dummy_model_instance)
547  
548          # Check that the model's instance variables were correctly patched with Liger modules
549          assert inspect.getsource(
550              dummy_model_instance.model.norm.forward
551          ) == inspect.getsource(LigerRMSNorm.forward)
552          for layer in dummy_model_instance.model.layers:
553              for expert in layer.block_sparse_moe.experts:
554                  assert inspect.getsource(expert.forward) == inspect.getsource(
555                      LigerBlockSparseTop2MLP.forward
556                  )
557              assert inspect.getsource(
558                  layer.input_layernorm.forward
559              ) == inspect.getsource(LigerRMSNorm.forward)
560              assert inspect.getsource(
561                  layer.post_attention_layernorm.forward
562              ) == inspect.getsource(LigerRMSNorm.forward)
563  
564          try:
565              print(dummy_model_instance)
566          except Exception as e:
567              pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}")
568  
569  
570  def test_apply_liger_kernel_to_instance_for_gemma():
571      # Ensure any monkey patching is cleaned up for subsequent tests
572      with patch("transformers.models.gemma.modeling_gemma"):
573          # Instantiate a dummy model
574          config = transformers.models.gemma.configuration_gemma.GemmaConfig(
575              torch_dtype=torch.bfloat16,
576              rms_norm_eps=1e-5,
577              hidden_size=32,
578              intermediate_size=64,
579              hidden_act="silu",
580              num_hidden_layers=2,
581          )
582          dummy_model_instance = AutoModelForCausalLM.from_config(config)
583  
584          # Check that model instance variables are not yet patched with Liger modules
585          assert inspect.getsource(
586              dummy_model_instance.model.norm.forward
587          ) != inspect.getsource(LigerRMSNorm.forward)
588          for layer in dummy_model_instance.model.layers:
589              assert inspect.getsource(layer.mlp.forward) != inspect.getsource(
590                  LigerGEGLUMLP.forward
591              )
592              assert inspect.getsource(
593                  layer.input_layernorm.forward
594              ) != inspect.getsource(LigerRMSNorm.forward)
595              assert inspect.getsource(
596                  layer.post_attention_layernorm.forward
597              ) != inspect.getsource(LigerRMSNorm.forward)
598  
599          # Test applying kernels to the model instance
600          _apply_liger_kernel_to_instance(model=dummy_model_instance)
601  
602          # Check that the model's instance variables were correctly patched with Liger modules
603          assert inspect.getsource(
604              dummy_model_instance.model.norm.forward
605          ) == inspect.getsource(LigerRMSNorm.forward)
606          for layer in dummy_model_instance.model.layers:
607              assert inspect.getsource(layer.mlp.forward) == inspect.getsource(
608                  LigerGEGLUMLP.forward
609              )
610              assert inspect.getsource(
611                  layer.input_layernorm.forward
612              ) == inspect.getsource(LigerRMSNorm.forward)
613              assert inspect.getsource(
614                  layer.post_attention_layernorm.forward
615              ) == inspect.getsource(LigerRMSNorm.forward)
616  
617          try:
618              print(dummy_model_instance)
619          except Exception as e:
620              pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}")
621  
622  
623  def test_apply_liger_kernel_to_instance_for_gemma2():
624      # Ensure any monkey patching is cleaned up for subsequent tests
625      with patch("transformers.models.gemma2.modeling_gemma2"):
626          # Instantiate a dummy model
627          config = transformers.models.gemma2.configuration_gemma2.Gemma2Config(
628              torch_dtype=torch.bfloat16,
629              rms_norm_eps=1e-5,
630              hidden_size=32,
631              intermediate_size=64,
632              hidden_act="silu",
633              num_hidden_layers=2,
634          )
635          dummy_model_instance = AutoModelForCausalLM.from_config(config)
636  
637          # Check that model instance variables are not yet patched with Liger modules
638          assert inspect.getsource(
639              dummy_model_instance.model.norm.forward
640          ) != inspect.getsource(LigerRMSNorm.forward)
641          for layer in dummy_model_instance.model.layers:
642              assert inspect.getsource(layer.mlp.forward) != inspect.getsource(
643                  LigerGEGLUMLP.forward
644              )
645              assert inspect.getsource(
646                  layer.input_layernorm.forward
647              ) != inspect.getsource(LigerRMSNorm.forward)
648              assert inspect.getsource(
649                  layer.post_attention_layernorm.forward
650              ) != inspect.getsource(LigerRMSNorm.forward)
651              assert inspect.getsource(
652                  layer.pre_feedforward_layernorm.forward
653              ) != inspect.getsource(LigerRMSNorm.forward)
654              assert inspect.getsource(
655                  layer.post_feedforward_layernorm.forward
656              ) != inspect.getsource(LigerRMSNorm.forward)
657  
658          # Test applying kernels to the model instance
659          _apply_liger_kernel_to_instance(model=dummy_model_instance)
660  
661          # Check that the model's instance variables were correctly patched with Liger modules
662          assert inspect.getsource(
663              dummy_model_instance.model.norm.forward
664          ) == inspect.getsource(LigerRMSNorm.forward)
665          for layer in dummy_model_instance.model.layers:
666              assert inspect.getsource(layer.mlp.forward) == inspect.getsource(
667                  LigerGEGLUMLP.forward
668              )
669              assert inspect.getsource(
670                  layer.input_layernorm.forward
671              ) == inspect.getsource(LigerRMSNorm.forward)
672              assert inspect.getsource(
673                  layer.post_attention_layernorm.forward
674              ) == inspect.getsource(LigerRMSNorm.forward)
675              assert inspect.getsource(
676                  layer.pre_feedforward_layernorm.forward
677              ) == inspect.getsource(LigerRMSNorm.forward)
678              assert inspect.getsource(
679                  layer.post_feedforward_layernorm.forward
680              ) == inspect.getsource(LigerRMSNorm.forward)
681  
682          try:
683              print(dummy_model_instance)
684          except Exception as e:
685              pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}")
686  
687  
688  def test_apply_liger_kernel_to_instance_for_qwen2():
689      # Ensure any monkey patching is cleaned up for subsequent tests
690      with patch("transformers.models.qwen2.modeling_qwen2"):
691          # Instantiate a dummy model
692          config = transformers.models.qwen2.configuration_qwen2.Qwen2Config(
693              torch_dtype=torch.bfloat16,
694              rms_norm_eps=1e-5,
695              hidden_size=32,
696              intermediate_size=64,
697              hidden_act="silu",
698              num_hidden_layers=2,
699          )
700          dummy_model_instance = AutoModelForCausalLM.from_config(config)
701  
702          # Check that model instance variables are not yet patched with Liger modules
703          assert inspect.getsource(
704              dummy_model_instance.model.norm.forward
705          ) != inspect.getsource(LigerRMSNorm.forward)
706          for layer in dummy_model_instance.model.layers:
707              assert inspect.getsource(layer.mlp.forward) != inspect.getsource(
708                  LigerSwiGLUMLP.forward
709              )
710              assert inspect.getsource(
711                  layer.input_layernorm.forward
712              ) != inspect.getsource(LigerRMSNorm.forward)
713              assert inspect.getsource(
714                  layer.post_attention_layernorm.forward
715              ) != inspect.getsource(LigerRMSNorm.forward)
716  
717          # Test applying kernels to the model instance
718          _apply_liger_kernel_to_instance(model=dummy_model_instance)
719  
720          # Check that the model's instance variables were correctly patched with Liger modules
721          assert inspect.getsource(
722              dummy_model_instance.model.norm.forward
723          ) == inspect.getsource(LigerRMSNorm.forward)
724          for layer in dummy_model_instance.model.layers:
725              assert inspect.getsource(layer.mlp.forward) == inspect.getsource(
726                  LigerSwiGLUMLP.forward
727              )
728              assert inspect.getsource(
729                  layer.input_layernorm.forward
730              ) == inspect.getsource(LigerRMSNorm.forward)
731              assert inspect.getsource(
732                  layer.post_attention_layernorm.forward
733              ) == inspect.getsource(LigerRMSNorm.forward)
734  
735          try:
736              print(dummy_model_instance)
737          except Exception as e:
738              pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}")
739  
740  
741  @pytest.mark.skipif(not is_qwen2_vl_available(), reason="qwen2_vl module not available")
742  def test_apply_liger_kernel_to_instance_for_qwen2_vl():
743      # Ensure any monkey patching is cleaned up for subsequent tests
744      with patch("transformers.models.qwen2_vl.modeling_qwen2_vl"):
745          from transformers.models.qwen2_vl.modeling_qwen2_vl import (
746              Qwen2VLForConditionalGeneration,
747          )
748  
749          # Instantiate a dummy model
750          config = transformers.models.qwen2_vl.configuration_qwen2_vl.Qwen2VLConfig(
751              torch_dtype=torch.bfloat16,
752              rms_norm_eps=1e-5,
753              hidden_size=32,
754              intermediate_size=48,
755              embed_dim=16,
756              hidden_act="silu",
757              num_hidden_layers=2,
758              num_attention_heads=2,
759              max_position_embeddings=128,
760              vocab_size=1000,
761              vision_config={
762                  "depth": 4,
763                  "embed_dim": 128,
764                  "num_heads": 8,
765                  "hidden_size": 1024,
766              },
767          )
768          dummy_model_instance = Qwen2VLForConditionalGeneration._from_config(config)
769  
770          assert isinstance(dummy_model_instance, Qwen2VLForConditionalGeneration)
771  
772          # Check that model instance variables are not yet patched with Liger modules
773          assert inspect.getsource(
774              dummy_model_instance.model.norm.forward
775          ) != inspect.getsource(LigerRMSNorm.forward)
776          for layer in dummy_model_instance.model.layers:
777              assert inspect.getsource(layer.mlp.forward) != inspect.getsource(
778                  LigerSwiGLUMLP.forward
779              )
780              assert inspect.getsource(
781                  layer.input_layernorm.forward
782              ) != inspect.getsource(LigerRMSNorm.forward)
783              assert inspect.getsource(
784                  layer.post_attention_layernorm.forward
785              ) != inspect.getsource(LigerRMSNorm.forward)
786          for vision_block in dummy_model_instance.visual.blocks:
787              assert inspect.getsource(vision_block.norm1.forward) != inspect.getsource(
788                  LigerLayerNorm.forward
789              )
790              assert inspect.getsource(vision_block.norm2.forward) != inspect.getsource(
791                  LigerLayerNorm.forward
792              )
793  
794          # Test applying kernels to the model instance
795          _apply_liger_kernel_to_instance(model=dummy_model_instance)
796  
797          # Check that the model's instance variables were correctly patched with Liger modules
798          assert inspect.getsource(
799              dummy_model_instance.model.norm.forward
800          ) == inspect.getsource(LigerRMSNorm.forward)
801          for layer in dummy_model_instance.model.layers:
802              assert inspect.getsource(layer.mlp.forward) == inspect.getsource(
803                  LigerSwiGLUMLP.forward
804              )
805              assert inspect.getsource(
806                  layer.input_layernorm.forward
807              ) == inspect.getsource(LigerRMSNorm.forward)
808              assert inspect.getsource(
809                  layer.post_attention_layernorm.forward
810              ) == inspect.getsource(LigerRMSNorm.forward)
811          for vision_block in dummy_model_instance.visual.blocks:
812              assert inspect.getsource(vision_block.norm1.forward) == inspect.getsource(
813                  LigerLayerNorm.forward
814              )
815              assert inspect.getsource(vision_block.norm2.forward) == inspect.getsource(
816                  LigerLayerNorm.forward
817              )
818  
819          try:
820              print(dummy_model_instance)
821          except Exception as e:
822              pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}")
823  
824  
825  def test_apply_liger_kernel_to_instance_for_phi3():
826      # Ensure any monkey patching is cleaned up for subsequent tests
827      with patch("transformers.models.phi3.modeling_phi3"):
828          # Instantiate a dummy model
829          config = transformers.models.phi3.configuration_phi3.Phi3Config(
830              torch_dtype=torch.bfloat16,
831              rms_norm_eps=1e-5,
832              hidden_size=32,
833              intermediate_size=64,
834              hidden_act="silu",
835              num_hidden_layers=2,
836          )
837          dummy_model_instance = AutoModelForCausalLM.from_config(config)
838  
839          # Check that model instance variables are not yet patched with Liger modules
840          assert inspect.getsource(
841              dummy_model_instance.model.norm.forward
842          ) != inspect.getsource(LigerRMSNorm.forward)
843          for layer in dummy_model_instance.model.layers:
844              assert inspect.getsource(layer.mlp.forward) != inspect.getsource(
845                  LigerPhi3SwiGLUMLP.forward
846              )
847              assert inspect.getsource(
848                  layer.input_layernorm.forward
849              ) != inspect.getsource(LigerRMSNorm.forward)
850              assert inspect.getsource(
851                  layer.post_attention_layernorm.forward
852              ) != inspect.getsource(LigerRMSNorm.forward)
853  
854          # Test applying kernels to the model instance
855          _apply_liger_kernel_to_instance(model=dummy_model_instance)
856  
857          # Check that the model's instance variables were correctly patched with Liger modules
858          assert inspect.getsource(
859              dummy_model_instance.model.norm.forward
860          ) == inspect.getsource(LigerRMSNorm.forward)
861          for layer in dummy_model_instance.model.layers:
862              assert inspect.getsource(layer.mlp.forward) == inspect.getsource(
863                  LigerPhi3SwiGLUMLP.forward
864              )
865              assert inspect.getsource(
866                  layer.input_layernorm.forward
867              ) == inspect.getsource(LigerRMSNorm.forward)
868              assert inspect.getsource(
869                  layer.post_attention_layernorm.forward
870              ) == inspect.getsource(LigerRMSNorm.forward)
871  
872          try:
873              print(dummy_model_instance)
874          except Exception as e:
875              pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}")