benchmark_fused_linear_cross_entropy.py
1 import torch 2 import triton 3 from utils import ( 4 QUANTILES, 5 SingleBenchmarkRunInput, 6 SingleBenchmarkRunOutput, 7 _test_memory, 8 parse_benchmark_script_args, 9 run_benchmarks, 10 ) 11 12 from liger_kernel.transformers.fused_linear_cross_entropy import ( 13 LigerFusedLinearCrossEntropyLoss, 14 ) 15 from liger_kernel.utils import infer_device 16 17 device = infer_device() 18 19 20 class TorchLMHeadCE(torch.nn.Module): 21 """Ground truth implementation of the linear fused with torch based cross entropy loss. 22 23 :param H: hidden size 24 :param V: vocab size 25 :param ignore_index: index to ignore 26 :param reduction: reduction method 27 """ 28 29 def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): 30 super().__init__() 31 self.lin = torch.nn.Linear( 32 in_features=H, out_features=V, bias=False, dtype=dtype 33 ) 34 self.ce_loss = torch.nn.CrossEntropyLoss( 35 ignore_index=ignore_index, reduction="mean" 36 ) 37 38 def forward(self, x, y): 39 logits = self.lin(x) 40 return self.ce_loss(logits, y) 41 42 43 class LigerLMHeadCE(torch.nn.Module): 44 def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): 45 super().__init__() 46 self.lin = torch.nn.Linear( 47 in_features=H, out_features=V, bias=False, dtype=dtype 48 ) 49 self.ce_loss = LigerFusedLinearCrossEntropyLoss( 50 ignore_index=ignore_index, reduction="mean" 51 ) 52 53 def forward(self, x, y): 54 return self.ce_loss(self.lin.weight, x, y) 55 56 57 ############################################################################# 58 # Test the memory consumption of the linear fused cross entropy loss 59 ############################################################################# 60 61 62 def bench_memory_fused_linear_cross_entropy( 63 input: SingleBenchmarkRunInput, 64 ) -> SingleBenchmarkRunOutput: 65 BT = input.x 66 H = input.extra_benchmark_config["H"] 67 V = input.extra_benchmark_config["V"] 68 dtype = input.extra_benchmark_config["dtype"] 69 provider = input.kernel_provider 70 71 torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device) 72 liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device) 73 74 _input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device) 75 target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1) 76 77 def fwd(): 78 if provider == "liger": 79 return liger_lm_head_ce(_input, target) 80 elif provider == "huggingface": 81 return torch_lm_head_ce(_input, target) 82 83 def full(): 84 y = fwd() 85 y.backward() 86 87 mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) 88 return SingleBenchmarkRunOutput( 89 y_20=mem_20, 90 y_50=mem_50, 91 y_80=mem_80, 92 ) 93 94 95 # ############################################################################# 96 # # Test the speed of the fused linear cross entropy loss 97 # ############################################################################# 98 99 100 def bench_speed_fused_linear_cross_entropy( 101 input: SingleBenchmarkRunInput, 102 ) -> SingleBenchmarkRunOutput: 103 BT = input.x 104 H = input.extra_benchmark_config["H"] 105 V = input.extra_benchmark_config["V"] 106 dtype = input.extra_benchmark_config["dtype"] 107 provider = input.kernel_provider 108 mode = input.kernel_operation_mode 109 110 torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device) 111 liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device) 112 113 _input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device) 114 target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1) 115 116 def fwd(): 117 if provider == "liger": 118 return liger_lm_head_ce(_input, target) 119 elif provider == "huggingface": 120 return torch_lm_head_ce(_input, target) 121 122 if mode == "forward": 123 ms_50, ms_20, ms_80 = triton.testing.do_bench( 124 fwd, 125 rep=100, 126 quantiles=QUANTILES, 127 ) 128 elif mode == "backward": 129 y = fwd() 130 131 ms_50, ms_20, ms_80 = triton.testing.do_bench( 132 lambda: y.backward(retain_graph=True), 133 grad_to_none=[_input], 134 rep=100, 135 quantiles=QUANTILES, 136 ) 137 elif mode == "full": 138 139 def full(): 140 y = fwd() 141 y.backward() 142 143 ms_50, ms_20, ms_80 = triton.testing.do_bench( 144 full, 145 rep=100, 146 quantiles=QUANTILES, 147 ) 148 return SingleBenchmarkRunOutput( 149 y_20=ms_20, 150 y_50=ms_50, 151 y_80=ms_80, 152 ) 153 154 155 if __name__ == "__main__": 156 args = parse_benchmark_script_args() 157 158 common_configs = { 159 "kernel_name": "fused_linear_cross_entropy", 160 "x_name": "BT", 161 "x_label": "B x T", 162 "x_values": [2**i for i in range(12, 16)], 163 "kernel_providers": ["liger", "huggingface"], 164 "extra_benchmark_configs": [ 165 {"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16} 166 ], 167 "overwrite": args.overwrite, 168 } 169 170 run_benchmarks( 171 bench_test_fn=bench_speed_fused_linear_cross_entropy, 172 kernel_operation_modes=["forward", "full"], 173 metric_name="speed", 174 metric_unit="ms", 175 **common_configs 176 ) 177 run_benchmarks( 178 bench_test_fn=bench_memory_fused_linear_cross_entropy, 179 kernel_operation_modes=["full"], 180 metric_name="memory", 181 metric_unit="MB", 182 **common_configs 183 )