/ README.md
README.md
1 <a name="readme-top"></a> 2 3 # Liger Kernel: Efficient Triton Kernels for LLM Training 4 5 6 <table style="width: 100%; text-align: center; border-collapse: collapse;"> 7 <tr> 8 <th style="padding: 10px;" colspan="2">Stable</th> 9 <th style="padding: 10px;" colspan="2">Nightly</th> 10 <th style="padding: 10px;">Discord</th> 11 <th style="padding: 10px;">Gurubase (experimental)</th> 12 </tr> 13 <tr> 14 <td style="padding: 10px;"> 15 <a href="https://pepy.tech/project/liger-kernel"> 16 <img src="https://static.pepy.tech/badge/liger-kernel" alt="Downloads (Stable)"> 17 </a> 18 </td> 19 <td style="padding: 10px;"> 20 <a href="https://pypi.org/project/liger-kernel"> 21 <img alt="PyPI - Version" src="https://img.shields.io/pypi/v/liger-kernel?color=green"> 22 </a> 23 </td> 24 <td style="padding: 10px;"> 25 <a href="https://pepy.tech/project/liger-kernel-nightly"> 26 <img src="https://static.pepy.tech/badge/liger-kernel-nightly" alt="Downloads (Nightly)"> 27 </a> 28 </td> 29 <td style="padding: 10px;"> 30 <a href="https://pypi.org/project/liger-kernel-nightly"> 31 <img alt="PyPI - Version" src="https://img.shields.io/pypi/v/liger-kernel-nightly?color=green"> 32 </a> 33 </td> 34 <td style="padding: 10px;"> 35 <a href="https://discord.gg/gpumode"> 36 <img src="https://dcbadge.vercel.app/api/server/gpumode?style=flat" alt="Join Our Discord"> 37 </a> 38 </td> 39 <td style="padding: 10px;"> 40 <a href="https://gurubase.io/g/liger-kernel"> 41 <img src="https://img.shields.io/badge/Gurubase-Ask%20Liger%20Kernel%20Guru-006BFF" alt="Ask Liger Kernel Guru"> 42 </a> 43 </td> 44 </tr> 45 </table> 46 47 48 49 <img src="https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/logo-banner.png"> 50 51 [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Cite our work](#cite-this-work) 52 53 <details> 54 <summary>Latest News 🔥</summary> 55 56 - [2024/11/6] We release [v0.4.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.4.0): Full AMD support, Tech Report, Modal CI, Llama-3.2-Vision! 57 - [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989 58 - [2024/9/6] We release v0.2.1 ([X post](https://x.com/liger_kernel/status/1832168197002510649)). 2500+ Stars, 10+ New Contributors, 50+ PRs, 50k Downloads in two weeks! 59 - [2024/8/31] CUDA MODE talk, [Liger-Kernel: Real-world Triton kernel for LLM Training](https://youtu.be/gWble4FreV4?si=dxPeIchhkJ36Mbns), [Slides](https://github.com/cuda-mode/lectures?tab=readme-ov-file#lecture-28-liger-kernel) 60 - [2024/8/23] Official release: check out our [X post](https://x.com/hsu_byron/status/1827072737673982056) 61 62 </details> 63 64 65 **Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training. 66 67 ## Supercharge Your Model with Liger Kernel 68 69  70 71 With one line of code, Liger Kernel can increase throughput by more than 20% and reduce memory usage by 60%, thereby enabling longer context lengths, larger batch sizes, and massive vocabularies. 72 73 74 | Speed Up | Memory Reduction | 75 |--------------------------|-------------------------| 76 |  |  | 77 78 > **Note:** 79 > - Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s. 80 > - Hugging Face models start to OOM at a 4K context length, whereas Hugging Face + Liger Kernel scales up to 16K. 81 82 ## Examples 83 84 | **Use Case** | **Description** | 85 |------------------------------------------------|---------------------------------------------------------------------------------------------------| 86 | [**Hugging Face Trainer**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface) | Train LLaMA 3-8B ~20% faster with over 40% memory reduction on Alpaca dataset using 4 A100s with FSDP | 87 | [**Lightning Trainer**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/lightning) | Increase 15% throughput and reduce memory usage by 40% with LLaMA3-8B on MMLU dataset using 8 A100s with DeepSpeed ZeRO3 | 88 | [**Medusa Multi-head LLM (Retraining Phase)**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) | Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s with FSDP | 89 | [**Vision-Language Model SFT**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface/run_qwen2_vl.sh) | Finetune Qwen2-VL on image-text data using 4 A100s with FSDP | 90 91 ## Key Features 92 93 - **Ease of use:** Simply patch your Hugging Face model with one line of code, or compose your own model using our Liger Kernel modules. 94 - **Time and memory efficient:** In the same spirit as Flash-Attn, but for layers like **RMSNorm**, **RoPE**, **SwiGLU**, and **CrossEntropy**! Increases multi-GPU training throughput by 20% and reduces memory usage by 60% with **kernel fusion**, **in-place replacement**, and **chunking** techniques. 95 - **Exact:** Computation is exact—no approximations! Both forward and backward passes are implemented with rigorous unit tests and undergo convergence testing against training runs without Liger Kernel to ensure accuracy. 96 - **Lightweight:** Liger Kernel has minimal dependencies, requiring only Torch and Triton—no extra libraries needed! Say goodbye to dependency headaches! 97 - **Multi-GPU supported:** Compatible with multi-GPU setups (PyTorch FSDP, DeepSpeed, DDP, etc.). 98 - **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift) 99 100 ## Installation 101 102 ### Dependencies 103 104 #### CUDA 105 106 - `torch >= 2.1.2` 107 - `triton >= 2.3.0` 108 109 #### ROCm 110 111 - `torch >= 2.5.0` Install according to the instruction in Pytorch official webpage. 112 - `triton >= 3.0.0` Install from pypi. (e.g. `pip install triton==3.0.0`) 113 114 ### Optional Dependencies 115 116 - `transformers >= 4.x`: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers. 117 118 > **Note:** 119 > Our kernels inherit the full spectrum of hardware compatibility offered by [Triton](https://github.com/triton-lang/triton). 120 121 To install the stable version: 122 123 ```bash 124 $ pip install liger-kernel 125 ``` 126 127 To install the nightly version: 128 129 ```bash 130 $ pip install liger-kernel-nightly 131 ``` 132 133 To install from source: 134 135 ```bash 136 git clone https://github.com/linkedin/Liger-Kernel.git 137 cd Liger-Kernel 138 pip install -e . 139 # or if using transformers 140 pip install -e .[transformers] 141 ``` 142 143 144 ## Getting Started 145 146 There are a couple of ways to apply Liger kernels, depending on the level of customization required. 147 148 ### 1. Use AutoLigerKernelForCausalLM 149 150 Using the `AutoLigerKernelForCausalLM` is the simplest approach, as you don't have to import a model-specific patching API. If the model type is supported, the modeling code will be automatically patched using the default settings. 151 152 ```python 153 from liger_kernel.transformers import AutoLigerKernelForCausalLM 154 155 # This AutoModel wrapper class automatically monkey-patches the 156 # model with the optimized Liger kernels if the model is supported. 157 model = AutoLigerKernelForCausalLM.from_pretrained("path/to/some/model") 158 ``` 159 160 ### 2. Apply Model-Specific Patching APIs 161 162 Using the [patching APIs](#patching), you can swap Hugging Face models with optimized Liger Kernels. 163 164 ```python 165 import transformers 166 from liger_kernel.transformers import apply_liger_kernel_to_llama 167 168 # 1a. Adding this line automatically monkey-patches the model with the optimized Liger kernels 169 apply_liger_kernel_to_llama() 170 171 # 1b. You could alternatively specify exactly which kernels are applied 172 apply_liger_kernel_to_llama( 173 rope=True, 174 swiglu=True, 175 cross_entropy=True, 176 fused_linear_cross_entropy=False, 177 rms_norm=False 178 ) 179 180 # 2. Instantiate patched model 181 model = transformers.AutoModelForCausalLM("path/to/llama/model") 182 ``` 183 184 ### 3. Compose Your Own Model 185 186 You can take individual [kernels](#kernels) to compose your models. 187 188 ```python 189 from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss 190 import torch.nn as nn 191 import torch 192 193 model = nn.Linear(128, 256).cuda() 194 195 # fuses linear + cross entropy layers together and performs chunk-by-chunk computation to reduce memory 196 loss_fn = LigerFusedLinearCrossEntropyLoss() 197 198 input = torch.randn(4, 128, requires_grad=True, device="cuda") 199 target = torch.randint(256, (4, ), device="cuda") 200 201 loss = loss_fn(model.weight, input, target) 202 loss.backward() 203 ``` 204 205 ## APIs 206 207 ### AutoModel 208 209 | **AutoModel Variant** | **API** | 210 |-----------|---------| 211 | AutoModelForCausalLM | `liger_kernel.transformers.AutoLigerKernelForCausalLM` | 212 213 214 ### Patching 215 216 | **Model** | **API** | **Supported Operations** | 217 |-------------|--------------------------------------------------------------|-------------------------------------------------------------------------| 218 | LLaMA 2 & 3 | `liger_kernel.transformers.apply_liger_kernel_to_llama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | 219 | LLaMA 3.2-Vision | `liger_kernel.transformers.apply_liger_kernel_to_mllama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | 220 | Mistral | `liger_kernel.transformers.apply_liger_kernel_to_mistral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | 221 | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | 222 | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy | 223 | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy | 224 | Qwen2 & Qwen2.5 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | 225 | Qwen2-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | 226 | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | 227 228 229 230 ### Kernels 231 232 | **Kernel** | **API** | 233 |---------------------------------|-------------------------------------------------------------| 234 | RMSNorm | `liger_kernel.transformers.LigerRMSNorm` | 235 | LayerNorm | `liger_kernel.transformers.LigerLayerNorm` | 236 | RoPE | `liger_kernel.transformers.liger_rotary_pos_emb` | 237 | SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` | 238 | GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` | 239 | CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` | 240 | FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`| 241 | KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` | 242 | JSD | `liger_kernel.transformers.LigerJSD` | 243 | FusedLinearJSD | `liger_kernel.transformers.LigerFusedLinearJSD` | 244 245 - **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction. 246 - **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup. 247 - **GroupNorm**: [GroupNorm](https://arxiv.org/pdf/1803.08494), which normalizes activations across the group dimension for a given sample. Channels are grouped in K groups over which the normalization is performed, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and can achieve up to ~2X speedup as the number of channels/groups increases. 248 - **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction. 249 - **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by 250 $$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW+b)\otimes(xV+c)$$ 251 , is implemented by fusing the elementwise multiplication (denoted by $\otimes$) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. 252 - **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by 253 $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$ 254 , is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the [tanh approximation form of GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used. 255 - **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.). 256 <!-- TODO: verify vocab sizes are accurate --> 257 - **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage. 258 - **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size. 259 - **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively. 260 - **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively. 261 262 263 ### Experimental Kernels 264 265 | **Kernel** | **API** | 266 |---------------------------------|-------------------------------------------------------------| 267 | Embedding | `liger_kernel.transformers.experimental.LigerEmbedding` | 268 | Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul` 269 270 - **Embedding**: [Embedding](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) is implemented by fusing embedding lookup and output operations. It achieves a peak speedup of ~1.5x in the forward pass and an overall speedup of ~1.1x. 271 - **Matmul int2xint8**: is implemented by using the cache tiled matrix multiplication and by fusing the matmul with the unpacking process which achieves a considerable speed up and performs on par with @torch.compile 272 <!-- TODO: be more specific about batch size --> 273 274 ## Contributing, Acknowledgements, and License 275 276 - [Contributing Guidelines](https://github.com/linkedin/Liger-Kernel/blob/main/docs/CONTRIBUTING.md) 277 - [Acknowledgements](https://github.com/linkedin/Liger-Kernel/blob/main/docs/Acknowledgement.md) 278 - [License Information](https://github.com/linkedin/Liger-Kernel/blob/main/docs/License.md) 279 280 ## Contact 281 282 - For issues, create a Github ticket in this repository 283 - For open discussion, join [our discord channel](https://discord.gg/gpumode) 284 - For formal collaboration, send an email to byhsu@linkedin.com 285 286 ## Cite this work 287 288 Biblatex entry: 289 ```bib 290 @article{hsu2024ligerkernelefficienttriton, 291 title={Liger Kernel: Efficient Triton Kernels for LLM Training}, 292 author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen}, 293 year={2024}, 294 eprint={2410.10989}, 295 archivePrefix={arXiv}, 296 primaryClass={cs.LG}, 297 url={https://arxiv.org/abs/2410.10989}, 298 journal={arXiv preprint arXiv:2410.10989}, 299 } 300 ``` 301 302 ## Star History 303 [](https://star-history.com/#linkedin/Liger-Kernel&Date) 304 305 ## Contributors 306 307 <a href="https://github.com/linkedin/Liger-Kernel/graphs/contributors"> 308 <img alt="contributors" src="https://contrib.rocks/image?repo=linkedin/Liger-Kernel"/> 309 </a> 310 311 <p align="right" style="font-size: 14px; color: #555; margin-top: 20px;"> 312 <a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;"> 313 ↑ Back to Top ↑ 314 </a> 315 </p>