benchmark_rope.py
1 import torch 2 import triton 3 from transformers.models.llama.modeling_llama import ( 4 LlamaRotaryEmbedding, 5 apply_rotary_pos_emb, 6 ) 7 from utils import ( 8 QUANTILES, 9 SingleBenchmarkRunInput, 10 SingleBenchmarkRunOutput, 11 _test_memory, 12 parse_benchmark_script_args, 13 run_benchmarks, 14 ) 15 16 from liger_kernel.transformers.rope import liger_rotary_pos_emb 17 from liger_kernel.utils import infer_device 18 19 device = infer_device() 20 21 22 def bench_speed_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: 23 provider = input.kernel_provider 24 mode = input.kernel_operation_mode 25 26 extra_benchmark_config = input.extra_benchmark_config 27 num_q_heads = extra_benchmark_config["num_q_heads"] 28 num_kv_heads = extra_benchmark_config["num_kv_heads"] 29 dtype = extra_benchmark_config["dtype"] 30 31 # x can be either hidden_size or seq_len 32 hidden_size = ( 33 extra_benchmark_config["hidden_size"] 34 if "hidden_size" in extra_benchmark_config 35 else input.x 36 ) 37 seq_len = ( 38 extra_benchmark_config["seq_len"] 39 if "seq_len" in extra_benchmark_config 40 else input.x 41 ) 42 43 head_dim = hidden_size // num_q_heads 44 rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) 45 q = torch.randn( 46 (1, seq_len, num_q_heads, head_dim), 47 device=device, 48 requires_grad=True, 49 dtype=dtype, 50 ).transpose(1, 2) 51 k = torch.randn( 52 (1, seq_len, num_kv_heads, head_dim), 53 device=device, 54 requires_grad=True, 55 dtype=dtype, 56 ).transpose(1, 2) 57 dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like( 58 k, device=device 59 ) 60 pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) 61 cos, sin = rotary_emb(k, pos_ids) 62 63 def fwd(): 64 if provider == "liger": 65 return liger_rotary_pos_emb(q, k, cos, sin, pos_ids) 66 elif provider == "huggingface": 67 return apply_rotary_pos_emb(q, k, cos, sin, pos_ids) 68 else: 69 raise ValueError(f"Invalid provider: {provider} for RoPE embedding") 70 71 if mode == "forward": 72 ms_50, ms_20, ms_80 = triton.testing.do_bench( 73 fwd, 74 grad_to_none=[q, k], 75 rep=400, 76 quantiles=QUANTILES, 77 ) 78 elif mode == "backward": 79 q_out, k_out = fwd() 80 ms_50, ms_20, ms_80 = triton.testing.do_bench( 81 lambda: torch.autograd.grad( 82 (q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True 83 ), 84 grad_to_none=[q, k], 85 rep=400, 86 quantiles=QUANTILES, 87 ) 88 elif mode == "full": 89 90 def full(): 91 q_out, k_out = fwd() 92 torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True) 93 94 ms_50, ms_20, ms_80 = triton.testing.do_bench( 95 full, 96 grad_to_none=[q, k], 97 rep=400, 98 quantiles=QUANTILES, 99 ) 100 return SingleBenchmarkRunOutput( 101 y_20=ms_20, 102 y_50=ms_50, 103 y_80=ms_80, 104 ) 105 106 107 def bench_memory_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: 108 provider = input.kernel_provider 109 110 extra_benchmark_config = input.extra_benchmark_config 111 num_q_heads = extra_benchmark_config["num_q_heads"] 112 num_kv_heads = extra_benchmark_config["num_kv_heads"] 113 dtype = extra_benchmark_config["dtype"] 114 115 # x can be either hidden_size or seq_len 116 hidden_size = ( 117 extra_benchmark_config["hidden_size"] 118 if "hidden_size" in extra_benchmark_config 119 else input.x 120 ) 121 seq_len = ( 122 extra_benchmark_config["seq_len"] 123 if "seq_len" in extra_benchmark_config 124 else input.x 125 ) 126 127 head_dim = hidden_size // num_q_heads 128 rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) 129 q = torch.randn( 130 (1, seq_len, num_q_heads, head_dim), 131 device=device, 132 requires_grad=True, 133 dtype=dtype, 134 ).transpose(1, 2) 135 k = torch.randn( 136 (1, seq_len, num_kv_heads, head_dim), 137 device=device, 138 requires_grad=True, 139 dtype=dtype, 140 ).transpose(1, 2) 141 dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like( 142 k, device=device 143 ) 144 pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) 145 cos, sin = rotary_emb(k, pos_ids) 146 147 def full(): 148 if provider == "liger": 149 q_out, k_out = liger_rotary_pos_emb(q, k, cos, sin, pos_ids) 150 else: 151 q_out, k_out = apply_rotary_pos_emb(q, k, cos, sin, pos_ids) 152 torch.autograd.grad( 153 (q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True 154 ) 155 156 mem_50, mem_20, mem_80 = _test_memory( 157 full, 158 quantiles=QUANTILES, 159 ) 160 return SingleBenchmarkRunOutput( 161 y_20=mem_20, 162 y_50=mem_50, 163 y_80=mem_80, 164 ) 165 166 167 if __name__ == "__main__": 168 args = parse_benchmark_script_args() 169 170 common_configs_varying_hidden_size = { 171 "kernel_name": "rope", 172 "x_name": "H", 173 "x_label": "hidden size", 174 "x_values": [32 * (2**i) for i in range(4, 10, 2)], 175 "kernel_providers": ["liger", "huggingface"], 176 "extra_benchmark_configs": [ 177 { 178 "dtype": torch.bfloat16, 179 "seq_len": 2048, 180 "num_q_heads": 32, 181 "num_kv_heads": 8, 182 } 183 ], 184 "overwrite": args.overwrite, 185 } 186 run_benchmarks( 187 bench_test_fn=bench_speed_rope, 188 kernel_operation_modes=["forward", "backward", "full"], 189 metric_name="speed", 190 metric_unit="ms", 191 **common_configs_varying_hidden_size, 192 ) 193 run_benchmarks( 194 bench_test_fn=bench_memory_rope, 195 kernel_operation_modes=["full"], 196 metric_name="memory", 197 metric_unit="MB", 198 **common_configs_varying_hidden_size, 199 ) 200 201 common_configs_varying_seq_len = { 202 "kernel_name": "rope", 203 "x_name": "T", 204 "x_label": "sequence length", 205 "x_values": [2**i for i in range(10, 15)], 206 "kernel_providers": ["liger", "huggingface"], 207 "extra_benchmark_configs": [ 208 { 209 "dtype": torch.bfloat16, 210 "hidden_size": 8192, 211 "num_q_heads": 32, 212 "num_kv_heads": 8, 213 } 214 ], 215 "overwrite": args.overwrite, 216 } 217 run_benchmarks( 218 bench_test_fn=bench_speed_rope, 219 kernel_operation_modes=["forward", "backward", "full"], 220 metric_name="speed", 221 metric_unit="ms", 222 **common_configs_varying_seq_len, 223 ) 224 run_benchmarks( 225 bench_test_fn=bench_memory_rope, 226 kernel_operation_modes=["full"], 227 metric_name="memory", 228 metric_unit="MB", 229 **common_configs_varying_seq_len, 230 )