benchmark_qwen2vl_mrope.py
1 import torch 2 import triton 3 from transformers.models.qwen2_vl.modeling_qwen2_vl import ( 4 Qwen2VLRotaryEmbedding, 5 apply_multimodal_rotary_pos_emb, 6 ) 7 from utils import ( 8 QUANTILES, 9 SingleBenchmarkRunInput, 10 SingleBenchmarkRunOutput, 11 _test_memory, 12 parse_benchmark_script_args, 13 run_benchmarks, 14 ) 15 16 from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb 17 from liger_kernel.utils import infer_device 18 19 device = infer_device() 20 21 22 def bench_speed_qwen2vl_mrope( 23 input: SingleBenchmarkRunInput, 24 ) -> SingleBenchmarkRunOutput: 25 provider = input.kernel_provider 26 mode = input.kernel_operation_mode 27 28 extra_benchmark_config = input.extra_benchmark_config 29 num_q_heads = extra_benchmark_config["num_q_heads"] 30 num_kv_heads = extra_benchmark_config["num_kv_heads"] 31 dtype = extra_benchmark_config["dtype"] 32 33 # x can be either hidden_size or seq_len 34 hidden_size = ( 35 extra_benchmark_config["hidden_size"] 36 if "hidden_size" in extra_benchmark_config 37 else input.x 38 ) 39 seq_len = ( 40 extra_benchmark_config["seq_len"] 41 if "seq_len" in extra_benchmark_config 42 else input.x 43 ) 44 45 head_dim = hidden_size // num_q_heads 46 rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device) 47 q = torch.randn( 48 (1, seq_len, num_q_heads, head_dim), 49 device=device, 50 requires_grad=True, 51 dtype=dtype, 52 ).transpose(1, 2) 53 k = torch.randn( 54 (1, seq_len, num_kv_heads, head_dim), 55 device=device, 56 requires_grad=True, 57 dtype=dtype, 58 ).transpose(1, 2) 59 dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like( 60 k, device=device 61 ) 62 pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1) 63 cos, sin = rotary_emb(k, pos_ids) 64 65 mrope_section_hw = head_dim * 3 // 16 66 mrope_section = [ 67 head_dim // 2 - 2 * mrope_section_hw, 68 mrope_section_hw, 69 mrope_section_hw, 70 ] 71 72 def fwd(): 73 if provider == "liger": 74 return liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section) 75 elif provider == "huggingface": 76 return apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section) 77 else: 78 raise ValueError(f"Invalid provider: {provider} for M-RoPE embedding") 79 80 if mode == "forward": 81 ms_50, ms_20, ms_80 = triton.testing.do_bench( 82 fwd, 83 grad_to_none=[q, k], 84 rep=400, 85 quantiles=QUANTILES, 86 ) 87 elif mode == "backward": 88 q_out, k_out = fwd() 89 ms_50, ms_20, ms_80 = triton.testing.do_bench( 90 lambda: torch.autograd.grad( 91 (q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True 92 ), 93 grad_to_none=[q, k], 94 rep=400, 95 quantiles=QUANTILES, 96 ) 97 elif mode == "full": 98 99 def full(): 100 q_out, k_out = fwd() 101 torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True) 102 103 ms_50, ms_20, ms_80 = triton.testing.do_bench( 104 full, 105 grad_to_none=[q, k], 106 rep=400, 107 quantiles=QUANTILES, 108 ) 109 return SingleBenchmarkRunOutput( 110 y_20=ms_20, 111 y_50=ms_50, 112 y_80=ms_80, 113 ) 114 115 116 def bench_memory_qwen2vl_mrope( 117 input: SingleBenchmarkRunInput, 118 ) -> SingleBenchmarkRunOutput: 119 provider = input.kernel_provider 120 121 extra_benchmark_config = input.extra_benchmark_config 122 num_q_heads = extra_benchmark_config["num_q_heads"] 123 num_kv_heads = extra_benchmark_config["num_kv_heads"] 124 dtype = extra_benchmark_config["dtype"] 125 126 # x can be either hidden_size or seq_len 127 hidden_size = ( 128 extra_benchmark_config["hidden_size"] 129 if "hidden_size" in extra_benchmark_config 130 else input.x 131 ) 132 seq_len = ( 133 extra_benchmark_config["seq_len"] 134 if "seq_len" in extra_benchmark_config 135 else input.x 136 ) 137 138 head_dim = hidden_size // num_q_heads 139 rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device) 140 q = torch.randn( 141 (1, seq_len, num_q_heads, head_dim), 142 device=device, 143 requires_grad=True, 144 dtype=dtype, 145 ).transpose(1, 2) 146 k = torch.randn( 147 (1, seq_len, num_kv_heads, head_dim), 148 device=device, 149 requires_grad=True, 150 dtype=dtype, 151 ).transpose(1, 2) 152 dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like( 153 k, device=device 154 ) 155 pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1) 156 cos, sin = rotary_emb(k, pos_ids) 157 158 mrope_section_hw = head_dim * 3 // 16 159 mrope_section = [ 160 head_dim // 2 - 2 * mrope_section_hw, 161 mrope_section_hw, 162 mrope_section_hw, 163 ] 164 165 def full(): 166 if provider == "liger": 167 q_out, k_out = liger_multimodal_rotary_pos_emb( 168 q, k, cos, sin, mrope_section 169 ) 170 else: 171 q_out, k_out = apply_multimodal_rotary_pos_emb( 172 q, k, cos, sin, mrope_section 173 ) 174 torch.autograd.grad( 175 (q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True 176 ) 177 178 mem_50, mem_20, mem_80 = _test_memory( 179 full, 180 quantiles=QUANTILES, 181 ) 182 return SingleBenchmarkRunOutput( 183 y_20=mem_20, 184 y_50=mem_50, 185 y_80=mem_80, 186 ) 187 188 189 if __name__ == "__main__": 190 args = parse_benchmark_script_args() 191 192 common_configs_varying_hidden_size = { 193 "kernel_name": "qwen2vl_mrope", 194 "x_name": "H", 195 "x_label": "hidden size", 196 "x_values": [32 * (2**i) for i in range(4, 10, 2)], 197 "kernel_providers": ["liger", "huggingface"], 198 "extra_benchmark_configs": [ 199 { 200 "dtype": torch.bfloat16, 201 "seq_len": 2048, 202 "num_q_heads": 32, 203 "num_kv_heads": 8, 204 } 205 ], 206 "overwrite": args.overwrite, 207 } 208 run_benchmarks( 209 bench_test_fn=bench_speed_qwen2vl_mrope, 210 kernel_operation_modes=["forward", "backward", "full"], 211 metric_name="speed", 212 metric_unit="ms", 213 **common_configs_varying_hidden_size, 214 ) 215 run_benchmarks( 216 bench_test_fn=bench_memory_qwen2vl_mrope, 217 kernel_operation_modes=["full"], 218 metric_name="memory", 219 metric_unit="MB", 220 **common_configs_varying_hidden_size, 221 ) 222 223 common_configs_varying_seq_len = { 224 "kernel_name": "qwen2vl_mrope", 225 "x_name": "T", 226 "x_label": "sequence length", 227 "x_values": [2**i for i in range(10, 15)], 228 "kernel_providers": ["liger", "huggingface"], 229 "extra_benchmark_configs": [ 230 { 231 "dtype": torch.bfloat16, 232 "hidden_size": 8192, 233 "num_q_heads": 32, 234 "num_kv_heads": 8, 235 } 236 ], 237 "overwrite": args.overwrite, 238 } 239 run_benchmarks( 240 bench_test_fn=bench_speed_qwen2vl_mrope, 241 kernel_operation_modes=["forward", "backward", "full"], 242 metric_name="speed", 243 metric_unit="ms", 244 **common_configs_varying_seq_len, 245 ) 246 run_benchmarks( 247 bench_test_fn=bench_memory_qwen2vl_mrope, 248 kernel_operation_modes=["full"], 249 metric_name="memory", 250 metric_unit="MB", 251 **common_configs_varying_seq_len, 252 )