/ src / liger_kernel / ops / rope.py
rope.py
  1  import torch
  2  import triton
  3  import triton.language as tl
  4  
  5  
  6  @triton.jit
  7  def _triton_rope(
  8      q_ptr,
  9      q_row_stride,
 10      k_ptr,
 11      k_row_stride,
 12      cos,
 13      cos_row_stride,
 14      sin,
 15      sin_row_stride,
 16      sl,
 17      bs: tl.constexpr,
 18      n_qh: tl.constexpr,
 19      n_kh: tl.constexpr,
 20      hd: tl.constexpr,
 21      pad_n_qh: tl.constexpr,
 22      pad_n_kh: tl.constexpr,
 23      pad_hd: tl.constexpr,
 24      BLOCK_SIZE: tl.constexpr,
 25      BACKWARD_PASS: tl.constexpr = False,
 26  ):
 27      # q size: (bsz, seq_len, num_q_heads, head_dim)
 28      # q stride: (seq_len * num_q_heads * head_dim, num_q_heads * head_dim, head_dim, 1)
 29      # k size: (bsz, seq_len, num_kv_heads, head_dim)
 30      # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)
 31  
 32      # cos size: (1, seq_len, head_dim)
 33      # stride: (seq_len * head_dim, head_dim, 1)
 34      pid = tl.program_id(0)
 35  
 36      # locate start address
 37      q_ptr = q_ptr + pid * q_row_stride
 38      k_ptr = k_ptr + pid * k_row_stride
 39  
 40      # ####################################################################
 41      # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
 42      # m of this program instance
 43      # ####################################################################
 44  
 45      # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which
 46      # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension
 47      # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index
 48      # and pid % sl to get the sequence index.
 49      # 2. We only need the left half of cos and sin matrix because the right half is just
 50      # a clone of the left half.
 51      cos_row_idx = pid % (sl)
 52      cos = cos + cos_row_idx * cos_row_stride
 53      sin = sin + cos_row_idx * sin_row_stride
 54      cos_offsets = tl.arange(0, pad_hd // 2)
 55      cos_mask = cos_offsets < hd // 2
 56      cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
 57      sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0)
 58  
 59      # ####################################################################
 60      # Load the left and right half of q and k for the current
 61      # program instance (i.e. for the current token) separately
 62      # ####################################################################
 63      # left half of the head
 64      first_half_q_offsets = (
 65          tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
 66      )
 67      first_half_k_offsets = (
 68          tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
 69      )
 70      first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
 71          tl.arange(0, pad_hd // 2)[None, :] < hd // 2
 72      )
 73      first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
 74          tl.arange(0, pad_hd // 2)[None, :] < hd // 2
 75      )
 76      q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
 77          sin_row.dtype
 78      )
 79      k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
 80          sin_row.dtype
 81      )
 82  
 83      # right half of the head
 84      second_half_q_offsets = first_half_q_offsets + (hd // 2)
 85      second_half_k_offsets = first_half_k_offsets + (hd // 2)
 86      second_q_mask = first_q_mask
 87      second_k_mask = first_k_mask
 88      q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
 89          sin_row.dtype
 90      )
 91      k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
 92          sin_row.dtype
 93      )
 94  
 95      if not BACKWARD_PASS:
 96          # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
 97          new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
 98          tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
 99          new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
100          tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
101  
102          new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
103          tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
104          new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
105          tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
106      else:
107          # with some math, we can get:
108          # dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]
109          new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row
110          tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
111          new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row
112          tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
113  
114          new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row
115          tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
116          new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row
117          tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
118  
119  
120  def rope_forward(q, k, cos, sin):
121  
122      # transpose it back to the physical shape because Triton looks at the physical storage
123      # note: q and k are incontiguous before the transformation and will become contiguous after transpose
124      q = q.transpose(1, 2)
125      k = k.transpose(1, 2)
126  
127      batch_size, seq_len, n_q_head, head_dim = q.shape
128      n_kv_head = k.shape[2]
129      pad_hd = triton.next_power_of_2(head_dim)
130      pad_n_q_head = triton.next_power_of_2(n_q_head)
131      pad_n_kv_head = triton.next_power_of_2(n_kv_head)
132      BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
133  
134      n_row = batch_size * seq_len
135  
136      # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
137      q = q.contiguous()
138      k = k.contiguous()
139      cos = cos.contiguous()
140      sin = sin.contiguous()
141  
142      _triton_rope[(n_row,)](
143          q,
144          q.stride(1),
145          k,
146          k.stride(1),
147          cos,
148          cos.stride(-2),
149          sin,
150          sin.stride(-2),
151          seq_len,
152          batch_size,
153          n_q_head,
154          n_kv_head,
155          head_dim,
156          pad_n_q_head,
157          pad_n_kv_head,
158          pad_hd,
159          BLOCK_SIZE=BLOCK_SIZE,
160          BACKWARD_PASS=False,
161      )
162      return q.transpose(1, 2), k.transpose(1, 2), cos, sin
163  
164  
165  def rope_backward(dq, dk, cos, sin):
166      dq = dq.transpose(1, 2)
167      dk = dk.transpose(1, 2)
168  
169      batch_size, seq_len, n_q_head, head_dim = dq.shape
170      n_kv_head = dk.shape[2]
171      pad_hd = triton.next_power_of_2(head_dim)
172      pad_n_q_head = triton.next_power_of_2(n_q_head)
173      pad_n_kv_head = triton.next_power_of_2(n_kv_head)
174      BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
175  
176      n_row = batch_size * seq_len
177  
178      # ensure dq and dk are contiguous
179      dq = dq.contiguous()
180      dk = dk.contiguous()
181  
182      # backward is similar to forward except swapping few ops
183      _triton_rope[(n_row,)](
184          dq,
185          dq.stride(1),
186          dk,
187          dk.stride(1),
188          cos,
189          cos.stride(-2),
190          sin,
191          sin.stride(-2),
192          seq_len,
193          batch_size,
194          n_q_head,
195          n_kv_head,
196          head_dim,
197          pad_n_q_head,
198          pad_n_kv_head,
199          pad_hd,
200          BLOCK_SIZE=BLOCK_SIZE,
201          BACKWARD_PASS=True,
202      )
203      return dq.transpose(1, 2), dk.transpose(1, 2)
204  
205  
206  class LigerRopeFunction(torch.autograd.Function):
207      """
208      Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that
209      this implements the HuggingFace Llama & Mistral version, whose rotation matrix is slightly different
210      than the original RoPE paper.
211  
212      Please find the corresponding HuggingFace implementation here:
213      https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llama/modeling_llama.py#L184
214  
215      For more details about the rotation matrix used here, please refer to:
216      https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509/2
217      """
218  
219      @staticmethod
220      def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
221          """
222          q size: (bsz, n_q_head, seq_len, head_dim)
223          k size: (bsz, n_kv_head, seq_len, head_dim)
224          cos size: (1, seq_len, head_dim)
225          sin size: (1, seq_len, head_dim)
226          """
227          q, k, cos, sin = rope_forward(q, k, cos, sin)
228          ctx.save_for_backward(cos, sin)
229          return q, k
230  
231      def backward(ctx, dq, dk):
232          """
233          dq size: (bsz, n_q_head, seq_len, head_dim)
234          dk size: (bsz, n_kv_head, seq_len, head_dim)
235          cos size: (1, seq_len, head_dim)
236          sin size: (1, seq_len, head_dim)
237          """
238  
239          cos, sin = ctx.saved_tensors
240          dq, dk = rope_backward(dq, dk, cos, sin)
241          return dq, dk, None, None, None, None