rope.py
 1  from liger_kernel.ops.rope import LigerRopeFunction
 2  
 3  
 4  def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
 5      """
 6      Applies Rotary Positional Embedding (RoPE) operation to query and key states.
 7  
 8      Args:
 9          q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
10          k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
11          cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim).
12          sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim).
13          position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None.
14          unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
15  
16      Returns:
17          Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the RoPE operation.
18      """
19  
20      return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)