rms_norm.py
1 """ 2 This file incorporates code from Unsloth licensed under the Apache License, Version 2.0. 3 See the original Unsloth repository at https://github.com/unslothai/unsloth. 4 5 The following line 6 https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/rms_norm.py#L30 7 is based on code from Unsloth, located at: 8 https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22 9 10 Modifications made by Yanning Chen, 2024. 11 """ 12 13 import math 14 import operator 15 16 import torch 17 import triton 18 import triton.language as tl 19 20 from liger_kernel.ops.utils import ( 21 calculate_settings, 22 compare_version, 23 ensure_contiguous, 24 torch_to_triton_dtype, 25 ) 26 27 if compare_version("triton", operator.ge, "3.0.0"): 28 try: 29 # typical import path with dispatch available 30 from triton.language.extra.libdevice import rsqrt 31 except ModuleNotFoundError: 32 # for working with NGC containers 33 from triton.language.extra.cuda.libdevice import rsqrt 34 else: 35 from triton.language.math import rsqrt 36 37 38 _CASTING_MODE_NONE = tl.constexpr(-1) 39 _CASTING_MODE_LLAMA = tl.constexpr(0) 40 _CASTING_MODE_GEMMA = tl.constexpr(1) 41 42 43 @triton.jit 44 def _rms_norm_forward_kernel( 45 Y_ptr, 46 Y_row_stride, 47 X_ptr, 48 X_row_stride, 49 W_ptr, 50 W_row_stride, 51 RSTD_ptr, 52 RSTD_row_stride, 53 n_cols, 54 eps, 55 offset, 56 casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out 57 BLOCK_SIZE: tl.constexpr, 58 ): 59 """ 60 y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N) 61 62 Reference: 63 1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html 64 2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22 65 3. https://arxiv.org/pdf/1910.07467 66 """ 67 68 row_idx = tl.program_id(0) 69 col_offsets = tl.arange(0, BLOCK_SIZE) 70 mask = col_offsets < n_cols 71 72 Y_ptr += row_idx * Y_row_stride 73 X_ptr += row_idx * X_row_stride 74 RSTD_ptr += row_idx * RSTD_row_stride 75 76 X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0) 77 X_row_dtype = X_row.dtype 78 W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0) 79 80 # On Llama, only rstd is computed on fp32 81 if casting_mode == _CASTING_MODE_LLAMA: 82 X_row = X_row.to(tl.float32) 83 84 # Gemma computes everything on fp32, and then casts back the output to the original dtype 85 if casting_mode == _CASTING_MODE_GEMMA: 86 W_row = W_row.to(tl.float32) 87 X_row = X_row.to(tl.float32) 88 89 if casting_mode == _CASTING_MODE_NONE: 90 eps = eps.to(X_row_dtype) 91 offset = offset.to(X_row_dtype) 92 93 mean_square = tl.sum(X_row * X_row, axis=0) / n_cols 94 rstd = rsqrt(mean_square + eps) 95 96 # We can save time by caching rms with minimal memory overhead 97 # because rms is much smaller compared to X_row, as rms is for each row. 98 # However, on the computation side, it can save 4 operations (*, sum, /, sqrt). 99 tl.store(RSTD_ptr, rstd) 100 101 X_row = X_row * rstd 102 103 # On Llama, the multiplication with the weight is done on the original dtype 104 if casting_mode == _CASTING_MODE_LLAMA: 105 X_row = X_row.to(X_row_dtype) 106 107 Y_row = X_row * (offset + W_row) 108 109 if casting_mode == _CASTING_MODE_GEMMA: 110 Y_row = Y_row.to(X_row_dtype) 111 112 tl.store(Y_ptr + col_offsets, Y_row, mask=mask) 113 114 115 @triton.jit 116 def _rms_norm_backward_kernel( 117 dY_ptr, 118 dY_row_stride, 119 dX_ptr, 120 dX_row_stride, 121 X_ptr, 122 X_row_stride, 123 X_dtype: tl.constexpr, 124 W_ptr, 125 W_row_stride, 126 RSTD_ptr, 127 RSTD_row_stride, 128 dW_ptr, 129 dW_row_stride, 130 n_rows, 131 n_cols, 132 offset, 133 rows_per_program: tl.constexpr, 134 casting_mode: tl.constexpr, 135 BLOCK_SIZE: tl.constexpr, 136 ): 137 """ 138 dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product 139 dw = sum(dy * (x / RMS)). summation over BxT dimension 140 """ 141 142 row_block_id = tl.program_id(0) 143 row_start = row_block_id * rows_per_program 144 row_end = min((row_block_id + 1) * rows_per_program, n_rows) 145 col_offsets = tl.arange(0, BLOCK_SIZE) 146 mask = col_offsets < n_cols 147 148 dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) 149 150 dY_ptr += row_start * dY_row_stride 151 dX_ptr += row_start * dX_row_stride 152 153 X_ptr += row_start * X_row_stride 154 RSTD_ptr += row_start 155 156 W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0) 157 W_row = W_row + offset 158 159 for _ in range(row_start, row_end): 160 dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0) 161 X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0) 162 163 # Get cached rms 164 rstd_row = tl.load(RSTD_ptr) 165 166 X_row = X_row.to(tl.float32) 167 168 # Different bacward graphs for different casting modes 169 if casting_mode == _CASTING_MODE_LLAMA: 170 m = (dY_row * W_row).to(tl.float32) 171 172 elif casting_mode == _CASTING_MODE_GEMMA: 173 dY_row = dY_row.to(tl.float32) 174 m = dY_row * W_row 175 else: 176 m = dY_row * W_row 177 178 dX_row = rstd_row * m 179 180 dX_row += (rstd_row) * ( 181 -(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row 182 ) 183 184 # calculate the gradient of W 185 if casting_mode == _CASTING_MODE_LLAMA: 186 dW_row += dY_row * (X_row * rstd_row).to(X_dtype) 187 else: 188 # here X_row is already in fp32 (see previous if block) 189 dW_row += dY_row * (X_row * rstd_row) 190 191 tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask) 192 193 dY_ptr += dY_row_stride 194 dX_ptr += dX_row_stride 195 X_ptr += X_row_stride 196 RSTD_ptr += RSTD_row_stride 197 198 tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask) 199 200 201 _str_to_casting_mode = { 202 "llama": _CASTING_MODE_LLAMA.value, 203 "gemma": _CASTING_MODE_GEMMA.value, 204 "none": _CASTING_MODE_NONE.value, 205 } 206 207 208 def rms_norm_forward(X, W, eps, offset, casting_mode): 209 if not isinstance(casting_mode, int): 210 assert ( 211 casting_mode in _str_to_casting_mode 212 ), f"Invalid casting mode: {casting_mode}" 213 casting_mode = _str_to_casting_mode[casting_mode] 214 else: 215 assert ( 216 casting_mode in _str_to_casting_mode.values() 217 ), f"Invalid casting mode: {casting_mode}" 218 219 shape = X.shape 220 dim = shape[-1] 221 X = X.view(-1, dim) 222 n_rows, n_cols = X.shape 223 BLOCK_SIZE, num_warps = calculate_settings(n_cols) 224 225 Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) 226 # RSTD is to cache rstd for each row 227 # RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode 228 rstd_dtype = ( 229 torch.float32 230 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) 231 else X.dtype 232 ) 233 RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device) 234 235 # Check constraints. 236 assert ( 237 X.shape[1] == W.shape[0] 238 ), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]" 239 240 _rms_norm_forward_kernel[(n_rows,)]( 241 Y, 242 Y.stride(0), 243 X, 244 X.stride(0), 245 W, 246 W.stride(0), 247 RSTD, 248 RSTD.stride(0), 249 n_cols, 250 eps, 251 offset, 252 casting_mode, 253 BLOCK_SIZE=BLOCK_SIZE, 254 num_warps=num_warps, 255 ) 256 return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode 257 258 259 def rms_norm_backward( 260 dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place 261 ): 262 shape = dY.shape 263 dim = shape[-1] 264 dY = dY.view(-1, dim) 265 n_rows, n_cols = dY.shape 266 267 sm_count = 1 268 if X.device.type == "cuda": 269 sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count 270 elif X.device.type == "xpu": 271 sm_count = torch.xpu.get_device_properties(X.device).gpu_subslice_count 272 273 # fp32 for numerical stability especially. 274 _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device) 275 276 if n_cols > BLOCK_SIZE: 277 raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") 278 rows_per_program = math.ceil(n_rows / sm_count) 279 grid = (sm_count,) 280 281 if in_place is True: 282 dX = dY 283 else: 284 dX = torch.zeros_like(dY) 285 286 _rms_norm_backward_kernel[grid]( 287 dY, 288 dY.stride(0), 289 dX, 290 dX.stride(0), 291 X, 292 X.stride(0), 293 torch_to_triton_dtype[X.dtype], 294 W, 295 W.stride(0), 296 RSTD, 297 RSTD.stride(0), 298 _dW, 299 _dW.stride(0), 300 n_rows, 301 n_cols, 302 offset, 303 rows_per_program, 304 casting_mode, 305 BLOCK_SIZE=BLOCK_SIZE, 306 num_warps=num_warps, 307 ) 308 dX = dX.view(*shape) 309 dW = _dW.sum(dim=0).to(W.dtype) 310 311 return dX, dW 312 313 314 class LigerRMSNormFunction(torch.autograd.Function): 315 """ 316 Performs RMSNorm (Root Mean Square Normalization), which normalizes the input tensor `X` using the 317 weight tensor `W`, with an optional offset and casting mode. 318 319 Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma 320 uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual 321 `(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function. 322 323 In addition, different models cast their inputs at different places during RMSNorm computation. For 324 example, Gemma casts everything to fp32 nefore starting the computation, while Llama casts only the 325 inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently 326 support the following casting modes (they match HuggingFace Transformers' implementations): 327 - 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32. 328 - 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype. 329 - 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation. 330 331 `in_place` option means whether to in_place modify dY to store dX. This is default to `True` to save memory. However, under certain cases, it can produce incorrect inputs. 332 For example, gemma2 uses two rmsnorm sequentially with residual in between. The resesidual part needs dY so it cannot be modified in-place. 333 Therefore, for the patching of RMSNorm in gemma2, we set `in_place` to `False` 334 """ 335 336 @staticmethod 337 @ensure_contiguous 338 def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True): 339 """ 340 X: (B, T, H) or (BxT, H) 341 W: (H,) 342 """ 343 Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward( 344 X, W, eps, offset, casting_mode 345 ) 346 ctx.offset = offset 347 ctx.casting_mode = casting_mode 348 ctx.in_place = in_place 349 ctx.BLOCK_SIZE = BLOCK_SIZE 350 ctx.num_warps = num_warps 351 ctx.save_for_backward(X, W, RSTD) 352 return Y 353 354 @staticmethod 355 @ensure_contiguous 356 def backward(ctx, dY): 357 """ 358 Y: (B, T, H) or (BxT, H) 359 """ 360 X, W, RSTD = ctx.saved_tensors 361 dX, dW = rms_norm_backward( 362 dY, 363 X, 364 W, 365 RSTD, 366 ctx.offset, 367 ctx.casting_mode, 368 ctx.BLOCK_SIZE, 369 ctx.num_warps, 370 ctx.in_place, 371 ) 372 return dX, dW, None, None, None, None