rope.py
1 from liger_kernel.ops.rope import LigerRopeFunction 2 3 4 def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): 5 """ 6 Applies Rotary Positional Embedding (RoPE) operation to query and key states. 7 8 Args: 9 q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim). 10 k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim). 11 cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim). 12 sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim). 13 position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None. 14 unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1. 15 16 Returns: 17 Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the RoPE operation. 18 """ 19 20 return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim)