/ src / liger_kernel / ops / rms_norm.py
rms_norm.py
  1  """
  2  This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
  3  See the original Unsloth repository at https://github.com/unslothai/unsloth.
  4  
  5  The following line
  6  https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/rms_norm.py#L30
  7  is based on code from Unsloth, located at:
  8  https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
  9  
 10  Modifications made by Yanning Chen, 2024.
 11  """
 12  
 13  import math
 14  import operator
 15  
 16  import torch
 17  import triton
 18  import triton.language as tl
 19  
 20  from liger_kernel.ops.utils import (
 21      calculate_settings,
 22      compare_version,
 23      ensure_contiguous,
 24      torch_to_triton_dtype,
 25  )
 26  
 27  if compare_version("triton", operator.ge, "3.0.0"):
 28      try:
 29          # typical import path with dispatch available
 30          from triton.language.extra.libdevice import rsqrt
 31      except ModuleNotFoundError:
 32          # for working with NGC containers
 33          from triton.language.extra.cuda.libdevice import rsqrt
 34  else:
 35      from triton.language.math import rsqrt
 36  
 37  
 38  _CASTING_MODE_NONE = tl.constexpr(-1)
 39  _CASTING_MODE_LLAMA = tl.constexpr(0)
 40  _CASTING_MODE_GEMMA = tl.constexpr(1)
 41  
 42  
 43  @triton.jit
 44  def _rms_norm_forward_kernel(
 45      Y_ptr,
 46      Y_row_stride,
 47      X_ptr,
 48      X_row_stride,
 49      W_ptr,
 50      W_row_stride,
 51      RSTD_ptr,
 52      RSTD_row_stride,
 53      n_cols,
 54      eps,
 55      offset,
 56      casting_mode: tl.constexpr,  # constexpr so the `if` blocks can be optimized out
 57      BLOCK_SIZE: tl.constexpr,
 58  ):
 59      """
 60      y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
 61  
 62      Reference:
 63      1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
 64      2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
 65      3. https://arxiv.org/pdf/1910.07467
 66      """
 67  
 68      row_idx = tl.program_id(0)
 69      col_offsets = tl.arange(0, BLOCK_SIZE)
 70      mask = col_offsets < n_cols
 71  
 72      Y_ptr += row_idx * Y_row_stride
 73      X_ptr += row_idx * X_row_stride
 74      RSTD_ptr += row_idx * RSTD_row_stride
 75  
 76      X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
 77      X_row_dtype = X_row.dtype
 78      W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
 79  
 80      # On Llama, only rstd is computed on fp32
 81      if casting_mode == _CASTING_MODE_LLAMA:
 82          X_row = X_row.to(tl.float32)
 83  
 84      # Gemma computes everything on fp32, and then casts back the output to the original dtype
 85      if casting_mode == _CASTING_MODE_GEMMA:
 86          W_row = W_row.to(tl.float32)
 87          X_row = X_row.to(tl.float32)
 88  
 89      if casting_mode == _CASTING_MODE_NONE:
 90          eps = eps.to(X_row_dtype)
 91          offset = offset.to(X_row_dtype)
 92  
 93      mean_square = tl.sum(X_row * X_row, axis=0) / n_cols
 94      rstd = rsqrt(mean_square + eps)
 95  
 96      # We can save time by caching rms with minimal memory overhead
 97      # because rms is much smaller compared to X_row, as rms is for each row.
 98      # However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
 99      tl.store(RSTD_ptr, rstd)
100  
101      X_row = X_row * rstd
102  
103      # On Llama, the multiplication with the weight is done on the original dtype
104      if casting_mode == _CASTING_MODE_LLAMA:
105          X_row = X_row.to(X_row_dtype)
106  
107      Y_row = X_row * (offset + W_row)
108  
109      if casting_mode == _CASTING_MODE_GEMMA:
110          Y_row = Y_row.to(X_row_dtype)
111  
112      tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
113  
114  
115  @triton.jit
116  def _rms_norm_backward_kernel(
117      dY_ptr,
118      dY_row_stride,
119      dX_ptr,
120      dX_row_stride,
121      X_ptr,
122      X_row_stride,
123      X_dtype: tl.constexpr,
124      W_ptr,
125      W_row_stride,
126      RSTD_ptr,
127      RSTD_row_stride,
128      dW_ptr,
129      dW_row_stride,
130      n_rows,
131      n_cols,
132      offset,
133      rows_per_program: tl.constexpr,
134      casting_mode: tl.constexpr,
135      BLOCK_SIZE: tl.constexpr,
136  ):
137      """
138      dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
139      dw = sum(dy * (x / RMS)). summation over BxT dimension
140      """
141  
142      row_block_id = tl.program_id(0)
143      row_start = row_block_id * rows_per_program
144      row_end = min((row_block_id + 1) * rows_per_program, n_rows)
145      col_offsets = tl.arange(0, BLOCK_SIZE)
146      mask = col_offsets < n_cols
147  
148      dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
149  
150      dY_ptr += row_start * dY_row_stride
151      dX_ptr += row_start * dX_row_stride
152  
153      X_ptr += row_start * X_row_stride
154      RSTD_ptr += row_start
155  
156      W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
157      W_row = W_row + offset
158  
159      for _ in range(row_start, row_end):
160          dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
161          X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)
162  
163          # Get cached rms
164          rstd_row = tl.load(RSTD_ptr)
165  
166          X_row = X_row.to(tl.float32)
167  
168          # Different bacward graphs for different casting modes
169          if casting_mode == _CASTING_MODE_LLAMA:
170              m = (dY_row * W_row).to(tl.float32)
171  
172          elif casting_mode == _CASTING_MODE_GEMMA:
173              dY_row = dY_row.to(tl.float32)
174              m = dY_row * W_row
175          else:
176              m = dY_row * W_row
177  
178          dX_row = rstd_row * m
179  
180          dX_row += (rstd_row) * (
181              -(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
182          )
183  
184          # calculate the gradient of W
185          if casting_mode == _CASTING_MODE_LLAMA:
186              dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
187          else:
188              # here X_row is already in fp32 (see previous if block)
189              dW_row += dY_row * (X_row * rstd_row)
190  
191          tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)
192  
193          dY_ptr += dY_row_stride
194          dX_ptr += dX_row_stride
195          X_ptr += X_row_stride
196          RSTD_ptr += RSTD_row_stride
197  
198      tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
199  
200  
201  _str_to_casting_mode = {
202      "llama": _CASTING_MODE_LLAMA.value,
203      "gemma": _CASTING_MODE_GEMMA.value,
204      "none": _CASTING_MODE_NONE.value,
205  }
206  
207  
208  def rms_norm_forward(X, W, eps, offset, casting_mode):
209      if not isinstance(casting_mode, int):
210          assert (
211              casting_mode in _str_to_casting_mode
212          ), f"Invalid casting mode: {casting_mode}"
213          casting_mode = _str_to_casting_mode[casting_mode]
214      else:
215          assert (
216              casting_mode in _str_to_casting_mode.values()
217          ), f"Invalid casting mode: {casting_mode}"
218  
219      shape = X.shape
220      dim = shape[-1]
221      X = X.view(-1, dim)
222      n_rows, n_cols = X.shape
223      BLOCK_SIZE, num_warps = calculate_settings(n_cols)
224  
225      Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
226      # RSTD is to cache rstd for each row
227      # RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
228      rstd_dtype = (
229          torch.float32
230          if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value)
231          else X.dtype
232      )
233      RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)
234  
235      # Check constraints.
236      assert (
237          X.shape[1] == W.shape[0]
238      ), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
239  
240      _rms_norm_forward_kernel[(n_rows,)](
241          Y,
242          Y.stride(0),
243          X,
244          X.stride(0),
245          W,
246          W.stride(0),
247          RSTD,
248          RSTD.stride(0),
249          n_cols,
250          eps,
251          offset,
252          casting_mode,
253          BLOCK_SIZE=BLOCK_SIZE,
254          num_warps=num_warps,
255      )
256      return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode
257  
258  
259  def rms_norm_backward(
260      dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place
261  ):
262      shape = dY.shape
263      dim = shape[-1]
264      dY = dY.view(-1, dim)
265      n_rows, n_cols = dY.shape
266  
267      sm_count = 1
268      if X.device.type == "cuda":
269          sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
270      elif X.device.type == "xpu":
271          sm_count = torch.xpu.get_device_properties(X.device).gpu_subslice_count
272  
273      # fp32 for numerical stability especially.
274      _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
275  
276      if n_cols > BLOCK_SIZE:
277          raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
278      rows_per_program = math.ceil(n_rows / sm_count)
279      grid = (sm_count,)
280  
281      if in_place is True:
282          dX = dY
283      else:
284          dX = torch.zeros_like(dY)
285  
286      _rms_norm_backward_kernel[grid](
287          dY,
288          dY.stride(0),
289          dX,
290          dX.stride(0),
291          X,
292          X.stride(0),
293          torch_to_triton_dtype[X.dtype],
294          W,
295          W.stride(0),
296          RSTD,
297          RSTD.stride(0),
298          _dW,
299          _dW.stride(0),
300          n_rows,
301          n_cols,
302          offset,
303          rows_per_program,
304          casting_mode,
305          BLOCK_SIZE=BLOCK_SIZE,
306          num_warps=num_warps,
307      )
308      dX = dX.view(*shape)
309      dW = _dW.sum(dim=0).to(W.dtype)
310  
311      return dX, dW
312  
313  
314  class LigerRMSNormFunction(torch.autograd.Function):
315      """
316      Performs RMSNorm (Root Mean Square Normalization), which normalizes the input tensor `X` using the
317      weight tensor `W`, with an optional offset and casting mode.
318  
319      Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma
320      uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual
321      `(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function.
322  
323      In addition, different models cast their inputs at different places during RMSNorm computation. For
324      example, Gemma casts everything to fp32 nefore starting the computation, while Llama casts only the
325      inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently
326      support the following casting modes (they match HuggingFace Transformers' implementations):
327      - 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32.
328      - 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype.
329      - 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation.
330  
331      `in_place` option means whether to in_place modify dY to store dX. This is default to `True` to save memory. However, under certain cases, it can produce incorrect inputs.
332          For example, gemma2 uses two rmsnorm sequentially with residual in between. The resesidual part needs dY so it cannot be modified in-place.
333          Therefore, for the patching of RMSNorm in gemma2, we set `in_place` to `False`
334      """
335  
336      @staticmethod
337      @ensure_contiguous
338      def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True):
339          """
340          X: (B, T, H) or (BxT, H)
341          W: (H,)
342          """
343          Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(
344              X, W, eps, offset, casting_mode
345          )
346          ctx.offset = offset
347          ctx.casting_mode = casting_mode
348          ctx.in_place = in_place
349          ctx.BLOCK_SIZE = BLOCK_SIZE
350          ctx.num_warps = num_warps
351          ctx.save_for_backward(X, W, RSTD)
352          return Y
353  
354      @staticmethod
355      @ensure_contiguous
356      def backward(ctx, dY):
357          """
358          Y: (B, T, H) or (BxT, H)
359          """
360          X, W, RSTD = ctx.saved_tensors
361          dX, dW = rms_norm_backward(
362              dY,
363              X,
364              W,
365              RSTD,
366              ctx.offset,
367              ctx.casting_mode,
368              ctx.BLOCK_SIZE,
369              ctx.num_warps,
370              ctx.in_place,
371          )
372          return dX, dW, None, None, None, None