test_transformers.py
1 import pytest 2 3 4 def test_import_from_root(): 5 try: 6 from liger_kernel.transformers import ( # noqa: F401 7 LigerBlockSparseTop2MLP, 8 LigerCrossEntropyLoss, 9 LigerFusedLinearCrossEntropyLoss, 10 LigerGEGLUMLP, 11 LigerLayerNorm, 12 LigerPhi3SwiGLUMLP, 13 LigerRMSNorm, 14 LigerSwiGLUMLP, 15 liger_rotary_pos_emb, 16 ) 17 except Exception: 18 pytest.fail("Import kernels from root fails")