/ README.md
README.md
  1  <a name="readme-top"></a>
  2  
  3  # Liger Kernel: Efficient Triton Kernels for LLM Training
  4  
  5  
  6  <table style="width: 100%; text-align: center; border-collapse: collapse;">
  7      <tr>
  8          <th style="padding: 10px;" colspan="2">Stable</th>
  9          <th style="padding: 10px;" colspan="2">Nightly</th>
 10          <th style="padding: 10px;">Discord</th>
 11          <th style="padding: 10px;">Gurubase (experimental)</th>
 12      </tr>
 13      <tr>
 14          <td style="padding: 10px;">
 15              <a href="https://pepy.tech/project/liger-kernel">
 16                  <img src="https://static.pepy.tech/badge/liger-kernel" alt="Downloads (Stable)">
 17              </a>
 18          </td>
 19          <td style="padding: 10px;">
 20              <a href="https://pypi.org/project/liger-kernel">
 21                  <img alt="PyPI - Version" src="https://img.shields.io/pypi/v/liger-kernel?color=green">
 22              </a>
 23          </td>
 24          <td style="padding: 10px;">
 25              <a href="https://pepy.tech/project/liger-kernel-nightly">
 26                  <img src="https://static.pepy.tech/badge/liger-kernel-nightly" alt="Downloads (Nightly)">
 27              </a>
 28          </td>
 29          <td style="padding: 10px;">
 30              <a href="https://pypi.org/project/liger-kernel-nightly">
 31                  <img alt="PyPI - Version" src="https://img.shields.io/pypi/v/liger-kernel-nightly?color=green">
 32              </a>
 33          </td>
 34          <td style="padding: 10px;">
 35              <a href="https://discord.gg/gpumode">
 36                  <img src="https://dcbadge.vercel.app/api/server/gpumode?style=flat" alt="Join Our Discord">
 37              </a>
 38          </td>
 39          <td style="padding: 10px;">
 40              <a href="https://gurubase.io/g/liger-kernel">
 41                  <img src="https://img.shields.io/badge/Gurubase-Ask%20Liger%20Kernel%20Guru-006BFF" alt="Ask Liger Kernel Guru">
 42              </a>
 43          </td>
 44      </tr>
 45  </table>
 46  
 47  
 48  
 49  <img src="https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/logo-banner.png">
 50  
 51  [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Cite our work](#cite-this-work)
 52  
 53  <details>
 54    <summary>Latest News 🔥</summary>
 55  
 56    - [2024/11/6] We release [v0.4.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.4.0): Full AMD support, Tech Report, Modal CI, Llama-3.2-Vision!
 57    - [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989
 58    - [2024/9/6] We release v0.2.1 ([X post](https://x.com/liger_kernel/status/1832168197002510649)). 2500+ Stars, 10+ New Contributors, 50+ PRs, 50k Downloads in two weeks!
 59    - [2024/8/31] CUDA MODE talk, [Liger-Kernel: Real-world Triton kernel for LLM Training](https://youtu.be/gWble4FreV4?si=dxPeIchhkJ36Mbns), [Slides](https://github.com/cuda-mode/lectures?tab=readme-ov-file#lecture-28-liger-kernel)
 60    - [2024/8/23] Official release: check out our [X post](https://x.com/hsu_byron/status/1827072737673982056)
 61  
 62  </details>
 63  
 64  
 65  **Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
 66  
 67  ## Supercharge Your Model with Liger Kernel
 68  
 69  ![Banner](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/banner.GIF)
 70  
 71  With one line of code, Liger Kernel can increase throughput by more than 20% and reduce memory usage by 60%, thereby enabling longer context lengths, larger batch sizes, and massive vocabularies.
 72  
 73  
 74  | Speed Up                 | Memory Reduction        |
 75  |--------------------------|-------------------------|
 76  | ![Speed up](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-tps.png) | ![Memory](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-memory.png) |
 77  
 78  > **Note:**
 79  > - Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
 80  > - Hugging Face models start to OOM at a 4K context length, whereas Hugging Face + Liger Kernel scales up to 16K.
 81  
 82  ## Examples
 83  
 84  | **Use Case**                                    | **Description**                                                                                   |
 85  |------------------------------------------------|---------------------------------------------------------------------------------------------------|
 86  | [**Hugging Face Trainer**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface)      | Train LLaMA 3-8B ~20% faster with over 40% memory reduction on Alpaca dataset using 4 A100s with FSDP |
 87  | [**Lightning Trainer**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/lightning)         | Increase 15% throughput and reduce memory usage by 40% with LLaMA3-8B on MMLU dataset using 8 A100s with DeepSpeed ZeRO3 |
 88  | [**Medusa Multi-head LLM (Retraining Phase)**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa)        | Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s with FSDP |
 89  | [**Vision-Language Model SFT**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface/run_qwen2_vl.sh)      | Finetune Qwen2-VL on image-text data using 4 A100s with FSDP |
 90  
 91  ## Key Features
 92  
 93  - **Ease of use:** Simply patch your Hugging Face model with one line of code, or compose your own model using our Liger Kernel modules.
 94  - **Time and memory efficient:** In the same spirit as Flash-Attn, but for layers like **RMSNorm**, **RoPE**, **SwiGLU**, and **CrossEntropy**! Increases multi-GPU training throughput by 20% and reduces memory usage by 60% with **kernel fusion**, **in-place replacement**, and **chunking** techniques.
 95  - **Exact:** Computation is exact—no approximations! Both forward and backward passes are implemented with rigorous unit tests and undergo convergence testing against training runs without Liger Kernel to ensure accuracy.
 96  - **Lightweight:** Liger Kernel has minimal dependencies, requiring only Torch and Triton—no extra libraries needed! Say goodbye to dependency headaches!
 97  - **Multi-GPU supported:** Compatible with multi-GPU setups (PyTorch FSDP, DeepSpeed, DDP, etc.).
 98  - **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift)
 99  
100  ## Installation
101  
102  ### Dependencies
103  
104  #### CUDA
105  
106  - `torch >= 2.1.2`
107  - `triton >= 2.3.0`
108  
109  #### ROCm
110  
111  - `torch >= 2.5.0` Install according to the instruction in Pytorch official webpage.
112  - `triton >= 3.0.0` Install from pypi. (e.g. `pip install triton==3.0.0`)
113  
114  ### Optional Dependencies
115  
116  - `transformers >= 4.x`: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers.
117  
118  > **Note:**
119  > Our kernels inherit the full spectrum of hardware compatibility offered by [Triton](https://github.com/triton-lang/triton).
120  
121  To install the stable version:
122  
123  ```bash
124  $ pip install liger-kernel
125  ```
126  
127  To install the nightly version:
128  
129  ```bash
130  $ pip install liger-kernel-nightly
131  ```
132  
133  To install from source:
134  
135  ```bash
136  git clone https://github.com/linkedin/Liger-Kernel.git
137  cd Liger-Kernel
138  pip install -e .
139  # or if using transformers
140  pip install -e .[transformers]
141  ```
142  
143  
144  ## Getting Started
145  
146  There are a couple of ways to apply Liger kernels, depending on the level of customization required.
147  
148  ### 1. Use AutoLigerKernelForCausalLM
149  
150  Using the `AutoLigerKernelForCausalLM` is the simplest approach, as you don't have to import a model-specific patching API. If the model type is supported, the modeling code will be automatically patched using the default settings.
151  
152  ```python
153  from liger_kernel.transformers import AutoLigerKernelForCausalLM
154  
155  # This AutoModel wrapper class automatically monkey-patches the
156  # model with the optimized Liger kernels if the model is supported.
157  model = AutoLigerKernelForCausalLM.from_pretrained("path/to/some/model")
158  ```
159  
160  ### 2. Apply Model-Specific Patching APIs
161  
162  Using the [patching APIs](#patching), you can swap Hugging Face models with optimized Liger Kernels.
163  
164  ```python
165  import transformers
166  from liger_kernel.transformers import apply_liger_kernel_to_llama
167  
168  # 1a. Adding this line automatically monkey-patches the model with the optimized Liger kernels
169  apply_liger_kernel_to_llama()
170  
171  # 1b. You could alternatively specify exactly which kernels are applied
172  apply_liger_kernel_to_llama(
173    rope=True,
174    swiglu=True,
175    cross_entropy=True,
176    fused_linear_cross_entropy=False,
177    rms_norm=False
178  )
179  
180  # 2. Instantiate patched model
181  model = transformers.AutoModelForCausalLM("path/to/llama/model")
182  ```
183  
184  ### 3. Compose Your Own Model
185  
186  You can take individual [kernels](#kernels) to compose your models.
187  
188  ```python
189  from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss
190  import torch.nn as nn
191  import torch
192  
193  model = nn.Linear(128, 256).cuda()
194  
195  # fuses linear + cross entropy layers together and performs chunk-by-chunk computation to reduce memory
196  loss_fn = LigerFusedLinearCrossEntropyLoss()
197  
198  input = torch.randn(4, 128, requires_grad=True, device="cuda")
199  target = torch.randint(256, (4, ), device="cuda")
200  
201  loss = loss_fn(model.weight, input, target)
202  loss.backward()
203  ```
204  
205  ## APIs
206  
207  ### AutoModel
208  
209  | **AutoModel Variant** | **API** |
210  |-----------|---------|
211  | AutoModelForCausalLM | `liger_kernel.transformers.AutoLigerKernelForCausalLM` |
212  
213  
214  ### Patching
215  
216  | **Model**   | **API**                                                      | **Supported Operations**                                                |
217  |-------------|--------------------------------------------------------------|-------------------------------------------------------------------------|
218  | LLaMA 2 & 3 | `liger_kernel.transformers.apply_liger_kernel_to_llama`   | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy        |
219  | LLaMA 3.2-Vision | `liger_kernel.transformers.apply_liger_kernel_to_mllama`   | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy        |
220  | Mistral     | `liger_kernel.transformers.apply_liger_kernel_to_mistral`  | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy        |
221  | Mixtral     | `liger_kernel.transformers.apply_liger_kernel_to_mixtral`  | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy        |
222  | Gemma1      | `liger_kernel.transformers.apply_liger_kernel_to_gemma`    | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy         |
223  | Gemma2      | `liger_kernel.transformers.apply_liger_kernel_to_gemma2`   | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy         |
224  | Qwen2 & Qwen2.5      | `liger_kernel.transformers.apply_liger_kernel_to_qwen2`    | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy        |
225  | Qwen2-VL       | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl`    | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy        |
226  | Phi3 & Phi3.5       | `liger_kernel.transformers.apply_liger_kernel_to_phi3`     | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy         |
227  
228  
229  
230  ### Kernels
231  
232  | **Kernel**                      | **API**                                                     |
233  |---------------------------------|-------------------------------------------------------------|
234  | RMSNorm                         | `liger_kernel.transformers.LigerRMSNorm`                    |
235  | LayerNorm                       | `liger_kernel.transformers.LigerLayerNorm`                  |
236  | RoPE                            | `liger_kernel.transformers.liger_rotary_pos_emb`            |
237  | SwiGLU                          | `liger_kernel.transformers.LigerSwiGLUMLP`                  |
238  | GeGLU                           | `liger_kernel.transformers.LigerGEGLUMLP`                   |
239  | CrossEntropy                    | `liger_kernel.transformers.LigerCrossEntropyLoss`           |
240  | FusedLinearCrossEntropy         | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
241  | KLDivergence                    | `liger_kernel.transformers.LigerKLDIVLoss`                  |
242  | JSD                             | `liger_kernel.transformers.LigerJSD`                        |
243  | FusedLinearJSD                  | `liger_kernel.transformers.LigerFusedLinearJSD`             |
244  
245  - **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
246  - **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
247  - **GroupNorm**: [GroupNorm](https://arxiv.org/pdf/1803.08494), which normalizes activations across the group dimension for a given sample. Channels are grouped in K groups over which the normalization is performed, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and can achieve up to ~2X speedup as the number of channels/groups increases.
248  - **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
249  - **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
250  $$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW+b)\otimes(xV+c)$$
251  , is implemented by fusing the elementwise multiplication (denoted by $\otimes$) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction.
252  - **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
253  $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
254  , is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the [tanh approximation form of GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used.
255  - **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.).
256  <!-- TODO: verify vocab sizes are accurate  -->
257  - **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
258  - **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
259  - **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
260  - **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
261  
262  
263  ### Experimental Kernels
264  
265  | **Kernel**                      | **API**                                                     |
266  |---------------------------------|-------------------------------------------------------------|
267  | Embedding                       | `liger_kernel.transformers.experimental.LigerEmbedding`     |
268  | Matmul int2xint8                | `liger_kernel.transformers.experimental.matmul`
269  
270  - **Embedding**: [Embedding](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) is implemented by fusing embedding lookup and output operations. It achieves a peak speedup of ~1.5x in the forward pass and an overall speedup of ~1.1x.
271  - **Matmul int2xint8**: is implemented by using the cache tiled matrix multiplication and by fusing the matmul with the unpacking process which achieves a considerable speed up and performs on par with @torch.compile
272  <!-- TODO: be more specific about batch size -->
273  
274  ## Contributing, Acknowledgements, and License
275  
276  - [Contributing Guidelines](https://github.com/linkedin/Liger-Kernel/blob/main/docs/CONTRIBUTING.md)
277  - [Acknowledgements](https://github.com/linkedin/Liger-Kernel/blob/main/docs/Acknowledgement.md)
278  - [License Information](https://github.com/linkedin/Liger-Kernel/blob/main/docs/License.md)
279  
280  ## Contact
281  
282  - For issues, create a Github ticket in this repository
283  - For open discussion, join [our discord channel](https://discord.gg/gpumode)
284  - For formal collaboration, send an email to byhsu@linkedin.com
285  
286  ## Cite this work
287  
288  Biblatex entry:
289  ```bib
290  @article{hsu2024ligerkernelefficienttriton,
291        title={Liger Kernel: Efficient Triton Kernels for LLM Training},
292        author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen},
293        year={2024},
294        eprint={2410.10989},
295        archivePrefix={arXiv},
296        primaryClass={cs.LG},
297        url={https://arxiv.org/abs/2410.10989},
298        journal={arXiv preprint arXiv:2410.10989},
299  }
300  ```
301  
302  ## Star History
303  [![Star History Chart](https://api.star-history.com/svg?repos=linkedin/Liger-Kernel&type=Date)](https://star-history.com/#linkedin/Liger-Kernel&Date)
304  
305  ## Contributors
306  
307  <a href="https://github.com/linkedin/Liger-Kernel/graphs/contributors">
308    <img alt="contributors" src="https://contrib.rocks/image?repo=linkedin/Liger-Kernel"/>
309  </a>
310  
311  <p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
312      <a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
313          ↑ Back to Top ↑
314      </a>
315  </p>