/ benchmark / scripts / benchmark_fused_linear_cross_entropy.py
benchmark_fused_linear_cross_entropy.py
  1  import torch
  2  import triton
  3  from utils import (
  4      QUANTILES,
  5      SingleBenchmarkRunInput,
  6      SingleBenchmarkRunOutput,
  7      _test_memory,
  8      parse_benchmark_script_args,
  9      run_benchmarks,
 10  )
 11  
 12  from liger_kernel.transformers.fused_linear_cross_entropy import (
 13      LigerFusedLinearCrossEntropyLoss,
 14  )
 15  from liger_kernel.utils import infer_device
 16  
 17  device = infer_device()
 18  
 19  
 20  class TorchLMHeadCE(torch.nn.Module):
 21      """Ground truth implementation of the linear fused with torch based cross entropy loss.
 22  
 23      :param H: hidden size
 24      :param V: vocab size
 25      :param ignore_index: index to ignore
 26      :param reduction: reduction method
 27      """
 28  
 29      def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
 30          super().__init__()
 31          self.lin = torch.nn.Linear(
 32              in_features=H, out_features=V, bias=False, dtype=dtype
 33          )
 34          self.ce_loss = torch.nn.CrossEntropyLoss(
 35              ignore_index=ignore_index, reduction="mean"
 36          )
 37  
 38      def forward(self, x, y):
 39          logits = self.lin(x)
 40          return self.ce_loss(logits, y)
 41  
 42  
 43  class LigerLMHeadCE(torch.nn.Module):
 44      def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
 45          super().__init__()
 46          self.lin = torch.nn.Linear(
 47              in_features=H, out_features=V, bias=False, dtype=dtype
 48          )
 49          self.ce_loss = LigerFusedLinearCrossEntropyLoss(
 50              ignore_index=ignore_index, reduction="mean"
 51          )
 52  
 53      def forward(self, x, y):
 54          return self.ce_loss(self.lin.weight, x, y)
 55  
 56  
 57  #############################################################################
 58  # Test the memory consumption of the linear fused cross entropy loss
 59  #############################################################################
 60  
 61  
 62  def bench_memory_fused_linear_cross_entropy(
 63      input: SingleBenchmarkRunInput,
 64  ) -> SingleBenchmarkRunOutput:
 65      BT = input.x
 66      H = input.extra_benchmark_config["H"]
 67      V = input.extra_benchmark_config["V"]
 68      dtype = input.extra_benchmark_config["dtype"]
 69      provider = input.kernel_provider
 70  
 71      torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
 72      liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)
 73  
 74      _input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device)
 75      target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1)
 76  
 77      def fwd():
 78          if provider == "liger":
 79              return liger_lm_head_ce(_input, target)
 80          elif provider == "huggingface":
 81              return torch_lm_head_ce(_input, target)
 82  
 83      def full():
 84          y = fwd()
 85          y.backward()
 86  
 87      mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
 88      return SingleBenchmarkRunOutput(
 89          y_20=mem_20,
 90          y_50=mem_50,
 91          y_80=mem_80,
 92      )
 93  
 94  
 95  # #############################################################################
 96  # # Test the speed of the fused linear cross entropy loss
 97  # #############################################################################
 98  
 99  
100  def bench_speed_fused_linear_cross_entropy(
101      input: SingleBenchmarkRunInput,
102  ) -> SingleBenchmarkRunOutput:
103      BT = input.x
104      H = input.extra_benchmark_config["H"]
105      V = input.extra_benchmark_config["V"]
106      dtype = input.extra_benchmark_config["dtype"]
107      provider = input.kernel_provider
108      mode = input.kernel_operation_mode
109  
110      torch_lm_head_ce = TorchLMHeadCE(H=H, V=V, dtype=dtype).to(device)
111      liger_lm_head_ce = LigerLMHeadCE(H=H, V=V, dtype=dtype).to(device)
112  
113      _input = torch.randn(BT, H, requires_grad=True, dtype=dtype, device=device)
114      target = torch.randint(V, (BT, 1), dtype=torch.long, device=device).squeeze(1)
115  
116      def fwd():
117          if provider == "liger":
118              return liger_lm_head_ce(_input, target)
119          elif provider == "huggingface":
120              return torch_lm_head_ce(_input, target)
121  
122      if mode == "forward":
123          ms_50, ms_20, ms_80 = triton.testing.do_bench(
124              fwd,
125              rep=100,
126              quantiles=QUANTILES,
127          )
128      elif mode == "backward":
129          y = fwd()
130  
131          ms_50, ms_20, ms_80 = triton.testing.do_bench(
132              lambda: y.backward(retain_graph=True),
133              grad_to_none=[_input],
134              rep=100,
135              quantiles=QUANTILES,
136          )
137      elif mode == "full":
138  
139          def full():
140              y = fwd()
141              y.backward()
142  
143          ms_50, ms_20, ms_80 = triton.testing.do_bench(
144              full,
145              rep=100,
146              quantiles=QUANTILES,
147          )
148      return SingleBenchmarkRunOutput(
149          y_20=ms_20,
150          y_50=ms_50,
151          y_80=ms_80,
152      )
153  
154  
155  if __name__ == "__main__":
156      args = parse_benchmark_script_args()
157  
158      common_configs = {
159          "kernel_name": "fused_linear_cross_entropy",
160          "x_name": "BT",
161          "x_label": "B x T",
162          "x_values": [2**i for i in range(12, 16)],
163          "kernel_providers": ["liger", "huggingface"],
164          "extra_benchmark_configs": [
165              {"H": 4096, "V": 128256, "mode": "forward", "dtype": torch.bfloat16}
166          ],
167          "overwrite": args.overwrite,
168      }
169  
170      run_benchmarks(
171          bench_test_fn=bench_speed_fused_linear_cross_entropy,
172          kernel_operation_modes=["forward", "full"],
173          metric_name="speed",
174          metric_unit="ms",
175          **common_configs
176      )
177      run_benchmarks(
178          bench_test_fn=bench_memory_fused_linear_cross_entropy,
179          kernel_operation_modes=["full"],
180          metric_name="memory",
181          metric_unit="MB",
182          **common_configs
183      )