rope.py
1 import torch 2 import triton 3 import triton.language as tl 4 5 6 @triton.jit 7 def _triton_rope( 8 q_ptr, 9 q_row_stride, 10 k_ptr, 11 k_row_stride, 12 cos, 13 cos_row_stride, 14 sin, 15 sin_row_stride, 16 sl, 17 bs: tl.constexpr, 18 n_qh: tl.constexpr, 19 n_kh: tl.constexpr, 20 hd: tl.constexpr, 21 pad_n_qh: tl.constexpr, 22 pad_n_kh: tl.constexpr, 23 pad_hd: tl.constexpr, 24 BLOCK_SIZE: tl.constexpr, 25 BACKWARD_PASS: tl.constexpr = False, 26 ): 27 # q size: (bsz, seq_len, num_q_heads, head_dim) 28 # q stride: (seq_len * num_q_heads * head_dim, num_q_heads * head_dim, head_dim, 1) 29 # k size: (bsz, seq_len, num_kv_heads, head_dim) 30 # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1) 31 32 # cos size: (1, seq_len, head_dim) 33 # stride: (seq_len * head_dim, head_dim, 1) 34 pid = tl.program_id(0) 35 36 # locate start address 37 q_ptr = q_ptr + pid * q_row_stride 38 k_ptr = k_ptr + pid * k_row_stride 39 40 # #################################################################### 41 # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position 42 # m of this program instance 43 # #################################################################### 44 45 # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which 46 # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension 47 # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index 48 # and pid % sl to get the sequence index. 49 # 2. We only need the left half of cos and sin matrix because the right half is just 50 # a clone of the left half. 51 cos_row_idx = pid % (sl) 52 cos = cos + cos_row_idx * cos_row_stride 53 sin = sin + cos_row_idx * sin_row_stride 54 cos_offsets = tl.arange(0, pad_hd // 2) 55 cos_mask = cos_offsets < hd // 2 56 cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0) 57 sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0) 58 59 # #################################################################### 60 # Load the left and right half of q and k for the current 61 # program instance (i.e. for the current token) separately 62 # #################################################################### 63 # left half of the head 64 first_half_q_offsets = ( 65 tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] 66 ) 67 first_half_k_offsets = ( 68 tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] 69 ) 70 first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & ( 71 tl.arange(0, pad_hd // 2)[None, :] < hd // 2 72 ) 73 first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & ( 74 tl.arange(0, pad_hd // 2)[None, :] < hd // 2 75 ) 76 q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to( 77 sin_row.dtype 78 ) 79 k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to( 80 sin_row.dtype 81 ) 82 83 # right half of the head 84 second_half_q_offsets = first_half_q_offsets + (hd // 2) 85 second_half_k_offsets = first_half_k_offsets + (hd // 2) 86 second_q_mask = first_q_mask 87 second_k_mask = first_k_mask 88 q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to( 89 sin_row.dtype 90 ) 91 k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to( 92 sin_row.dtype 93 ) 94 95 if not BACKWARD_PASS: 96 # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] 97 new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row 98 tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) 99 new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row 100 tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) 101 102 new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row 103 tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) 104 new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row 105 tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) 106 else: 107 # with some math, we can get: 108 # dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin] 109 new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row 110 tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) 111 new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row 112 tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) 113 114 new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row 115 tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) 116 new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row 117 tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) 118 119 120 def rope_forward(q, k, cos, sin): 121 122 # transpose it back to the physical shape because Triton looks at the physical storage 123 # note: q and k are incontiguous before the transformation and will become contiguous after transpose 124 q = q.transpose(1, 2) 125 k = k.transpose(1, 2) 126 127 batch_size, seq_len, n_q_head, head_dim = q.shape 128 n_kv_head = k.shape[2] 129 pad_hd = triton.next_power_of_2(head_dim) 130 pad_n_q_head = triton.next_power_of_2(n_q_head) 131 pad_n_kv_head = triton.next_power_of_2(n_kv_head) 132 BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) 133 134 n_row = batch_size * seq_len 135 136 # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous 137 q = q.contiguous() 138 k = k.contiguous() 139 cos = cos.contiguous() 140 sin = sin.contiguous() 141 142 _triton_rope[(n_row,)]( 143 q, 144 q.stride(1), 145 k, 146 k.stride(1), 147 cos, 148 cos.stride(-2), 149 sin, 150 sin.stride(-2), 151 seq_len, 152 batch_size, 153 n_q_head, 154 n_kv_head, 155 head_dim, 156 pad_n_q_head, 157 pad_n_kv_head, 158 pad_hd, 159 BLOCK_SIZE=BLOCK_SIZE, 160 BACKWARD_PASS=False, 161 ) 162 return q.transpose(1, 2), k.transpose(1, 2), cos, sin 163 164 165 def rope_backward(dq, dk, cos, sin): 166 dq = dq.transpose(1, 2) 167 dk = dk.transpose(1, 2) 168 169 batch_size, seq_len, n_q_head, head_dim = dq.shape 170 n_kv_head = dk.shape[2] 171 pad_hd = triton.next_power_of_2(head_dim) 172 pad_n_q_head = triton.next_power_of_2(n_q_head) 173 pad_n_kv_head = triton.next_power_of_2(n_kv_head) 174 BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) 175 176 n_row = batch_size * seq_len 177 178 # ensure dq and dk are contiguous 179 dq = dq.contiguous() 180 dk = dk.contiguous() 181 182 # backward is similar to forward except swapping few ops 183 _triton_rope[(n_row,)]( 184 dq, 185 dq.stride(1), 186 dk, 187 dk.stride(1), 188 cos, 189 cos.stride(-2), 190 sin, 191 sin.stride(-2), 192 seq_len, 193 batch_size, 194 n_q_head, 195 n_kv_head, 196 head_dim, 197 pad_n_q_head, 198 pad_n_kv_head, 199 pad_hd, 200 BLOCK_SIZE=BLOCK_SIZE, 201 BACKWARD_PASS=True, 202 ) 203 return dq.transpose(1, 2), dk.transpose(1, 2) 204 205 206 class LigerRopeFunction(torch.autograd.Function): 207 """ 208 Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that 209 this implements the HuggingFace Llama & Mistral version, whose rotation matrix is slightly different 210 than the original RoPE paper. 211 212 Please find the corresponding HuggingFace implementation here: 213 https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llama/modeling_llama.py#L184 214 215 For more details about the rotation matrix used here, please refer to: 216 https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509/2 217 """ 218 219 @staticmethod 220 def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1): 221 """ 222 q size: (bsz, n_q_head, seq_len, head_dim) 223 k size: (bsz, n_kv_head, seq_len, head_dim) 224 cos size: (1, seq_len, head_dim) 225 sin size: (1, seq_len, head_dim) 226 """ 227 q, k, cos, sin = rope_forward(q, k, cos, sin) 228 ctx.save_for_backward(cos, sin) 229 return q, k 230 231 def backward(ctx, dq, dk): 232 """ 233 dq size: (bsz, n_q_head, seq_len, head_dim) 234 dk size: (bsz, n_kv_head, seq_len, head_dim) 235 cos size: (1, seq_len, head_dim) 236 sin size: (1, seq_len, head_dim) 237 """ 238 239 cos, sin = ctx.saved_tensors 240 dq, dk = rope_backward(dq, dk, cos, sin) 241 return dq, dk, None, None, None, None