benchmark_layer_norm.py
1 import torch 2 import triton 3 from utils import ( 4 QUANTILES, 5 SingleBenchmarkRunInput, 6 SingleBenchmarkRunOutput, 7 _test_memory, 8 parse_benchmark_script_args, 9 run_benchmarks, 10 ) 11 12 from liger_kernel.transformers.layer_norm import LigerLayerNorm 13 from liger_kernel.utils import infer_device 14 15 device = infer_device() 16 17 18 def bench_speed_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: 19 N = input.x 20 provider = input.kernel_provider 21 mode = input.kernel_operation_mode 22 extra_benchmark_config = input.extra_benchmark_config 23 M = extra_benchmark_config["M"] 24 eps = extra_benchmark_config["eps"] 25 dtype = extra_benchmark_config["dtype"] 26 27 x_shape = (M, N) 28 triton_ln = LigerLayerNorm(hidden_size=N).to(device) 29 torch_ln = torch.nn.LayerNorm(N, eps=eps).to(device) 30 31 x = torch.randn(x_shape, dtype=dtype, device=device) 32 dy = torch.randn_like(x) 33 x.requires_grad_(True) 34 35 def y_fwd(): 36 if provider == "liger": 37 return triton_ln(x) 38 if provider == "huggingface": 39 return torch_ln(x) 40 41 if mode == "forward": 42 ms_50, ms_20, ms_80 = triton.testing.do_bench( 43 y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500 44 ) 45 elif mode == "backward": 46 y = y_fwd() 47 ms_50, ms_20, ms_80 = triton.testing.do_bench( 48 lambda: y.backward(dy, retain_graph=True), 49 quantiles=QUANTILES, 50 grad_to_none=[x], 51 rep=500, 52 ) 53 elif mode == "full": 54 55 def full(): 56 y = y_fwd() 57 y.backward(dy, retain_graph=True) 58 59 ms_50, ms_20, ms_80 = triton.testing.do_bench( 60 full, quantiles=QUANTILES, grad_to_none=[x], rep=500 61 ) 62 63 return SingleBenchmarkRunOutput( 64 y_20=ms_20, 65 y_50=ms_50, 66 y_80=ms_80, 67 ) 68 69 70 def bench_memory_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: 71 N = input.x 72 provider = input.kernel_provider 73 dtype = input.extra_benchmark_config["dtype"] 74 M = input.extra_benchmark_config["M"] 75 eps = input.extra_benchmark_config["eps"] 76 77 x_shape = (M, N) 78 79 triton_ln = LigerLayerNorm(hidden_size=N).to(device) 80 torch_ln = torch.nn.LayerNorm(N, eps=eps).to(device) 81 82 x = torch.randn(x_shape, dtype=dtype, device=device) 83 dy = torch.randn_like(x) 84 x.requires_grad_(True) 85 86 def y_fwd(): 87 if provider == "liger": 88 return triton_ln(x) 89 if provider == "huggingface": 90 return torch_ln(x) 91 92 def full(): 93 y = y_fwd() 94 y.backward(dy, retain_graph=True) 95 96 mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) 97 return SingleBenchmarkRunOutput( 98 y_20=mem_20, 99 y_50=mem_50, 100 y_80=mem_80, 101 ) 102 103 104 if __name__ == "__main__": 105 args = parse_benchmark_script_args() 106 107 common_configs = { 108 "kernel_name": "layer_norm", 109 "x_name": "N", 110 "x_label": "hidden size", 111 "x_values": [2**i for i in range(10, 15)], 112 "kernel_providers": ["liger", "huggingface"], 113 "extra_benchmark_configs": [{"M": 4096, "dtype": torch.float32, "eps": 1e-6}], 114 "overwrite": args.overwrite, 115 } 116 117 run_benchmarks( 118 bench_test_fn=bench_speed_layer_norm, 119 kernel_operation_modes=["forward", "full"], 120 metric_name="speed", 121 metric_unit="ms", 122 **common_configs 123 ) 124 run_benchmarks( 125 bench_test_fn=bench_memory_layer_norm, 126 kernel_operation_modes=["full"], 127 metric_name="memory", 128 metric_unit="MB", 129 **common_configs 130 )