benchmark_jsd.py
1 import torch 2 import triton 3 from utils import ( 4 QUANTILES, 5 SingleBenchmarkRunInput, 6 SingleBenchmarkRunOutput, 7 _test_memory, 8 parse_benchmark_script_args, 9 run_benchmarks, 10 ) 11 12 from liger_kernel.transformers.jsd import LigerJSD 13 from liger_kernel.utils import infer_device 14 15 device = infer_device() 16 17 18 class TorchJSD(torch.nn.Module): 19 def __init__( 20 self, 21 beta: float = 0.5, 22 ignore_index: int = -100, 23 dtype: torch.dtype = torch.float, 24 ): 25 super(TorchJSD, self).__init__() 26 self.kl = torch.nn.KLDivLoss(reduction="none", log_target=True) 27 self.beta = beta 28 self.ignore_index = ignore_index 29 self.dtype = dtype 30 31 def forward( 32 self, 33 log_q: torch.Tensor, # input 34 log_p: torch.Tensor, # target 35 label=None, 36 ): 37 log_p, log_q = log_p.to(torch.float), log_q.to(torch.float) 38 log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1)) 39 m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta) 40 loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + ( 41 1 - self.beta 42 ) * self.kl(torch.log(m), log_q).sum(dim=-1) 43 44 if label is not None: 45 loss = torch.where(label != self.ignore_index, loss, 0.0) 46 n_non_ignore = (label != self.ignore_index).sum().item() 47 if n_non_ignore == 0: 48 loss = 0.0 49 else: 50 loss = (loss / n_non_ignore).sum() 51 else: 52 loss = (loss / log_q.shape[0]).sum() 53 return loss.to(self.dtype) 54 55 56 def bench_speed_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: 57 V = input.x 58 B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] 59 torch_jsd = TorchJSD() 60 liger_jsd = LigerJSD() 61 62 _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax( 63 dim=-1 64 ) 65 target = torch.randn(B * T, V, device=device).log_softmax(dim=-1) 66 67 def fwd(): 68 if input.kernel_provider == "liger": 69 return liger_jsd(_input, target) 70 else: 71 return torch_jsd(_input, target) 72 73 if input.kernel_operation_mode == "forward": 74 ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100) 75 elif input.kernel_operation_mode == "backward": 76 y = fwd() 77 78 ms_50, ms_20, ms_80 = triton.testing.do_bench( 79 lambda: y.backward(retain_graph=True), 80 quantiles=QUANTILES, 81 grad_to_none=[_input], 82 rep=100, 83 ) 84 elif input.kernel_operation_mode == "full": 85 86 def full(): 87 y = fwd() 88 y.backward(retain_graph=True) 89 90 ms_50, ms_20, ms_80 = triton.testing.do_bench( 91 full, quantiles=QUANTILES, rep=100 92 ) 93 return SingleBenchmarkRunOutput( 94 y_20=ms_20, 95 y_50=ms_50, 96 y_80=ms_80, 97 ) 98 99 100 def bench_memory_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: 101 torch_jsd = TorchJSD() 102 liger_jsd = LigerJSD() 103 104 V = input.x 105 B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"] 106 107 _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax( 108 dim=-1 109 ) 110 target = torch.randn(B * T, V, device=device).log_softmax(dim=-1) 111 112 def fwd(): 113 if input.kernel_provider == "liger": 114 return liger_jsd(_input, target) 115 else: 116 return torch_jsd(_input, target) 117 118 def full(): 119 y = fwd() 120 y.backward(retain_graph=True) 121 122 mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) 123 124 return SingleBenchmarkRunOutput( 125 y_20=mem_20, 126 y_50=mem_50, 127 y_80=mem_80, 128 ) 129 130 131 if __name__ == "__main__": 132 args = parse_benchmark_script_args() 133 common_args = { 134 "kernel_name": "jsd", 135 "x_name": "V", 136 "x_label": "vocab size", 137 "x_values": [2**i for i in range(12, 18)], 138 "kernel_providers": ["liger", "torch"], 139 "extra_benchmark_configs": [{"B": 4, "T": 2048}], 140 "overwrite": args.overwrite, 141 } 142 143 run_benchmarks( 144 bench_test_fn=bench_memory_jsd, 145 kernel_operation_modes=["full"], 146 metric_name="memory", 147 metric_unit="MB", 148 **common_args, 149 ) 150 151 run_benchmarks( 152 bench_test_fn=bench_speed_jsd, 153 kernel_operation_modes=["forward", "full"], 154 metric_name="speed", 155 metric_unit="ms", 156 **common_args, 157 )