benchmark_cross_entropy.py
1 import torch 2 import triton 3 from torch.nn import CrossEntropyLoss 4 from utils import ( 5 QUANTILES, 6 SingleBenchmarkRunInput, 7 SingleBenchmarkRunOutput, 8 _test_memory, 9 parse_benchmark_script_args, 10 run_benchmarks, 11 ) 12 13 from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss 14 from liger_kernel.utils import infer_device 15 16 device = infer_device() 17 18 19 def bench_memory_cross_entropy( 20 input: SingleBenchmarkRunInput, 21 ) -> SingleBenchmarkRunOutput: 22 torch_ce = CrossEntropyLoss() 23 liger_ce = LigerCrossEntropyLoss() 24 25 V = input.x 26 provider = input.kernel_provider 27 B = input.extra_benchmark_config["B"] 28 T = input.extra_benchmark_config["T"] 29 30 _input = torch.randn(B * T, V, requires_grad=True, device=device) 31 target = torch.randint(V, (B * T, 1), device=device).squeeze(1) 32 33 def fwd(): 34 if provider == "liger": 35 return liger_ce(_input, target) 36 else: 37 return torch_ce(_input, target) 38 39 def full(): 40 y = fwd() 41 y.backward() 42 43 mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) 44 return SingleBenchmarkRunOutput( 45 y_20=mem_20, 46 y_50=mem_50, 47 y_80=mem_80, 48 ) 49 50 51 def bench_speed_cross_entropy( 52 input: SingleBenchmarkRunInput, 53 ) -> SingleBenchmarkRunOutput: 54 torch_ce = CrossEntropyLoss() 55 liger_ce = LigerCrossEntropyLoss() 56 57 V = input.x 58 provider = input.kernel_provider 59 mode = input.kernel_operation_mode 60 B = input.extra_benchmark_config["B"] 61 T = input.extra_benchmark_config["T"] 62 63 _input = torch.randn(B * T, V, requires_grad=True, device=device) 64 target = torch.randint(V, (B * T, 1), device=device).squeeze(1) 65 66 def fwd(): 67 if provider == "liger": 68 return liger_ce(_input, target) 69 else: 70 return torch_ce(_input, target) 71 72 if mode == "forward": 73 ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) 74 elif mode == "backward": 75 y = fwd() 76 77 ms_50, ms_20, ms_80 = triton.testing.do_bench( 78 lambda: y.backward(retain_graph=True), 79 grad_to_none=[_input], 80 rep=100, 81 quantiles=QUANTILES, 82 ) 83 elif mode == "full": 84 85 def full(): 86 y = fwd() 87 y.backward() 88 89 ms_50, ms_20, ms_80 = triton.testing.do_bench( 90 full, rep=100, quantiles=QUANTILES 91 ) 92 93 return SingleBenchmarkRunOutput( 94 y_20=ms_20, 95 y_50=ms_50, 96 y_80=ms_80, 97 ) 98 99 100 if __name__ == "__main__": 101 args = parse_benchmark_script_args() 102 103 common_configs = { 104 "kernel_name": "cross_entropy", 105 "x_name": "V", 106 "x_label": "vocab size", 107 "x_values": [2**i for i in range(12, 18)], 108 "kernel_providers": ["liger", "huggingface"], 109 "extra_benchmark_configs": [{"B": 8, "T": 2048}], 110 "overwrite": args.overwrite, 111 } 112 113 run_benchmarks( 114 bench_test_fn=bench_speed_cross_entropy, 115 kernel_operation_modes=["forward", "full"], 116 metric_name="speed", 117 metric_unit="ms", 118 **common_configs 119 ) 120 run_benchmarks( 121 bench_test_fn=bench_memory_cross_entropy, 122 kernel_operation_modes=["full"], 123 metric_name="memory", 124 metric_unit="MB", 125 **common_configs 126 )