/ benchmark / scripts / benchmark_qwen2vl_mrope.py
benchmark_qwen2vl_mrope.py
  1  import torch
  2  import triton
  3  from transformers.models.qwen2_vl.modeling_qwen2_vl import (
  4      Qwen2VLRotaryEmbedding,
  5      apply_multimodal_rotary_pos_emb,
  6  )
  7  from utils import (
  8      QUANTILES,
  9      SingleBenchmarkRunInput,
 10      SingleBenchmarkRunOutput,
 11      _test_memory,
 12      parse_benchmark_script_args,
 13      run_benchmarks,
 14  )
 15  
 16  from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
 17  from liger_kernel.utils import infer_device
 18  
 19  device = infer_device()
 20  
 21  
 22  def bench_speed_qwen2vl_mrope(
 23      input: SingleBenchmarkRunInput,
 24  ) -> SingleBenchmarkRunOutput:
 25      provider = input.kernel_provider
 26      mode = input.kernel_operation_mode
 27  
 28      extra_benchmark_config = input.extra_benchmark_config
 29      num_q_heads = extra_benchmark_config["num_q_heads"]
 30      num_kv_heads = extra_benchmark_config["num_kv_heads"]
 31      dtype = extra_benchmark_config["dtype"]
 32  
 33      # x can be either hidden_size or seq_len
 34      hidden_size = (
 35          extra_benchmark_config["hidden_size"]
 36          if "hidden_size" in extra_benchmark_config
 37          else input.x
 38      )
 39      seq_len = (
 40          extra_benchmark_config["seq_len"]
 41          if "seq_len" in extra_benchmark_config
 42          else input.x
 43      )
 44  
 45      head_dim = hidden_size // num_q_heads
 46      rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device)
 47      q = torch.randn(
 48          (1, seq_len, num_q_heads, head_dim),
 49          device=device,
 50          requires_grad=True,
 51          dtype=dtype,
 52      ).transpose(1, 2)
 53      k = torch.randn(
 54          (1, seq_len, num_kv_heads, head_dim),
 55          device=device,
 56          requires_grad=True,
 57          dtype=dtype,
 58      ).transpose(1, 2)
 59      dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like(
 60          k, device=device
 61      )
 62      pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1)
 63      cos, sin = rotary_emb(k, pos_ids)
 64  
 65      mrope_section_hw = head_dim * 3 // 16
 66      mrope_section = [
 67          head_dim // 2 - 2 * mrope_section_hw,
 68          mrope_section_hw,
 69          mrope_section_hw,
 70      ]
 71  
 72      def fwd():
 73          if provider == "liger":
 74              return liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section)
 75          elif provider == "huggingface":
 76              return apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section)
 77          else:
 78              raise ValueError(f"Invalid provider: {provider} for M-RoPE embedding")
 79  
 80      if mode == "forward":
 81          ms_50, ms_20, ms_80 = triton.testing.do_bench(
 82              fwd,
 83              grad_to_none=[q, k],
 84              rep=400,
 85              quantiles=QUANTILES,
 86          )
 87      elif mode == "backward":
 88          q_out, k_out = fwd()
 89          ms_50, ms_20, ms_80 = triton.testing.do_bench(
 90              lambda: torch.autograd.grad(
 91                  (q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True
 92              ),
 93              grad_to_none=[q, k],
 94              rep=400,
 95              quantiles=QUANTILES,
 96          )
 97      elif mode == "full":
 98  
 99          def full():
100              q_out, k_out = fwd()
101              torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True)
102  
103          ms_50, ms_20, ms_80 = triton.testing.do_bench(
104              full,
105              grad_to_none=[q, k],
106              rep=400,
107              quantiles=QUANTILES,
108          )
109      return SingleBenchmarkRunOutput(
110          y_20=ms_20,
111          y_50=ms_50,
112          y_80=ms_80,
113      )
114  
115  
116  def bench_memory_qwen2vl_mrope(
117      input: SingleBenchmarkRunInput,
118  ) -> SingleBenchmarkRunOutput:
119      provider = input.kernel_provider
120  
121      extra_benchmark_config = input.extra_benchmark_config
122      num_q_heads = extra_benchmark_config["num_q_heads"]
123      num_kv_heads = extra_benchmark_config["num_kv_heads"]
124      dtype = extra_benchmark_config["dtype"]
125  
126      # x can be either hidden_size or seq_len
127      hidden_size = (
128          extra_benchmark_config["hidden_size"]
129          if "hidden_size" in extra_benchmark_config
130          else input.x
131      )
132      seq_len = (
133          extra_benchmark_config["seq_len"]
134          if "seq_len" in extra_benchmark_config
135          else input.x
136      )
137  
138      head_dim = hidden_size // num_q_heads
139      rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device)
140      q = torch.randn(
141          (1, seq_len, num_q_heads, head_dim),
142          device=device,
143          requires_grad=True,
144          dtype=dtype,
145      ).transpose(1, 2)
146      k = torch.randn(
147          (1, seq_len, num_kv_heads, head_dim),
148          device=device,
149          requires_grad=True,
150          dtype=dtype,
151      ).transpose(1, 2)
152      dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like(
153          k, device=device
154      )
155      pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1)
156      cos, sin = rotary_emb(k, pos_ids)
157  
158      mrope_section_hw = head_dim * 3 // 16
159      mrope_section = [
160          head_dim // 2 - 2 * mrope_section_hw,
161          mrope_section_hw,
162          mrope_section_hw,
163      ]
164  
165      def full():
166          if provider == "liger":
167              q_out, k_out = liger_multimodal_rotary_pos_emb(
168                  q, k, cos, sin, mrope_section
169              )
170          else:
171              q_out, k_out = apply_multimodal_rotary_pos_emb(
172                  q, k, cos, sin, mrope_section
173              )
174          torch.autograd.grad(
175              (q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True
176          )
177  
178      mem_50, mem_20, mem_80 = _test_memory(
179          full,
180          quantiles=QUANTILES,
181      )
182      return SingleBenchmarkRunOutput(
183          y_20=mem_20,
184          y_50=mem_50,
185          y_80=mem_80,
186      )
187  
188  
189  if __name__ == "__main__":
190      args = parse_benchmark_script_args()
191  
192      common_configs_varying_hidden_size = {
193          "kernel_name": "qwen2vl_mrope",
194          "x_name": "H",
195          "x_label": "hidden size",
196          "x_values": [32 * (2**i) for i in range(4, 10, 2)],
197          "kernel_providers": ["liger", "huggingface"],
198          "extra_benchmark_configs": [
199              {
200                  "dtype": torch.bfloat16,
201                  "seq_len": 2048,
202                  "num_q_heads": 32,
203                  "num_kv_heads": 8,
204              }
205          ],
206          "overwrite": args.overwrite,
207      }
208      run_benchmarks(
209          bench_test_fn=bench_speed_qwen2vl_mrope,
210          kernel_operation_modes=["forward", "backward", "full"],
211          metric_name="speed",
212          metric_unit="ms",
213          **common_configs_varying_hidden_size,
214      )
215      run_benchmarks(
216          bench_test_fn=bench_memory_qwen2vl_mrope,
217          kernel_operation_modes=["full"],
218          metric_name="memory",
219          metric_unit="MB",
220          **common_configs_varying_hidden_size,
221      )
222  
223      common_configs_varying_seq_len = {
224          "kernel_name": "qwen2vl_mrope",
225          "x_name": "T",
226          "x_label": "sequence length",
227          "x_values": [2**i for i in range(10, 15)],
228          "kernel_providers": ["liger", "huggingface"],
229          "extra_benchmark_configs": [
230              {
231                  "dtype": torch.bfloat16,
232                  "hidden_size": 8192,
233                  "num_q_heads": 32,
234                  "num_kv_heads": 8,
235              }
236          ],
237          "overwrite": args.overwrite,
238      }
239      run_benchmarks(
240          bench_test_fn=bench_speed_qwen2vl_mrope,
241          kernel_operation_modes=["forward", "backward", "full"],
242          metric_name="speed",
243          metric_unit="ms",
244          **common_configs_varying_seq_len,
245      )
246      run_benchmarks(
247          bench_test_fn=bench_memory_qwen2vl_mrope,
248          kernel_operation_modes=["full"],
249          metric_name="memory",
250          metric_unit="MB",
251          **common_configs_varying_seq_len,
252      )