test_rms_norm.py
1 import os 2 from test.utils import assert_verbose_allclose, set_seed, supports_bfloat16 3 4 import pytest 5 import torch 6 import torch.nn as nn 7 8 from liger_kernel.ops.rms_norm import LigerRMSNormFunction 9 from liger_kernel.transformers.functional import liger_rms_norm 10 from liger_kernel.transformers.rms_norm import LigerRMSNorm 11 from liger_kernel.utils import infer_device 12 13 device = infer_device() 14 15 set_seed(42) 16 torch.use_deterministic_algorithms(True) 17 18 # Only setting torch.use_deterministic_algorithms(True) might throw the following error: 19 # RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithms(true)`, 20 # but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an 21 # environment variable before running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. For more information, 22 # go to https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility 23 24 if device == "cuda": 25 os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 26 27 SLEEP_SECONDS = 0.1 28 29 30 class BaseRMSNorm(nn.Module): 31 def __init__(self, hidden_size, eps=1e-6): 32 super().__init__() 33 self.weight = nn.Parameter(torch.ones(hidden_size)) 34 self.variance_epsilon = eps 35 36 def forward(self, hidden_states): 37 input_dtype = hidden_states.dtype 38 variance = hidden_states.pow(2).mean(-1, keepdim=True) 39 hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 40 return self.weight * hidden_states.to(input_dtype) 41 42 43 # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L112 44 class LlamaRMSNorm(nn.Module): 45 def __init__(self, hidden_size, eps=1e-6): 46 """ 47 LlamaRMSNorm is equivalent to T5LayerNorm 48 """ 49 super().__init__() 50 self.weight = nn.Parameter(torch.ones(hidden_size)) 51 self.variance_epsilon = eps 52 53 def forward(self, hidden_states): 54 input_dtype = hidden_states.dtype 55 hidden_states = hidden_states.to(torch.float32) 56 variance = hidden_states.pow(2).mean(-1, keepdim=True) 57 hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 58 return self.weight * hidden_states.to(input_dtype) 59 60 61 # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L122 62 class GemmaRMSNorm(nn.Module): 63 def __init__(self, hidden_size: int, eps: float = 1e-6): 64 super().__init__() 65 self.eps = eps 66 self.weight = nn.Parameter(torch.ones(hidden_size)) 67 68 def _norm(self, x): 69 return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 70 71 def forward(self, x): 72 output = self._norm(x.float()) 73 output = output * (1.0 + self.weight.float()) 74 return output.type_as(x) 75 76 77 @pytest.mark.parametrize( 78 "bs, sl, hd", 79 [ 80 (2, 128, 512), 81 # weird shapes 82 (5, 123, 123), 83 ], 84 ) 85 @pytest.mark.parametrize( 86 "dtype, atol, rtol", 87 [ 88 (torch.float32, 1e-4, 1e-6), 89 pytest.param( 90 torch.bfloat16, 91 2e-1, 92 2e-2, 93 marks=pytest.mark.skipif( 94 not supports_bfloat16(), reason="bfloat16 not supported on this GPU" 95 ), 96 ), 97 ], 98 ) 99 @pytest.mark.parametrize( 100 "reference, offset, casting_mode", 101 [ 102 (LlamaRMSNorm, 0.0, "llama"), 103 (GemmaRMSNorm, 1.0, "gemma"), 104 (BaseRMSNorm, 0.0, "none"), 105 ], 106 ) 107 @pytest.mark.parametrize( 108 "in_place", 109 [ 110 True, 111 False, 112 ], 113 ) 114 def test_correctness( 115 bs, sl, hd, dtype, atol, rtol, reference, offset, casting_mode, in_place 116 ): 117 _tensor = torch.randn(bs, sl, hd, device=device, dtype=dtype) 118 119 h1 = _tensor.clone().requires_grad_(True) 120 h2 = _tensor.clone().requires_grad_(True) 121 122 # do 123 do = torch.randn(bs, sl, hd, device=device, dtype=dtype) 124 125 # reference (llama or gemma) 126 ref_rms = reference(hidden_size=hd).to(device).to(dtype) 127 ref_o = ref_rms(h1) 128 ref_o.backward(do, retain_graph=True) 129 130 # triton 131 triton_rms = ( 132 LigerRMSNorm( 133 hidden_size=hd, offset=offset, casting_mode=casting_mode, in_place=in_place 134 ) 135 .to(device) 136 .to(dtype) 137 ) 138 triton_o = triton_rms(h2) 139 triton_o.backward(do, retain_graph=True) 140 141 assert_verbose_allclose(ref_o, triton_o, atol=atol, rtol=rtol) 142 assert_verbose_allclose( 143 ref_rms.weight.grad, triton_rms.weight.grad, atol=atol, rtol=rtol 144 ) 145 print(f"{h1.grad=}") 146 print(f"{h2.grad=}") 147 assert_verbose_allclose(h1.grad, h2.grad, atol=atol, rtol=rtol, max_print=20) 148 149 150 @pytest.mark.parametrize( 151 "bs, sl, hd", 152 [ 153 (2, 2, 8), 154 # weird shapes 155 (9, 7, 41), 156 ], 157 ) 158 @pytest.mark.parametrize( 159 "dtype, atol, rtol", 160 [ 161 (torch.float32, 1e-4, 1e-6), 162 (torch.bfloat16, 2e-1, 2e-2), 163 ], 164 ) 165 @pytest.mark.parametrize( 166 "reference, offset, casting_mode", 167 [ 168 (LlamaRMSNorm, 0.0, "llama"), 169 (GemmaRMSNorm, 1.0, "gemma"), 170 ], 171 ) 172 def test_correctness_functional( 173 bs, sl, hd, dtype, atol, rtol, reference, offset, casting_mode 174 ): 175 # h 176 _tensor = torch.randn(bs, sl, hd, device=device, dtype=dtype) 177 178 h1 = _tensor.clone().requires_grad_(True) 179 h2 = _tensor.clone().requires_grad_(True) 180 181 w = torch.randn(hd, device=device, dtype=dtype) 182 183 y1 = liger_rms_norm(X=h1, W=w, eps=1e-6, offset=offset, casting_mode=casting_mode) 184 y2 = LigerRMSNormFunction.apply(h2, w, 1e-6, offset, casting_mode) 185 186 assert torch.allclose(y1, y2, atol=atol, rtol=rtol) 187 188 grad = torch.randn_like(y2) 189 190 y1.backward(grad) 191 y2.backward(grad) 192 193 assert torch.allclose(h1.grad, h2.grad, atol=atol, rtol=rtol)