benchmark_kl_div.py
1 import torch 2 import torch.nn as nn 3 import triton 4 from utils import ( 5 QUANTILES, 6 SingleBenchmarkRunInput, 7 SingleBenchmarkRunOutput, 8 _test_memory, 9 parse_benchmark_script_args, 10 run_benchmarks, 11 ) 12 13 from liger_kernel.transformers.kl_div import LigerKLDIVLoss 14 from liger_kernel.utils import infer_device 15 16 device = infer_device() 17 18 S, E = 12, 18 19 20 21 def bench_speed_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: 22 reduction = "batchmean" 23 V = input.x 24 B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] 25 torch_kl_div = nn.KLDivLoss(reduction=reduction) 26 liger_kl_div = LigerKLDIVLoss(reduction=reduction) 27 28 _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax( 29 dim=-1 30 ) 31 target = torch.randn(B * T, V, device=device).softmax(dim=-1) 32 33 def fwd(): 34 if input.kernel_provider == "liger": 35 return liger_kl_div(_input, target) 36 else: 37 return torch_kl_div(_input, target) 38 39 if input.kernel_operation_mode == "forward": 40 ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100) 41 elif input.kernel_operation_mode == "backward": 42 y = fwd() 43 44 ms_50, ms_20, ms_80 = triton.testing.do_bench( 45 lambda: y.backward(retain_graph=True), 46 quantiles=QUANTILES, 47 grad_to_none=[_input], 48 rep=100, 49 ) 50 elif input.kernel_operation_mode == "full": 51 52 def full(): 53 y = fwd() 54 y.backward(retain_graph=True) 55 56 ms_50, ms_20, ms_80 = triton.testing.do_bench( 57 full, quantiles=QUANTILES, rep=100 58 ) 59 return SingleBenchmarkRunOutput( 60 y_20=ms_20, 61 y_50=ms_50, 62 y_80=ms_80, 63 ) 64 65 66 def bench_memory_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: 67 reduction = "batchmean" 68 torch_kl_div = nn.KLDivLoss(reduction=reduction) 69 liger_kl_div = LigerKLDIVLoss(reduction=reduction) 70 71 V = input.x 72 B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] 73 74 _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax( 75 dim=-1 76 ) 77 target = torch.randn(B * T, V, device=device).softmax(dim=-1) 78 79 def fwd(): 80 if input.kernel_provider == "liger": 81 return liger_kl_div(_input, target) 82 else: 83 return torch_kl_div(_input, target) 84 85 def full(): 86 y = fwd() 87 y.backward(retain_graph=True) 88 89 mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) 90 91 return SingleBenchmarkRunOutput( 92 y_20=mem_20, 93 y_50=mem_50, 94 y_80=mem_80, 95 ) 96 97 98 if __name__ == "__main__": 99 args = parse_benchmark_script_args() 100 common_args = { 101 "kernel_name": "kl_div", 102 "x_name": "V", 103 "x_label": "vocab size", 104 "x_values": [2**i for i in range(12, 18)], 105 "kernel_providers": ["liger", "torch"], 106 "extra_benchmark_configs": [{"B": 8, "T": 512}], 107 "overwrite": args.overwrite, 108 } 109 110 run_benchmarks( 111 bench_test_fn=bench_memory_kldiv, 112 kernel_operation_modes=["full"], 113 metric_name="memory", 114 metric_unit="MB", 115 **common_args, 116 ) 117 118 run_benchmarks( 119 bench_test_fn=bench_speed_kldiv, 120 kernel_operation_modes=["forward", "full"], 121 metric_name="speed", 122 metric_unit="ms", 123 **common_args, 124 )