/ benchmark / scripts / benchmark_jsd.py
benchmark_jsd.py
  1  import torch
  2  import triton
  3  from utils import (
  4      QUANTILES,
  5      SingleBenchmarkRunInput,
  6      SingleBenchmarkRunOutput,
  7      _test_memory,
  8      parse_benchmark_script_args,
  9      run_benchmarks,
 10  )
 11  
 12  from liger_kernel.transformers.jsd import LigerJSD
 13  from liger_kernel.utils import infer_device
 14  
 15  device = infer_device()
 16  
 17  
 18  class TorchJSD(torch.nn.Module):
 19      def __init__(
 20          self,
 21          beta: float = 0.5,
 22          ignore_index: int = -100,
 23          dtype: torch.dtype = torch.float,
 24      ):
 25          super(TorchJSD, self).__init__()
 26          self.kl = torch.nn.KLDivLoss(reduction="none", log_target=True)
 27          self.beta = beta
 28          self.ignore_index = ignore_index
 29          self.dtype = dtype
 30  
 31      def forward(
 32          self,
 33          log_q: torch.Tensor,  # input
 34          log_p: torch.Tensor,  # target
 35          label=None,
 36      ):
 37          log_p, log_q = log_p.to(torch.float), log_q.to(torch.float)
 38          log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1))
 39          m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta)
 40          loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + (
 41              1 - self.beta
 42          ) * self.kl(torch.log(m), log_q).sum(dim=-1)
 43  
 44          if label is not None:
 45              loss = torch.where(label != self.ignore_index, loss, 0.0)
 46              n_non_ignore = (label != self.ignore_index).sum().item()
 47              if n_non_ignore == 0:
 48                  loss = 0.0
 49              else:
 50                  loss = (loss / n_non_ignore).sum()
 51          else:
 52              loss = (loss / log_q.shape[0]).sum()
 53          return loss.to(self.dtype)
 54  
 55  
 56  def bench_speed_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
 57      V = input.x
 58      B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
 59      torch_jsd = TorchJSD()
 60      liger_jsd = LigerJSD()
 61  
 62      _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(
 63          dim=-1
 64      )
 65      target = torch.randn(B * T, V, device=device).log_softmax(dim=-1)
 66  
 67      def fwd():
 68          if input.kernel_provider == "liger":
 69              return liger_jsd(_input, target)
 70          else:
 71              return torch_jsd(_input, target)
 72  
 73      if input.kernel_operation_mode == "forward":
 74          ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
 75      elif input.kernel_operation_mode == "backward":
 76          y = fwd()
 77  
 78          ms_50, ms_20, ms_80 = triton.testing.do_bench(
 79              lambda: y.backward(retain_graph=True),
 80              quantiles=QUANTILES,
 81              grad_to_none=[_input],
 82              rep=100,
 83          )
 84      elif input.kernel_operation_mode == "full":
 85  
 86          def full():
 87              y = fwd()
 88              y.backward(retain_graph=True)
 89  
 90          ms_50, ms_20, ms_80 = triton.testing.do_bench(
 91              full, quantiles=QUANTILES, rep=100
 92          )
 93      return SingleBenchmarkRunOutput(
 94          y_20=ms_20,
 95          y_50=ms_50,
 96          y_80=ms_80,
 97      )
 98  
 99  
100  def bench_memory_jsd(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
101      torch_jsd = TorchJSD()
102      liger_jsd = LigerJSD()
103  
104      V = input.x
105      B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
106  
107      _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(
108          dim=-1
109      )
110      target = torch.randn(B * T, V, device=device).log_softmax(dim=-1)
111  
112      def fwd():
113          if input.kernel_provider == "liger":
114              return liger_jsd(_input, target)
115          else:
116              return torch_jsd(_input, target)
117  
118      def full():
119          y = fwd()
120          y.backward(retain_graph=True)
121  
122      mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
123  
124      return SingleBenchmarkRunOutput(
125          y_20=mem_20,
126          y_50=mem_50,
127          y_80=mem_80,
128      )
129  
130  
131  if __name__ == "__main__":
132      args = parse_benchmark_script_args()
133      common_args = {
134          "kernel_name": "jsd",
135          "x_name": "V",
136          "x_label": "vocab size",
137          "x_values": [2**i for i in range(12, 18)],
138          "kernel_providers": ["liger", "torch"],
139          "extra_benchmark_configs": [{"B": 4, "T": 2048}],
140          "overwrite": args.overwrite,
141      }
142  
143      run_benchmarks(
144          bench_test_fn=bench_memory_jsd,
145          kernel_operation_modes=["full"],
146          metric_name="memory",
147          metric_unit="MB",
148          **common_args,
149      )
150  
151      run_benchmarks(
152          bench_test_fn=bench_speed_jsd,
153          kernel_operation_modes=["forward", "full"],
154          metric_name="speed",
155          metric_unit="ms",
156          **common_args,
157      )