/ benchmark / scripts / benchmark_embedding.py
benchmark_embedding.py
  1  import torch
  2  import triton
  3  from torch.nn import Embedding
  4  from utils import (
  5      QUANTILES,
  6      SingleBenchmarkRunInput,
  7      SingleBenchmarkRunOutput,
  8      _test_memory,
  9      parse_benchmark_script_args,
 10      run_benchmarks,
 11  )
 12  
 13  from liger_kernel.transformers.experimental.embedding import LigerEmbedding
 14  from liger_kernel.utils import infer_device
 15  
 16  device = infer_device()
 17  
 18  # NOTE: For torch compile, we will just use default inductor settings. No further customization
 19  # is needed.
 20  
 21  
 22  def bench_speed_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
 23      V = input.x
 24      provider = input.kernel_provider
 25      mode = input.kernel_operation_mode
 26  
 27      B = input.extra_benchmark_config["B"]
 28      T = input.extra_benchmark_config["T"]
 29      D = input.extra_benchmark_config["D"]
 30      dtype = input.extra_benchmark_config["dtype"]
 31  
 32      torch_emb = Embedding(V, D).to(device).to(dtype)
 33      liger_emb = LigerEmbedding(V, D).to(device).to(dtype)
 34      torch_compile_emb = torch.compile(torch_emb)
 35  
 36      input_ids = torch.randint(0, V, (B, T), device=device)
 37  
 38      def fwd():
 39          if provider == "liger":
 40              return liger_emb(input_ids)
 41          elif provider == "torch_compile":
 42              return torch_compile_emb(input_ids)
 43          else:
 44              return torch_emb(input_ids)
 45  
 46      def full():
 47          output = fwd()
 48          output.backward(torch.randn_like(output))
 49  
 50      if mode == "forward":
 51          ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
 52      elif mode == "full":
 53          ms_50, ms_20, ms_80 = triton.testing.do_bench(
 54              full, quantiles=QUANTILES, rep=100
 55          )
 56      return SingleBenchmarkRunOutput(
 57          y_20=ms_20,
 58          y_50=ms_50,
 59          y_80=ms_80,
 60      )
 61  
 62  
 63  def bench_memory_embedding(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
 64      V = input.x
 65      provider = input.kernel_provider
 66  
 67      B = input.extra_benchmark_config["B"]
 68      T = input.extra_benchmark_config["T"]
 69      D = input.extra_benchmark_config["D"]
 70      dtype = input.extra_benchmark_config["dtype"]
 71  
 72      torch_emb = Embedding(V, D).to(device).to(dtype)
 73      liger_emb = LigerEmbedding(V, D).to(device).to(dtype)
 74      torch_compile_emb = torch.compile(torch_emb)
 75  
 76      input_ids = torch.randint(0, V, (B, T), device=device)
 77  
 78      def fwd():
 79          if provider == "liger":
 80              return liger_emb(input_ids)
 81          elif provider == "torch_compile":
 82              return torch_compile_emb(input_ids)
 83          else:
 84              return torch_emb(input_ids)
 85  
 86      def full():
 87          output = fwd()
 88          output.backward(torch.randn_like(output))
 89  
 90      mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
 91      return SingleBenchmarkRunOutput(
 92          y_20=mem_20,
 93          y_50=mem_50,
 94          y_80=mem_80,
 95      )
 96  
 97  
 98  if __name__ == "__main__":
 99      args = parse_benchmark_script_args()
100  
101      common_configs = {
102          "kernel_name": "embedding",
103          "x_name": "V",
104          "x_label": "embedding dimension",
105          "x_values": [2**i for i in range(10, 18)],
106          "kernel_providers": ["liger", "huggingface", "torch_compile"],
107          "extra_benchmark_configs": [
108              # BERT
109              {"B": 32, "T": 512, "D": 768, "dtype": torch.float32},
110              # Llama
111              {"B": 8, "T": 2048, "D": 4096, "dtype": torch.float32},
112          ],
113          "overwrite": args.overwrite,
114      }
115  
116      run_benchmarks(
117          bench_test_fn=bench_speed_embedding,
118          kernel_operation_modes=["forward", "full"],
119          metric_name="speed",
120          metric_unit="ms",
121          **common_configs
122      )
123      run_benchmarks(
124          bench_test_fn=bench_memory_embedding,
125          kernel_operation_modes=["full"],
126          metric_name="memory",
127          metric_unit="MB",
128          **common_configs
129      )