/ benchmark / scripts / benchmark_layer_norm.py
benchmark_layer_norm.py
  1  import torch
  2  import triton
  3  from utils import (
  4      QUANTILES,
  5      SingleBenchmarkRunInput,
  6      SingleBenchmarkRunOutput,
  7      _test_memory,
  8      parse_benchmark_script_args,
  9      run_benchmarks,
 10  )
 11  
 12  from liger_kernel.transformers.layer_norm import LigerLayerNorm
 13  from liger_kernel.utils import infer_device
 14  
 15  device = infer_device()
 16  
 17  
 18  def bench_speed_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
 19      N = input.x
 20      provider = input.kernel_provider
 21      mode = input.kernel_operation_mode
 22      extra_benchmark_config = input.extra_benchmark_config
 23      M = extra_benchmark_config["M"]
 24      eps = extra_benchmark_config["eps"]
 25      dtype = extra_benchmark_config["dtype"]
 26  
 27      x_shape = (M, N)
 28      triton_ln = LigerLayerNorm(hidden_size=N).to(device)
 29      torch_ln = torch.nn.LayerNorm(N, eps=eps).to(device)
 30  
 31      x = torch.randn(x_shape, dtype=dtype, device=device)
 32      dy = torch.randn_like(x)
 33      x.requires_grad_(True)
 34  
 35      def y_fwd():
 36          if provider == "liger":
 37              return triton_ln(x)
 38          if provider == "huggingface":
 39              return torch_ln(x)
 40  
 41      if mode == "forward":
 42          ms_50, ms_20, ms_80 = triton.testing.do_bench(
 43              y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500
 44          )
 45      elif mode == "backward":
 46          y = y_fwd()
 47          ms_50, ms_20, ms_80 = triton.testing.do_bench(
 48              lambda: y.backward(dy, retain_graph=True),
 49              quantiles=QUANTILES,
 50              grad_to_none=[x],
 51              rep=500,
 52          )
 53      elif mode == "full":
 54  
 55          def full():
 56              y = y_fwd()
 57              y.backward(dy, retain_graph=True)
 58  
 59          ms_50, ms_20, ms_80 = triton.testing.do_bench(
 60              full, quantiles=QUANTILES, grad_to_none=[x], rep=500
 61          )
 62  
 63      return SingleBenchmarkRunOutput(
 64          y_20=ms_20,
 65          y_50=ms_50,
 66          y_80=ms_80,
 67      )
 68  
 69  
 70  def bench_memory_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
 71      N = input.x
 72      provider = input.kernel_provider
 73      dtype = input.extra_benchmark_config["dtype"]
 74      M = input.extra_benchmark_config["M"]
 75      eps = input.extra_benchmark_config["eps"]
 76  
 77      x_shape = (M, N)
 78  
 79      triton_ln = LigerLayerNorm(hidden_size=N).to(device)
 80      torch_ln = torch.nn.LayerNorm(N, eps=eps).to(device)
 81  
 82      x = torch.randn(x_shape, dtype=dtype, device=device)
 83      dy = torch.randn_like(x)
 84      x.requires_grad_(True)
 85  
 86      def y_fwd():
 87          if provider == "liger":
 88              return triton_ln(x)
 89          if provider == "huggingface":
 90              return torch_ln(x)
 91  
 92      def full():
 93          y = y_fwd()
 94          y.backward(dy, retain_graph=True)
 95  
 96      mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
 97      return SingleBenchmarkRunOutput(
 98          y_20=mem_20,
 99          y_50=mem_50,
100          y_80=mem_80,
101      )
102  
103  
104  if __name__ == "__main__":
105      args = parse_benchmark_script_args()
106  
107      common_configs = {
108          "kernel_name": "layer_norm",
109          "x_name": "N",
110          "x_label": "hidden size",
111          "x_values": [2**i for i in range(10, 15)],
112          "kernel_providers": ["liger", "huggingface"],
113          "extra_benchmark_configs": [{"M": 4096, "dtype": torch.float32, "eps": 1e-6}],
114          "overwrite": args.overwrite,
115      }
116  
117      run_benchmarks(
118          bench_test_fn=bench_speed_layer_norm,
119          kernel_operation_modes=["forward", "full"],
120          metric_name="speed",
121          metric_unit="ms",
122          **common_configs
123      )
124      run_benchmarks(
125          bench_test_fn=bench_memory_layer_norm,
126          kernel_operation_modes=["full"],
127          metric_name="memory",
128          metric_unit="MB",
129          **common_configs
130      )