embedding.py
1 import torch 2 import triton 3 import triton.language as tl 4 5 from liger_kernel.ops.utils import ensure_contiguous 6 7 8 @triton.jit 9 def embedding_forward_kernel( 10 embeddings_ptr, 11 indices_ptr, 12 output_ptr, 13 n_elements, 14 embedding_dim: tl.constexpr, 15 BLOCK_SIZE_M: tl.constexpr, 16 BLOCK_SIZE_N: tl.constexpr, 17 ): 18 pid_m = tl.program_id(0) 19 pid_n = tl.program_id(1) 20 21 start_m = pid_m * BLOCK_SIZE_M 22 start_n = pid_n * BLOCK_SIZE_N 23 offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M) 24 mask_m = offsets_m < n_elements 25 indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0) 26 offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N) 27 mask_n = offsets_n < embedding_dim 28 29 embedding_offsets = indices[:, None] * embedding_dim + offsets_n[None, :] 30 embeddings = tl.load( 31 embeddings_ptr + embedding_offsets, 32 mask=mask_m[:, None] & mask_n[None, :], 33 other=0.0, 34 ) 35 36 output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :] 37 tl.store( 38 output_ptr + output_offsets, embeddings, mask=mask_m[:, None] & mask_n[None, :] 39 ) 40 41 42 @triton.jit 43 def embedding_backward_kernel( 44 grad_output_ptr, 45 grad_weight_ptr, 46 indices_ptr, 47 n_elements, 48 embedding_dim: tl.constexpr, 49 BLOCK_SIZE_M: tl.constexpr, 50 BLOCK_SIZE_N: tl.constexpr, 51 ): 52 pid_m = tl.program_id(0) 53 pid_n = tl.program_id(1) 54 55 start_m = pid_m * BLOCK_SIZE_M 56 start_n = pid_n * BLOCK_SIZE_N 57 offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M) 58 mask_m = offsets_m < n_elements 59 indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0) 60 offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N) 61 mask_n = offsets_n < embedding_dim 62 63 grad_output = tl.load( 64 grad_output_ptr + offsets_m[:, None] * embedding_dim + offsets_n[None, :], 65 mask=mask_m[:, None] & mask_n[None, :], 66 other=0.0, 67 ) 68 69 grad_weight_offsets = indices[:, None] * embedding_dim + offsets_n[None, :] 70 71 tl.atomic_add( 72 grad_weight_ptr + grad_weight_offsets, 73 grad_output, 74 mask=mask_m[:, None] & mask_n[None, :], 75 ) 76 77 78 class LigerEmbeddingFunction(torch.autograd.Function): 79 @staticmethod 80 @ensure_contiguous 81 def forward(ctx, embeddings: torch.Tensor, indices: torch.Tensor): 82 ori_shape = indices.shape 83 indices = indices.view(-1) 84 output = torch.empty( 85 indices.shape[0], 86 embeddings.shape[1], 87 device=indices.device, 88 dtype=embeddings.dtype, 89 ) 90 91 n_elements = indices.numel() 92 embedding_dim = embeddings.shape[1] 93 94 BLOCK_SIZE_M = triton.next_power_of_2(min(128, embedding_dim)) 95 BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim)) 96 grid = ( 97 triton.cdiv(n_elements, BLOCK_SIZE_M), 98 triton.cdiv(embedding_dim, BLOCK_SIZE_N), 99 ) 100 101 embedding_forward_kernel[grid]( 102 embeddings, 103 indices, 104 output, 105 n_elements, 106 embedding_dim=embedding_dim, 107 BLOCK_SIZE_M=BLOCK_SIZE_M, 108 BLOCK_SIZE_N=BLOCK_SIZE_N, 109 ) 110 111 ctx.save_for_backward(indices, embeddings) 112 113 return output.view(*ori_shape, -1) 114 115 @staticmethod 116 @ensure_contiguous 117 def backward(ctx, grad_output: torch.Tensor): 118 indices, embedding_table = ctx.saved_tensors 119 grad_output = grad_output.contiguous().view(-1, embedding_table.shape[1]) 120 121 grad_weight = torch.zeros_like(embedding_table) 122 123 n_elements = indices.numel() 124 embedding_dim = embedding_table.shape[1] 125 126 BLOCK_SIZE_M = triton.next_power_of_2(min(128, embedding_dim)) 127 BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim)) 128 grid = ( 129 triton.cdiv(n_elements, BLOCK_SIZE_M), 130 triton.cdiv(embedding_dim, BLOCK_SIZE_N), 131 ) 132 133 embedding_backward_kernel[grid]( 134 grad_output, 135 grad_weight, 136 indices, 137 n_elements, 138 embedding_dim=embedding_dim, 139 BLOCK_SIZE_M=BLOCK_SIZE_M, 140 BLOCK_SIZE_N=BLOCK_SIZE_N, 141 ) 142 143 return grad_weight, None