/ benchmark / scripts / benchmark_rope.py
benchmark_rope.py
  1  import torch
  2  import triton
  3  from transformers.models.llama.modeling_llama import (
  4      LlamaRotaryEmbedding,
  5      apply_rotary_pos_emb,
  6  )
  7  from utils import (
  8      QUANTILES,
  9      SingleBenchmarkRunInput,
 10      SingleBenchmarkRunOutput,
 11      _test_memory,
 12      parse_benchmark_script_args,
 13      run_benchmarks,
 14  )
 15  
 16  from liger_kernel.transformers.rope import liger_rotary_pos_emb
 17  from liger_kernel.utils import infer_device
 18  
 19  device = infer_device()
 20  
 21  
 22  def bench_speed_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
 23      provider = input.kernel_provider
 24      mode = input.kernel_operation_mode
 25  
 26      extra_benchmark_config = input.extra_benchmark_config
 27      num_q_heads = extra_benchmark_config["num_q_heads"]
 28      num_kv_heads = extra_benchmark_config["num_kv_heads"]
 29      dtype = extra_benchmark_config["dtype"]
 30  
 31      # x can be either hidden_size or seq_len
 32      hidden_size = (
 33          extra_benchmark_config["hidden_size"]
 34          if "hidden_size" in extra_benchmark_config
 35          else input.x
 36      )
 37      seq_len = (
 38          extra_benchmark_config["seq_len"]
 39          if "seq_len" in extra_benchmark_config
 40          else input.x
 41      )
 42  
 43      head_dim = hidden_size // num_q_heads
 44      rotary_emb = LlamaRotaryEmbedding(head_dim, device=device)
 45      q = torch.randn(
 46          (1, seq_len, num_q_heads, head_dim),
 47          device=device,
 48          requires_grad=True,
 49          dtype=dtype,
 50      ).transpose(1, 2)
 51      k = torch.randn(
 52          (1, seq_len, num_kv_heads, head_dim),
 53          device=device,
 54          requires_grad=True,
 55          dtype=dtype,
 56      ).transpose(1, 2)
 57      dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like(
 58          k, device=device
 59      )
 60      pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
 61      cos, sin = rotary_emb(k, pos_ids)
 62  
 63      def fwd():
 64          if provider == "liger":
 65              return liger_rotary_pos_emb(q, k, cos, sin, pos_ids)
 66          elif provider == "huggingface":
 67              return apply_rotary_pos_emb(q, k, cos, sin, pos_ids)
 68          else:
 69              raise ValueError(f"Invalid provider: {provider} for RoPE embedding")
 70  
 71      if mode == "forward":
 72          ms_50, ms_20, ms_80 = triton.testing.do_bench(
 73              fwd,
 74              grad_to_none=[q, k],
 75              rep=400,
 76              quantiles=QUANTILES,
 77          )
 78      elif mode == "backward":
 79          q_out, k_out = fwd()
 80          ms_50, ms_20, ms_80 = triton.testing.do_bench(
 81              lambda: torch.autograd.grad(
 82                  (q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True
 83              ),
 84              grad_to_none=[q, k],
 85              rep=400,
 86              quantiles=QUANTILES,
 87          )
 88      elif mode == "full":
 89  
 90          def full():
 91              q_out, k_out = fwd()
 92              torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True)
 93  
 94          ms_50, ms_20, ms_80 = triton.testing.do_bench(
 95              full,
 96              grad_to_none=[q, k],
 97              rep=400,
 98              quantiles=QUANTILES,
 99          )
100      return SingleBenchmarkRunOutput(
101          y_20=ms_20,
102          y_50=ms_50,
103          y_80=ms_80,
104      )
105  
106  
107  def bench_memory_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
108      provider = input.kernel_provider
109  
110      extra_benchmark_config = input.extra_benchmark_config
111      num_q_heads = extra_benchmark_config["num_q_heads"]
112      num_kv_heads = extra_benchmark_config["num_kv_heads"]
113      dtype = extra_benchmark_config["dtype"]
114  
115      # x can be either hidden_size or seq_len
116      hidden_size = (
117          extra_benchmark_config["hidden_size"]
118          if "hidden_size" in extra_benchmark_config
119          else input.x
120      )
121      seq_len = (
122          extra_benchmark_config["seq_len"]
123          if "seq_len" in extra_benchmark_config
124          else input.x
125      )
126  
127      head_dim = hidden_size // num_q_heads
128      rotary_emb = LlamaRotaryEmbedding(head_dim, device=device)
129      q = torch.randn(
130          (1, seq_len, num_q_heads, head_dim),
131          device=device,
132          requires_grad=True,
133          dtype=dtype,
134      ).transpose(1, 2)
135      k = torch.randn(
136          (1, seq_len, num_kv_heads, head_dim),
137          device=device,
138          requires_grad=True,
139          dtype=dtype,
140      ).transpose(1, 2)
141      dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like(
142          k, device=device
143      )
144      pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
145      cos, sin = rotary_emb(k, pos_ids)
146  
147      def full():
148          if provider == "liger":
149              q_out, k_out = liger_rotary_pos_emb(q, k, cos, sin, pos_ids)
150          else:
151              q_out, k_out = apply_rotary_pos_emb(q, k, cos, sin, pos_ids)
152          torch.autograd.grad(
153              (q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True
154          )
155  
156      mem_50, mem_20, mem_80 = _test_memory(
157          full,
158          quantiles=QUANTILES,
159      )
160      return SingleBenchmarkRunOutput(
161          y_20=mem_20,
162          y_50=mem_50,
163          y_80=mem_80,
164      )
165  
166  
167  if __name__ == "__main__":
168      args = parse_benchmark_script_args()
169  
170      common_configs_varying_hidden_size = {
171          "kernel_name": "rope",
172          "x_name": "H",
173          "x_label": "hidden size",
174          "x_values": [32 * (2**i) for i in range(4, 10, 2)],
175          "kernel_providers": ["liger", "huggingface"],
176          "extra_benchmark_configs": [
177              {
178                  "dtype": torch.bfloat16,
179                  "seq_len": 2048,
180                  "num_q_heads": 32,
181                  "num_kv_heads": 8,
182              }
183          ],
184          "overwrite": args.overwrite,
185      }
186      run_benchmarks(
187          bench_test_fn=bench_speed_rope,
188          kernel_operation_modes=["forward", "backward", "full"],
189          metric_name="speed",
190          metric_unit="ms",
191          **common_configs_varying_hidden_size,
192      )
193      run_benchmarks(
194          bench_test_fn=bench_memory_rope,
195          kernel_operation_modes=["full"],
196          metric_name="memory",
197          metric_unit="MB",
198          **common_configs_varying_hidden_size,
199      )
200  
201      common_configs_varying_seq_len = {
202          "kernel_name": "rope",
203          "x_name": "T",
204          "x_label": "sequence length",
205          "x_values": [2**i for i in range(10, 15)],
206          "kernel_providers": ["liger", "huggingface"],
207          "extra_benchmark_configs": [
208              {
209                  "dtype": torch.bfloat16,
210                  "hidden_size": 8192,
211                  "num_q_heads": 32,
212                  "num_kv_heads": 8,
213              }
214          ],
215          "overwrite": args.overwrite,
216      }
217      run_benchmarks(
218          bench_test_fn=bench_speed_rope,
219          kernel_operation_modes=["forward", "backward", "full"],
220          metric_name="speed",
221          metric_unit="ms",
222          **common_configs_varying_seq_len,
223      )
224      run_benchmarks(
225          bench_test_fn=bench_memory_rope,
226          kernel_operation_modes=["full"],
227          metric_name="memory",
228          metric_unit="MB",
229          **common_configs_varying_seq_len,
230      )