/ src / liger_kernel / ops / experimental / embedding.py
embedding.py
  1  import torch
  2  import triton
  3  import triton.language as tl
  4  
  5  from liger_kernel.ops.utils import ensure_contiguous
  6  
  7  
  8  @triton.jit
  9  def embedding_forward_kernel(
 10      embeddings_ptr,
 11      indices_ptr,
 12      output_ptr,
 13      n_elements,
 14      embedding_dim: tl.constexpr,
 15      BLOCK_SIZE_M: tl.constexpr,
 16      BLOCK_SIZE_N: tl.constexpr,
 17  ):
 18      pid_m = tl.program_id(0)
 19      pid_n = tl.program_id(1)
 20  
 21      start_m = pid_m * BLOCK_SIZE_M
 22      start_n = pid_n * BLOCK_SIZE_N
 23      offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
 24      mask_m = offsets_m < n_elements
 25      indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)
 26      offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
 27      mask_n = offsets_n < embedding_dim
 28  
 29      embedding_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
 30      embeddings = tl.load(
 31          embeddings_ptr + embedding_offsets,
 32          mask=mask_m[:, None] & mask_n[None, :],
 33          other=0.0,
 34      )
 35  
 36      output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :]
 37      tl.store(
 38          output_ptr + output_offsets, embeddings, mask=mask_m[:, None] & mask_n[None, :]
 39      )
 40  
 41  
 42  @triton.jit
 43  def embedding_backward_kernel(
 44      grad_output_ptr,
 45      grad_weight_ptr,
 46      indices_ptr,
 47      n_elements,
 48      embedding_dim: tl.constexpr,
 49      BLOCK_SIZE_M: tl.constexpr,
 50      BLOCK_SIZE_N: tl.constexpr,
 51  ):
 52      pid_m = tl.program_id(0)
 53      pid_n = tl.program_id(1)
 54  
 55      start_m = pid_m * BLOCK_SIZE_M
 56      start_n = pid_n * BLOCK_SIZE_N
 57      offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
 58      mask_m = offsets_m < n_elements
 59      indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)
 60      offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
 61      mask_n = offsets_n < embedding_dim
 62  
 63      grad_output = tl.load(
 64          grad_output_ptr + offsets_m[:, None] * embedding_dim + offsets_n[None, :],
 65          mask=mask_m[:, None] & mask_n[None, :],
 66          other=0.0,
 67      )
 68  
 69      grad_weight_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
 70  
 71      tl.atomic_add(
 72          grad_weight_ptr + grad_weight_offsets,
 73          grad_output,
 74          mask=mask_m[:, None] & mask_n[None, :],
 75      )
 76  
 77  
 78  class LigerEmbeddingFunction(torch.autograd.Function):
 79      @staticmethod
 80      @ensure_contiguous
 81      def forward(ctx, embeddings: torch.Tensor, indices: torch.Tensor):
 82          ori_shape = indices.shape
 83          indices = indices.view(-1)
 84          output = torch.empty(
 85              indices.shape[0],
 86              embeddings.shape[1],
 87              device=indices.device,
 88              dtype=embeddings.dtype,
 89          )
 90  
 91          n_elements = indices.numel()
 92          embedding_dim = embeddings.shape[1]
 93  
 94          BLOCK_SIZE_M = triton.next_power_of_2(min(128, embedding_dim))
 95          BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
 96          grid = (
 97              triton.cdiv(n_elements, BLOCK_SIZE_M),
 98              triton.cdiv(embedding_dim, BLOCK_SIZE_N),
 99          )
100  
101          embedding_forward_kernel[grid](
102              embeddings,
103              indices,
104              output,
105              n_elements,
106              embedding_dim=embedding_dim,
107              BLOCK_SIZE_M=BLOCK_SIZE_M,
108              BLOCK_SIZE_N=BLOCK_SIZE_N,
109          )
110  
111          ctx.save_for_backward(indices, embeddings)
112  
113          return output.view(*ori_shape, -1)
114  
115      @staticmethod
116      @ensure_contiguous
117      def backward(ctx, grad_output: torch.Tensor):
118          indices, embedding_table = ctx.saved_tensors
119          grad_output = grad_output.contiguous().view(-1, embedding_table.shape[1])
120  
121          grad_weight = torch.zeros_like(embedding_table)
122  
123          n_elements = indices.numel()
124          embedding_dim = embedding_table.shape[1]
125  
126          BLOCK_SIZE_M = triton.next_power_of_2(min(128, embedding_dim))
127          BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
128          grid = (
129              triton.cdiv(n_elements, BLOCK_SIZE_M),
130              triton.cdiv(embedding_dim, BLOCK_SIZE_N),
131          )
132  
133          embedding_backward_kernel[grid](
134              grad_output,
135              grad_weight,
136              indices,
137              n_elements,
138              embedding_dim=embedding_dim,
139              BLOCK_SIZE_M=BLOCK_SIZE_M,
140              BLOCK_SIZE_N=BLOCK_SIZE_N,
141          )
142  
143          return grad_weight, None