/ test / transformers / test_rms_norm.py
test_rms_norm.py
  1  import os
  2  from test.utils import assert_verbose_allclose, set_seed, supports_bfloat16
  3  
  4  import pytest
  5  import torch
  6  import torch.nn as nn
  7  
  8  from liger_kernel.ops.rms_norm import LigerRMSNormFunction
  9  from liger_kernel.transformers.functional import liger_rms_norm
 10  from liger_kernel.transformers.rms_norm import LigerRMSNorm
 11  from liger_kernel.utils import infer_device
 12  
 13  device = infer_device()
 14  
 15  set_seed(42)
 16  torch.use_deterministic_algorithms(True)
 17  
 18  #  Only setting torch.use_deterministic_algorithms(True) might throw the following error:
 19  #  RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithms(true)`,
 20  #  but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an
 21  #  environment variable before running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. For more information,
 22  #  go to https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility
 23  
 24  if device == "cuda":
 25      os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
 26  
 27  SLEEP_SECONDS = 0.1
 28  
 29  
 30  class BaseRMSNorm(nn.Module):
 31      def __init__(self, hidden_size, eps=1e-6):
 32          super().__init__()
 33          self.weight = nn.Parameter(torch.ones(hidden_size))
 34          self.variance_epsilon = eps
 35  
 36      def forward(self, hidden_states):
 37          input_dtype = hidden_states.dtype
 38          variance = hidden_states.pow(2).mean(-1, keepdim=True)
 39          hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
 40          return self.weight * hidden_states.to(input_dtype)
 41  
 42  
 43  # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L112
 44  class LlamaRMSNorm(nn.Module):
 45      def __init__(self, hidden_size, eps=1e-6):
 46          """
 47          LlamaRMSNorm is equivalent to T5LayerNorm
 48          """
 49          super().__init__()
 50          self.weight = nn.Parameter(torch.ones(hidden_size))
 51          self.variance_epsilon = eps
 52  
 53      def forward(self, hidden_states):
 54          input_dtype = hidden_states.dtype
 55          hidden_states = hidden_states.to(torch.float32)
 56          variance = hidden_states.pow(2).mean(-1, keepdim=True)
 57          hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
 58          return self.weight * hidden_states.to(input_dtype)
 59  
 60  
 61  # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L122
 62  class GemmaRMSNorm(nn.Module):
 63      def __init__(self, hidden_size: int, eps: float = 1e-6):
 64          super().__init__()
 65          self.eps = eps
 66          self.weight = nn.Parameter(torch.ones(hidden_size))
 67  
 68      def _norm(self, x):
 69          return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
 70  
 71      def forward(self, x):
 72          output = self._norm(x.float())
 73          output = output * (1.0 + self.weight.float())
 74          return output.type_as(x)
 75  
 76  
 77  @pytest.mark.parametrize(
 78      "bs, sl, hd",
 79      [
 80          (2, 128, 512),
 81          # weird shapes
 82          (5, 123, 123),
 83      ],
 84  )
 85  @pytest.mark.parametrize(
 86      "dtype, atol, rtol",
 87      [
 88          (torch.float32, 1e-4, 1e-6),
 89          pytest.param(
 90              torch.bfloat16,
 91              2e-1,
 92              2e-2,
 93              marks=pytest.mark.skipif(
 94                  not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
 95              ),
 96          ),
 97      ],
 98  )
 99  @pytest.mark.parametrize(
100      "reference, offset, casting_mode",
101      [
102          (LlamaRMSNorm, 0.0, "llama"),
103          (GemmaRMSNorm, 1.0, "gemma"),
104          (BaseRMSNorm, 0.0, "none"),
105      ],
106  )
107  @pytest.mark.parametrize(
108      "in_place",
109      [
110          True,
111          False,
112      ],
113  )
114  def test_correctness(
115      bs, sl, hd, dtype, atol, rtol, reference, offset, casting_mode, in_place
116  ):
117      _tensor = torch.randn(bs, sl, hd, device=device, dtype=dtype)
118  
119      h1 = _tensor.clone().requires_grad_(True)
120      h2 = _tensor.clone().requires_grad_(True)
121  
122      # do
123      do = torch.randn(bs, sl, hd, device=device, dtype=dtype)
124  
125      # reference (llama or gemma)
126      ref_rms = reference(hidden_size=hd).to(device).to(dtype)
127      ref_o = ref_rms(h1)
128      ref_o.backward(do, retain_graph=True)
129  
130      # triton
131      triton_rms = (
132          LigerRMSNorm(
133              hidden_size=hd, offset=offset, casting_mode=casting_mode, in_place=in_place
134          )
135          .to(device)
136          .to(dtype)
137      )
138      triton_o = triton_rms(h2)
139      triton_o.backward(do, retain_graph=True)
140  
141      assert_verbose_allclose(ref_o, triton_o, atol=atol, rtol=rtol)
142      assert_verbose_allclose(
143          ref_rms.weight.grad, triton_rms.weight.grad, atol=atol, rtol=rtol
144      )
145      print(f"{h1.grad=}")
146      print(f"{h2.grad=}")
147      assert_verbose_allclose(h1.grad, h2.grad, atol=atol, rtol=rtol, max_print=20)
148  
149  
150  @pytest.mark.parametrize(
151      "bs, sl, hd",
152      [
153          (2, 2, 8),
154          # weird shapes
155          (9, 7, 41),
156      ],
157  )
158  @pytest.mark.parametrize(
159      "dtype, atol, rtol",
160      [
161          (torch.float32, 1e-4, 1e-6),
162          (torch.bfloat16, 2e-1, 2e-2),
163      ],
164  )
165  @pytest.mark.parametrize(
166      "reference, offset, casting_mode",
167      [
168          (LlamaRMSNorm, 0.0, "llama"),
169          (GemmaRMSNorm, 1.0, "gemma"),
170      ],
171  )
172  def test_correctness_functional(
173      bs, sl, hd, dtype, atol, rtol, reference, offset, casting_mode
174  ):
175      # h
176      _tensor = torch.randn(bs, sl, hd, device=device, dtype=dtype)
177  
178      h1 = _tensor.clone().requires_grad_(True)
179      h2 = _tensor.clone().requires_grad_(True)
180  
181      w = torch.randn(hd, device=device, dtype=dtype)
182  
183      y1 = liger_rms_norm(X=h1, W=w, eps=1e-6, offset=offset, casting_mode=casting_mode)
184      y2 = LigerRMSNormFunction.apply(h2, w, 1e-6, offset, casting_mode)
185  
186      assert torch.allclose(y1, y2, atol=atol, rtol=rtol)
187  
188      grad = torch.randn_like(y2)
189  
190      y1.backward(grad)
191      y2.backward(grad)
192  
193      assert torch.allclose(h1.grad, h2.grad, atol=atol, rtol=rtol)