/ benchmark / scripts / benchmark_swiglu.py
benchmark_swiglu.py
  1  import torch
  2  import triton
  3  from transformers.models.llama.configuration_llama import LlamaConfig
  4  from transformers.models.llama.modeling_llama import LlamaMLP
  5  from utils import (
  6      QUANTILES,
  7      SingleBenchmarkRunInput,
  8      SingleBenchmarkRunOutput,
  9      _test_memory,
 10      parse_benchmark_script_args,
 11      run_benchmarks,
 12  )
 13  
 14  from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
 15  from liger_kernel.utils import infer_device
 16  
 17  device = infer_device()
 18  
 19  
 20  def bench_speed_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
 21      seq_len = input.x
 22      provider = input.kernel_provider
 23      mode = input.kernel_operation_mode
 24  
 25      extra_benchmark_config = input.extra_benchmark_config
 26      bsz = extra_benchmark_config["B"]
 27      hidden_size = extra_benchmark_config["hidden_size"]
 28      dtype = extra_benchmark_config["dtype"]
 29      intermediate_size = extra_benchmark_config["intermediate_size"]
 30      hidden_act = extra_benchmark_config["hidden_act"]
 31  
 32      llama_config = LlamaConfig(
 33          hidden_size=hidden_size,
 34          intermediate_size=intermediate_size,
 35          hidden_act=hidden_act,
 36      )
 37  
 38      x_shape = (bsz, seq_len, hidden_size)
 39  
 40      # initialize input
 41      x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True)
 42  
 43      if provider == "liger":
 44          layer = LigerSwiGLUMLP(config=llama_config).to(device).to(dtype)
 45      elif provider == "huggingface":
 46          layer = LlamaMLP(config=llama_config).to(device).to(dtype)
 47      else:
 48          raise ValueError(f"Invalid provider: {provider} for SwiGLU")
 49  
 50      def fwd():
 51          return layer(x)
 52  
 53      if mode == "forward":
 54          ms_50, ms_20, ms_80 = triton.testing.do_bench(
 55              fwd,
 56              grad_to_none=[x],
 57              quantiles=QUANTILES,
 58              rep=10,
 59          )
 60      elif mode == "backward":
 61          do = torch.randn_like(x)
 62          y = fwd()
 63          ms_50, ms_20, ms_80 = triton.testing.do_bench(
 64              lambda: y.backward(do, retain_graph=True),
 65              grad_to_none=[x],
 66              quantiles=QUANTILES,
 67              rep=10,
 68          )
 69      else:
 70  
 71          def full():
 72              y = fwd()
 73              y.backward(torch.randn_like(y), retain_graph=True)
 74  
 75          ms_50, ms_20, ms_80 = triton.testing.do_bench(
 76              full,
 77              grad_to_none=[x],
 78              quantiles=QUANTILES,
 79              rep=10,
 80          )
 81  
 82      return SingleBenchmarkRunOutput(
 83          y_20=ms_20,
 84          y_50=ms_50,
 85          y_80=ms_80,
 86      )
 87  
 88  
 89  def bench_memory_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
 90      seq_len = input.x
 91      provider = input.kernel_provider
 92      mode = input.kernel_operation_mode
 93  
 94      extra_benchmark_config = input.extra_benchmark_config
 95      bsz = extra_benchmark_config["B"]
 96      hidden_size = extra_benchmark_config["hidden_size"]
 97      dtype = extra_benchmark_config["dtype"]
 98      intermediate_size = extra_benchmark_config["intermediate_size"]
 99      hidden_act = extra_benchmark_config["hidden_act"]
100  
101      llama_config = LlamaConfig(
102          hidden_size=hidden_size,
103          intermediate_size=intermediate_size,
104          hidden_act=hidden_act,
105      )
106  
107      x_shape = (bsz, seq_len, hidden_size)
108  
109      # initialize input
110      x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True)
111  
112      if provider == "liger":
113          layer = LigerSwiGLUMLP(config=llama_config).to(device).to(dtype)
114      elif provider == "huggingface":
115          layer = LlamaMLP(config=llama_config).to(device).to(dtype)
116      else:
117          raise ValueError(f"Invalid provider: {provider} for SwiGLU")
118  
119      def fwd():
120          return layer(x)
121  
122      def full():
123          y = fwd()
124          y.backward(torch.randn_like(y), retain_graph=True)
125  
126      if mode == "forward":
127          mem_50, mem_20, mem_80 = _test_memory(fwd, quantiles=QUANTILES)
128      elif mode == "backward":
129          do = torch.randn_like(x)
130          y = fwd()
131          mem_50, mem_20, mem_80 = _test_memory(
132              lambda: y.backward(do, retain_graph=True), quantiles=QUANTILES
133          )
134      else:
135          mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
136  
137      return SingleBenchmarkRunOutput(
138          y_20=mem_20,
139          y_50=mem_50,
140          y_80=mem_80,
141      )
142  
143  
144  if __name__ == "__main__":
145      args = parse_benchmark_script_args()
146  
147      common_configs = {
148          "kernel_name": "swiglu",
149          "x_name": "T",
150          "x_label": "sequence length",
151          "x_values": [2**i for i in range(10, 14)],
152          "kernel_providers": ["liger", "huggingface"],
153          "extra_benchmark_configs": [
154              {
155                  "B": 4,
156                  "hidden_size": 4096,
157                  "dtype": torch.bfloat16,
158                  "intermediate_size": 11008,
159                  "hidden_act": "silu",
160              }
161          ],
162          "overwrite": args.overwrite,
163      }
164  
165      run_benchmarks(
166          bench_test_fn=bench_speed_swiglu,
167          kernel_operation_modes=["forward"],
168          metric_name="speed",
169          metric_unit="ms",
170          **common_configs,
171      )
172      run_benchmarks(
173          bench_test_fn=bench_memory_swiglu,
174          kernel_operation_modes=["full"],
175          metric_name="memory",
176          metric_unit="MB",
177          **common_configs,
178      )