/ benchmark / scripts / benchmark_kl_div.py
benchmark_kl_div.py
  1  import torch
  2  import torch.nn as nn
  3  import triton
  4  from utils import (
  5      QUANTILES,
  6      SingleBenchmarkRunInput,
  7      SingleBenchmarkRunOutput,
  8      _test_memory,
  9      parse_benchmark_script_args,
 10      run_benchmarks,
 11  )
 12  
 13  from liger_kernel.transformers.kl_div import LigerKLDIVLoss
 14  from liger_kernel.utils import infer_device
 15  
 16  device = infer_device()
 17  
 18  S, E = 12, 18
 19  
 20  
 21  def bench_speed_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
 22      reduction = "batchmean"
 23      V = input.x
 24      B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
 25      torch_kl_div = nn.KLDivLoss(reduction=reduction)
 26      liger_kl_div = LigerKLDIVLoss(reduction=reduction)
 27  
 28      _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(
 29          dim=-1
 30      )
 31      target = torch.randn(B * T, V, device=device).softmax(dim=-1)
 32  
 33      def fwd():
 34          if input.kernel_provider == "liger":
 35              return liger_kl_div(_input, target)
 36          else:
 37              return torch_kl_div(_input, target)
 38  
 39      if input.kernel_operation_mode == "forward":
 40          ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
 41      elif input.kernel_operation_mode == "backward":
 42          y = fwd()
 43  
 44          ms_50, ms_20, ms_80 = triton.testing.do_bench(
 45              lambda: y.backward(retain_graph=True),
 46              quantiles=QUANTILES,
 47              grad_to_none=[_input],
 48              rep=100,
 49          )
 50      elif input.kernel_operation_mode == "full":
 51  
 52          def full():
 53              y = fwd()
 54              y.backward(retain_graph=True)
 55  
 56          ms_50, ms_20, ms_80 = triton.testing.do_bench(
 57              full, quantiles=QUANTILES, rep=100
 58          )
 59      return SingleBenchmarkRunOutput(
 60          y_20=ms_20,
 61          y_50=ms_50,
 62          y_80=ms_80,
 63      )
 64  
 65  
 66  def bench_memory_kldiv(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
 67      reduction = "batchmean"
 68      torch_kl_div = nn.KLDivLoss(reduction=reduction)
 69      liger_kl_div = LigerKLDIVLoss(reduction=reduction)
 70  
 71      V = input.x
 72      B, T = input.extra_benchmark_config["B"], input.extra_benchmark_config["T"]
 73  
 74      _input = torch.randn(B * T, V, requires_grad=True, device=device).log_softmax(
 75          dim=-1
 76      )
 77      target = torch.randn(B * T, V, device=device).softmax(dim=-1)
 78  
 79      def fwd():
 80          if input.kernel_provider == "liger":
 81              return liger_kl_div(_input, target)
 82          else:
 83              return torch_kl_div(_input, target)
 84  
 85      def full():
 86          y = fwd()
 87          y.backward(retain_graph=True)
 88  
 89      mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
 90  
 91      return SingleBenchmarkRunOutput(
 92          y_20=mem_20,
 93          y_50=mem_50,
 94          y_80=mem_80,
 95      )
 96  
 97  
 98  if __name__ == "__main__":
 99      args = parse_benchmark_script_args()
100      common_args = {
101          "kernel_name": "kl_div",
102          "x_name": "V",
103          "x_label": "vocab size",
104          "x_values": [2**i for i in range(12, 18)],
105          "kernel_providers": ["liger", "torch"],
106          "extra_benchmark_configs": [{"B": 8, "T": 512}],
107          "overwrite": args.overwrite,
108      }
109  
110      run_benchmarks(
111          bench_test_fn=bench_memory_kldiv,
112          kernel_operation_modes=["full"],
113          metric_name="memory",
114          metric_unit="MB",
115          **common_args,
116      )
117  
118      run_benchmarks(
119          bench_test_fn=bench_speed_kldiv,
120          kernel_operation_modes=["forward", "full"],
121          metric_name="speed",
122          metric_unit="ms",
123          **common_args,
124      )