benchmark_embedding.py
1 import torch 2 import triton 3 from torch.nn import Embedding 4 from utils import ( 5 QUANTILES, 6 SingleBenchmarkRunInput, 7 SingleBenchmarkRunOutput, 8 _test_memory, 9 parse_benchmark_script_args, 10 run_benchmarks, 11 ) 12 13 from liger_kernel.transformers.experimental.embedding import LigerEmbedding 14 from liger_kernel.utils import infer_device 15 16 device = infer_device() 17 18 # NOTE: For torch compile, we will just use default inductor settings. No further customization 19 # is needed. 20 21 22 def bench_speed_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: 23 V = input.x 24 provider = input.kernel_provider 25 mode = input.kernel_operation_mode 26 27 B = input.extra_benchmark_config["B"] 28 T = input.extra_benchmark_config["T"] 29 D = input.extra_benchmark_config["D"] 30 dtype = input.extra_benchmark_config["dtype"] 31 32 torch_emb = Embedding(V, D).to(device).to(dtype) 33 liger_emb = LigerEmbedding(V, D).to(device).to(dtype) 34 torch_compile_emb = torch.compile(torch_emb) 35 36 input_ids = torch.randint(0, V, (B, T), device=device) 37 38 def fwd(): 39 if provider == "liger": 40 return liger_emb(input_ids) 41 elif provider == "torch_compile": 42 return torch_compile_emb(input_ids) 43 else: 44 return torch_emb(input_ids) 45 46 def full(): 47 output = fwd() 48 output.backward(torch.randn_like(output)) 49 50 if mode == "forward": 51 ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100) 52 elif mode == "full": 53 ms_50, ms_20, ms_80 = triton.testing.do_bench( 54 full, quantiles=QUANTILES, rep=100 55 ) 56 return SingleBenchmarkRunOutput( 57 y_20=ms_20, 58 y_50=ms_50, 59 y_80=ms_80, 60 ) 61 62 63 def bench_memory_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: 64 V = input.x 65 provider = input.kernel_provider 66 67 B = input.extra_benchmark_config["B"] 68 T = input.extra_benchmark_config["T"] 69 D = input.extra_benchmark_config["D"] 70 dtype = input.extra_benchmark_config["dtype"] 71 72 torch_emb = Embedding(V, D).to(device).to(dtype) 73 liger_emb = LigerEmbedding(V, D).to(device).to(dtype) 74 torch_compile_emb = torch.compile(torch_emb) 75 76 input_ids = torch.randint(0, V, (B, T), device=device) 77 78 def fwd(): 79 if provider == "liger": 80 return liger_emb(input_ids) 81 elif provider == "torch_compile": 82 return torch_compile_emb(input_ids) 83 else: 84 return torch_emb(input_ids) 85 86 def full(): 87 output = fwd() 88 output.backward(torch.randn_like(output)) 89 90 mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) 91 return SingleBenchmarkRunOutput( 92 y_20=mem_20, 93 y_50=mem_50, 94 y_80=mem_80, 95 ) 96 97 98 if __name__ == "__main__": 99 args = parse_benchmark_script_args() 100 101 common_configs = { 102 "kernel_name": "embedding", 103 "x_name": "V", 104 "x_label": "embedding dimension", 105 "x_values": [2**i for i in range(10, 18)], 106 "kernel_providers": ["liger", "huggingface", "torch_compile"], 107 "extra_benchmark_configs": [ 108 # BERT 109 {"B": 32, "T": 512, "D": 768, "dtype": torch.float32}, 110 # Llama 111 {"B": 8, "T": 2048, "D": 4096, "dtype": torch.float32}, 112 ], 113 "overwrite": args.overwrite, 114 } 115 116 run_benchmarks( 117 bench_test_fn=bench_speed_embedding, 118 kernel_operation_modes=["forward", "full"], 119 metric_name="speed", 120 metric_unit="ms", 121 **common_configs 122 ) 123 run_benchmarks( 124 bench_test_fn=bench_memory_embedding, 125 kernel_operation_modes=["full"], 126 metric_name="memory", 127 metric_unit="MB", 128 **common_configs 129 )