test_fused_linear_cross_entropy.py
1 from test.transformers.test_cross_entropy import CrossEntropyWithZLoss 2 from test.utils import assert_verbose_allclose, set_seed 3 from typing import Optional 4 5 import pytest 6 import torch 7 8 from liger_kernel.ops.fused_linear_cross_entropy import ( 9 LigerFusedLinearCrossEntropyFunction, 10 ) 11 from liger_kernel.transformers.functional import liger_fused_linear_cross_entropy 12 from liger_kernel.transformers.fused_linear_cross_entropy import ( 13 LigerFusedLinearCrossEntropyLoss, 14 ) 15 from liger_kernel.utils import infer_device 16 17 device = infer_device() 18 19 # set random seed globally 20 set_seed() 21 22 23 class TorchLMHeadCE(torch.nn.Module): 24 """Ground truth implementation of the linear fused with torch based cross entropy loss. 25 26 :param H: hidden size 27 :param V: vocab size 28 :param ignore_index: index to ignore 29 :param reduction: reduction method 30 :param label_smoothing: label_smoothing to apply on target 31 :param lse_square_scale: scaler of lse ^ 2 to compute z loss 32 33 # TODO: if we bump CI env's `transformers` version to >= 4.46, we should just directly 34 # call https://github.com/huggingface/transformers/blob/main/src/transformers/loss/loss_utils.py#L32 35 # to be consistent with Hugging Face model implementation. 36 """ 37 38 def __init__( 39 self, 40 H: int, 41 V: int, 42 dtype: torch.dtype, 43 bias: bool = False, 44 ignore_index: int = -100, 45 lse_square_scale: float = 0.0, 46 label_smoothing: float = 0.0, 47 reduction: str = "mean", 48 softcap: Optional[float] = None, 49 ): 50 super().__init__() 51 self.lin = torch.nn.Linear( 52 in_features=H, out_features=V, bias=bias, dtype=dtype 53 ) 54 self.ce_loss = CrossEntropyWithZLoss( 55 ignore_index=ignore_index, 56 lse_square_scale=lse_square_scale, 57 label_smoothing=label_smoothing, 58 reduction=reduction, 59 ) 60 self.softcap = softcap 61 62 def forward(self, x, y): 63 logits = self.lin(x).to(torch.float32) 64 if self.softcap is not None and self.softcap != 0.0: 65 logits = self.softcap * torch.tanh(logits / self.softcap) 66 return self.ce_loss(logits, y) 67 68 69 class LigerLMHeadCE(torch.nn.Module): 70 def __init__( 71 self, 72 H: int, 73 V: int, 74 dtype: torch.dtype, 75 bias: bool = False, 76 ignore_index: int = -100, 77 lse_square_scale: float = 0.0, 78 label_smoothing: float = 0.0, 79 reduction: str = "mean", 80 softcap: Optional[float] = None, 81 ): 82 super().__init__() 83 self.lin = torch.nn.Linear( 84 in_features=H, out_features=V, bias=bias, dtype=dtype 85 ) 86 self.ce_loss = LigerFusedLinearCrossEntropyLoss( 87 ignore_index=ignore_index, 88 lse_square_scale=lse_square_scale, 89 label_smoothing=label_smoothing, 90 reduction=reduction, 91 softcap=softcap, 92 ) 93 94 def forward(self, x, y): 95 return self.ce_loss(self.lin.weight, x, y, self.lin.bias) 96 97 98 ############################################################################# 99 # Test the correctness of the fused linear cross entropy loss 100 ############################################################################# 101 102 103 @pytest.mark.parametrize( 104 "B, T, H, V", 105 [ 106 (8, 128, 1024, 4096), 107 (4, 47, 31, 123), # random shape 108 ], 109 ) 110 @pytest.mark.parametrize( 111 "reduction, scalar, dtype, atol, rtol", 112 [ 113 ("mean", 1.0, torch.bfloat16, 5e-3, 5e-2), 114 ("mean", 1.0, torch.float32, 1e-5, 5e-4), 115 ("sum", 1.0, torch.bfloat16, 5e-0, 5e1), 116 ("sum", 1.0, torch.float32, 1e-3, 5e-2), 117 ], 118 ) 119 @pytest.mark.parametrize("bias", [True, False]) 120 @pytest.mark.parametrize( 121 "label_smoothing, ignore_index, lse_square_scale, softcap", 122 [ 123 (0, -100, 0, None), 124 ( 125 0.1, 126 42, 127 1e-4, 128 30.0, 129 ), # Pass non-default values once to ensure all params work along 130 ], 131 ) 132 def test_correctness( 133 B, 134 T, 135 H, 136 V, 137 scalar, 138 dtype, 139 bias, 140 lse_square_scale, 141 label_smoothing, 142 ignore_index, 143 reduction, 144 softcap, 145 atol, 146 rtol, 147 ): 148 torch_lm_head_ce = TorchLMHeadCE( 149 H=H, 150 V=V, 151 bias=bias, 152 lse_square_scale=lse_square_scale, 153 label_smoothing=label_smoothing, 154 ignore_index=ignore_index, 155 reduction=reduction, 156 softcap=softcap, 157 dtype=dtype, 158 ).to(device) 159 liger_lm_head_ce = LigerLMHeadCE( 160 H=H, 161 V=V, 162 bias=bias, 163 lse_square_scale=lse_square_scale, 164 label_smoothing=label_smoothing, 165 ignore_index=ignore_index, 166 reduction=reduction, 167 softcap=softcap, 168 dtype=dtype, 169 ).to(device) 170 171 # init the linear in all CEs with the same weights 172 torch_lm_head_ce.lin.weight.data = liger_lm_head_ce.lin.weight.data = torch.rand( 173 V, H, device=device, dtype=dtype 174 ) 175 176 if bias: 177 torch_lm_head_ce.lin.bias.data = liger_lm_head_ce.lin.bias.data = torch.rand( 178 V, device=device, dtype=dtype 179 ) 180 181 _tensor = torch.randn(B * T, H, device=device, dtype=dtype) * scalar 182 _input1 = _tensor.detach().clone().requires_grad_(True) 183 _input2 = _tensor.detach().clone().requires_grad_(True) 184 185 target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) 186 # Assign some random number of elements as ignore_index 187 num_elements_to_assign = torch.randint( 188 1, B * T // 2, (1,) 189 ).item() # Random number of elements to set to ignore_index 190 indices_to_assign = torch.randperm(B * T)[ 191 :num_elements_to_assign 192 ] # Randomly select indices 193 target[indices_to_assign] = ignore_index 194 195 output1 = torch_lm_head_ce(_input1, target) 196 output2 = liger_lm_head_ce(_input2, target) 197 198 assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) 199 200 output1.backward() 201 output2.backward() 202 203 assert_verbose_allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol) 204 205 assert_verbose_allclose( 206 torch_lm_head_ce.lin.weight.grad, 207 liger_lm_head_ce.lin.weight.grad, 208 atol=atol, 209 rtol=rtol, 210 ) 211 212 if bias: 213 assert_verbose_allclose( 214 torch_lm_head_ce.lin.bias.grad, 215 liger_lm_head_ce.lin.bias.grad, 216 atol=atol, 217 rtol=rtol, 218 ) 219 220 221 @pytest.mark.parametrize( 222 "B, T, H, V", 223 [ 224 (2, 2, 8, 8), 225 # weird shapes 226 (9, 7, 41, 41), 227 ], 228 ) 229 @pytest.mark.parametrize( 230 "scalar, dtype, atol, rtol", 231 [ 232 (1.0, torch.bfloat16, 5e-3, 5e-2), 233 (1.0, torch.float32, 1e-5, 5e-4), 234 ], 235 ) 236 @pytest.mark.parametrize("bias", [True, False]) 237 def test_correctness_functional(B, T, H, V, scalar, dtype, bias, atol, rtol): 238 _input = torch.randn(B * T, H, device=device, dtype=dtype) * scalar 239 x1 = _input.detach().clone().requires_grad_(True) 240 x2 = _input.detach().clone().requires_grad_(True) 241 242 target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) 243 244 weight = torch.randn(V, H, device=device, dtype=dtype) 245 bias = torch.randn(V, device=device, dtype=dtype) if bias else None 246 247 y1 = liger_fused_linear_cross_entropy( 248 input=x1, 249 weight=weight, 250 target=target, 251 bias=bias, 252 ) 253 y2 = LigerFusedLinearCrossEntropyFunction.apply(x2, weight, target, bias) 254 255 assert torch.allclose(y1, y2, atol=atol, rtol=rtol) 256 257 grad_output = torch.randn_like(y1) 258 259 y1.backward(grad_output) 260 y2.backward(grad_output) 261 262 assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) 263 264 265 @pytest.mark.parametrize( 266 "B, T, H, V", 267 [ 268 (8, 128, 1024, 4096), 269 (4, 47, 31, 123), # random shape 270 ], 271 ) 272 @pytest.mark.parametrize( 273 "cast_dtype, atol, rtol", 274 [ 275 (torch.bfloat16, 5e-3, 5e-2), 276 (torch.float16, 5e-3, 5e-2), 277 ], 278 ) 279 def test_amp(B, T, H, V, cast_dtype, atol, rtol): 280 dtype = torch.float32 281 torch_lm_head_ce = TorchLMHeadCE( 282 H=H, 283 V=V, 284 bias=True, 285 label_smoothing=0.0, 286 reduction="mean", 287 dtype=dtype, 288 ).to(device) 289 liger_lm_head_ce = LigerLMHeadCE( 290 H=H, 291 V=V, 292 bias=True, 293 label_smoothing=0.0, 294 reduction="mean", 295 dtype=dtype, 296 ).to(device) 297 298 # init the linear in all CEs with the same weights 299 torch_lm_head_ce.lin.weight.data = liger_lm_head_ce.lin.weight.data = torch.rand( 300 V, H, device=device, dtype=dtype 301 ) 302 303 _tensor = torch.randn(B * T, H, device=device, dtype=dtype) 304 _input1 = _tensor.detach().clone().requires_grad_(True) 305 _input2 = _tensor.detach().clone().requires_grad_(True) 306 307 target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) 308 309 with torch.autocast(device_type=device, dtype=cast_dtype): 310 output1 = torch_lm_head_ce(_input1, target) 311 output2 = liger_lm_head_ce(_input2, target) 312 313 assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) 314 315 with torch.autocast(device_type=device, dtype=cast_dtype): 316 output1.backward() 317 output2.backward() 318 319 assert_verbose_allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol) 320 321 assert_verbose_allclose( 322 torch_lm_head_ce.lin.weight.grad, 323 liger_lm_head_ce.lin.weight.grad, 324 atol=atol, 325 rtol=rtol, 326 )