/ benchmark / scripts / benchmark_cross_entropy.py
benchmark_cross_entropy.py
  1  import torch
  2  import triton
  3  from torch.nn import CrossEntropyLoss
  4  from utils import (
  5      QUANTILES,
  6      SingleBenchmarkRunInput,
  7      SingleBenchmarkRunOutput,
  8      _test_memory,
  9      parse_benchmark_script_args,
 10      run_benchmarks,
 11  )
 12  
 13  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
 14  from liger_kernel.utils import infer_device
 15  
 16  device = infer_device()
 17  
 18  
 19  def bench_memory_cross_entropy(
 20      input: SingleBenchmarkRunInput,
 21  ) -> SingleBenchmarkRunOutput:
 22      torch_ce = CrossEntropyLoss()
 23      liger_ce = LigerCrossEntropyLoss()
 24  
 25      V = input.x
 26      provider = input.kernel_provider
 27      B = input.extra_benchmark_config["B"]
 28      T = input.extra_benchmark_config["T"]
 29  
 30      _input = torch.randn(B * T, V, requires_grad=True, device=device)
 31      target = torch.randint(V, (B * T, 1), device=device).squeeze(1)
 32  
 33      def fwd():
 34          if provider == "liger":
 35              return liger_ce(_input, target)
 36          else:
 37              return torch_ce(_input, target)
 38  
 39      def full():
 40          y = fwd()
 41          y.backward()
 42  
 43      mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
 44      return SingleBenchmarkRunOutput(
 45          y_20=mem_20,
 46          y_50=mem_50,
 47          y_80=mem_80,
 48      )
 49  
 50  
 51  def bench_speed_cross_entropy(
 52      input: SingleBenchmarkRunInput,
 53  ) -> SingleBenchmarkRunOutput:
 54      torch_ce = CrossEntropyLoss()
 55      liger_ce = LigerCrossEntropyLoss()
 56  
 57      V = input.x
 58      provider = input.kernel_provider
 59      mode = input.kernel_operation_mode
 60      B = input.extra_benchmark_config["B"]
 61      T = input.extra_benchmark_config["T"]
 62  
 63      _input = torch.randn(B * T, V, requires_grad=True, device=device)
 64      target = torch.randint(V, (B * T, 1), device=device).squeeze(1)
 65  
 66      def fwd():
 67          if provider == "liger":
 68              return liger_ce(_input, target)
 69          else:
 70              return torch_ce(_input, target)
 71  
 72      if mode == "forward":
 73          ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES)
 74      elif mode == "backward":
 75          y = fwd()
 76  
 77          ms_50, ms_20, ms_80 = triton.testing.do_bench(
 78              lambda: y.backward(retain_graph=True),
 79              grad_to_none=[_input],
 80              rep=100,
 81              quantiles=QUANTILES,
 82          )
 83      elif mode == "full":
 84  
 85          def full():
 86              y = fwd()
 87              y.backward()
 88  
 89          ms_50, ms_20, ms_80 = triton.testing.do_bench(
 90              full, rep=100, quantiles=QUANTILES
 91          )
 92  
 93      return SingleBenchmarkRunOutput(
 94          y_20=ms_20,
 95          y_50=ms_50,
 96          y_80=ms_80,
 97      )
 98  
 99  
100  if __name__ == "__main__":
101      args = parse_benchmark_script_args()
102  
103      common_configs = {
104          "kernel_name": "cross_entropy",
105          "x_name": "V",
106          "x_label": "vocab size",
107          "x_values": [2**i for i in range(12, 18)],
108          "kernel_providers": ["liger", "huggingface"],
109          "extra_benchmark_configs": [{"B": 8, "T": 2048}],
110          "overwrite": args.overwrite,
111      }
112  
113      run_benchmarks(
114          bench_test_fn=bench_speed_cross_entropy,
115          kernel_operation_modes=["forward", "full"],
116          metric_name="speed",
117          metric_unit="ms",
118          **common_configs
119      )
120      run_benchmarks(
121          bench_test_fn=bench_memory_cross_entropy,
122          kernel_operation_modes=["full"],
123          metric_name="memory",
124          metric_unit="MB",
125          **common_configs
126      )