benchmark_swiglu.py
1 import torch 2 import triton 3 from transformers.models.llama.configuration_llama import LlamaConfig 4 from transformers.models.llama.modeling_llama import LlamaMLP 5 from utils import ( 6 QUANTILES, 7 SingleBenchmarkRunInput, 8 SingleBenchmarkRunOutput, 9 _test_memory, 10 parse_benchmark_script_args, 11 run_benchmarks, 12 ) 13 14 from liger_kernel.transformers.swiglu import LigerSwiGLUMLP 15 from liger_kernel.utils import infer_device 16 17 device = infer_device() 18 19 20 def bench_speed_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: 21 seq_len = input.x 22 provider = input.kernel_provider 23 mode = input.kernel_operation_mode 24 25 extra_benchmark_config = input.extra_benchmark_config 26 bsz = extra_benchmark_config["B"] 27 hidden_size = extra_benchmark_config["hidden_size"] 28 dtype = extra_benchmark_config["dtype"] 29 intermediate_size = extra_benchmark_config["intermediate_size"] 30 hidden_act = extra_benchmark_config["hidden_act"] 31 32 llama_config = LlamaConfig( 33 hidden_size=hidden_size, 34 intermediate_size=intermediate_size, 35 hidden_act=hidden_act, 36 ) 37 38 x_shape = (bsz, seq_len, hidden_size) 39 40 # initialize input 41 x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True) 42 43 if provider == "liger": 44 layer = LigerSwiGLUMLP(config=llama_config).to(device).to(dtype) 45 elif provider == "huggingface": 46 layer = LlamaMLP(config=llama_config).to(device).to(dtype) 47 else: 48 raise ValueError(f"Invalid provider: {provider} for SwiGLU") 49 50 def fwd(): 51 return layer(x) 52 53 if mode == "forward": 54 ms_50, ms_20, ms_80 = triton.testing.do_bench( 55 fwd, 56 grad_to_none=[x], 57 quantiles=QUANTILES, 58 rep=10, 59 ) 60 elif mode == "backward": 61 do = torch.randn_like(x) 62 y = fwd() 63 ms_50, ms_20, ms_80 = triton.testing.do_bench( 64 lambda: y.backward(do, retain_graph=True), 65 grad_to_none=[x], 66 quantiles=QUANTILES, 67 rep=10, 68 ) 69 else: 70 71 def full(): 72 y = fwd() 73 y.backward(torch.randn_like(y), retain_graph=True) 74 75 ms_50, ms_20, ms_80 = triton.testing.do_bench( 76 full, 77 grad_to_none=[x], 78 quantiles=QUANTILES, 79 rep=10, 80 ) 81 82 return SingleBenchmarkRunOutput( 83 y_20=ms_20, 84 y_50=ms_50, 85 y_80=ms_80, 86 ) 87 88 89 def bench_memory_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: 90 seq_len = input.x 91 provider = input.kernel_provider 92 mode = input.kernel_operation_mode 93 94 extra_benchmark_config = input.extra_benchmark_config 95 bsz = extra_benchmark_config["B"] 96 hidden_size = extra_benchmark_config["hidden_size"] 97 dtype = extra_benchmark_config["dtype"] 98 intermediate_size = extra_benchmark_config["intermediate_size"] 99 hidden_act = extra_benchmark_config["hidden_act"] 100 101 llama_config = LlamaConfig( 102 hidden_size=hidden_size, 103 intermediate_size=intermediate_size, 104 hidden_act=hidden_act, 105 ) 106 107 x_shape = (bsz, seq_len, hidden_size) 108 109 # initialize input 110 x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True) 111 112 if provider == "liger": 113 layer = LigerSwiGLUMLP(config=llama_config).to(device).to(dtype) 114 elif provider == "huggingface": 115 layer = LlamaMLP(config=llama_config).to(device).to(dtype) 116 else: 117 raise ValueError(f"Invalid provider: {provider} for SwiGLU") 118 119 def fwd(): 120 return layer(x) 121 122 def full(): 123 y = fwd() 124 y.backward(torch.randn_like(y), retain_graph=True) 125 126 if mode == "forward": 127 mem_50, mem_20, mem_80 = _test_memory(fwd, quantiles=QUANTILES) 128 elif mode == "backward": 129 do = torch.randn_like(x) 130 y = fwd() 131 mem_50, mem_20, mem_80 = _test_memory( 132 lambda: y.backward(do, retain_graph=True), quantiles=QUANTILES 133 ) 134 else: 135 mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) 136 137 return SingleBenchmarkRunOutput( 138 y_20=mem_20, 139 y_50=mem_50, 140 y_80=mem_80, 141 ) 142 143 144 if __name__ == "__main__": 145 args = parse_benchmark_script_args() 146 147 common_configs = { 148 "kernel_name": "swiglu", 149 "x_name": "T", 150 "x_label": "sequence length", 151 "x_values": [2**i for i in range(10, 14)], 152 "kernel_providers": ["liger", "huggingface"], 153 "extra_benchmark_configs": [ 154 { 155 "B": 4, 156 "hidden_size": 4096, 157 "dtype": torch.bfloat16, 158 "intermediate_size": 11008, 159 "hidden_act": "silu", 160 } 161 ], 162 "overwrite": args.overwrite, 163 } 164 165 run_benchmarks( 166 bench_test_fn=bench_speed_swiglu, 167 kernel_operation_modes=["forward"], 168 metric_name="speed", 169 metric_unit="ms", 170 **common_configs, 171 ) 172 run_benchmarks( 173 bench_test_fn=bench_memory_swiglu, 174 kernel_operation_modes=["full"], 175 metric_name="memory", 176 metric_unit="MB", 177 **common_configs, 178 )