test_cross_entropy.py
1 from test.utils import assert_verbose_allclose, set_seed, supports_bfloat16 2 3 import pytest 4 import torch 5 import torch.nn.functional as F 6 from torch.nn import CrossEntropyLoss 7 8 from liger_kernel.ops.cross_entropy import ( 9 LigerCrossEntropyFunction, 10 liger_cross_entropy_kernel, 11 ) 12 from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss 13 from liger_kernel.transformers.functional import liger_cross_entropy 14 from liger_kernel.utils import infer_device 15 16 device = infer_device() 17 set_seed(42) 18 19 20 class CrossEntropyWithZLoss(torch.nn.Module): 21 def __init__( 22 self, 23 lse_square_scale=0.0, 24 reduction="mean", 25 ignore_index=-100, 26 label_smoothing=0.0, 27 return_z_loss=False, 28 dtype=torch.float32, 29 ): 30 super().__init__() 31 self.lse_square_scale = lse_square_scale 32 self.reduction = reduction 33 self.ignore_index = ignore_index 34 self.return_z_loss = return_z_loss 35 self.label_smoothing = label_smoothing 36 self.dtype = dtype 37 38 def forward(self, logits, targets): 39 # Loss calculations are all in float32 40 logits = logits.to(torch.float32) 41 # Standard cross entropy loss 42 ce_loss = F.cross_entropy( 43 logits, 44 targets, 45 reduction=self.reduction, 46 label_smoothing=self.label_smoothing, 47 ignore_index=self.ignore_index, 48 ) 49 50 # Compute log-sum-exp term 51 lse = torch.logsumexp(logits, dim=-1) 52 53 # Z-loss term 54 z_loss = torch.where( 55 targets != self.ignore_index, self.lse_square_scale * (lse**2), 0.0 56 ) 57 z_loss = z_loss.to(logits.dtype) 58 if self.reduction == "mean": 59 z_loss = z_loss.sum() / (targets != self.ignore_index).sum() 60 elif self.reduction == "sum": 61 z_loss = z_loss.sum() 62 else: 63 z_loss = z_loss 64 ce_loss = ce_loss.to(self.dtype) 65 z_loss = z_loss.to(self.dtype) 66 67 # Final loss: cross-entropy loss + Z-loss 68 total_loss = ce_loss + z_loss 69 if self.return_z_loss: 70 return total_loss, z_loss 71 else: 72 return total_loss 73 74 75 def _test_correctness_once(target_ce, B, T, V, reduction, scalar, dtype, atol, rtol): 76 torch.manual_seed(0) 77 torch_ce = CrossEntropyLoss(reduction=reduction) 78 79 _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar 80 _input = _tensor.detach().clone().requires_grad_(True) 81 _input2 = _tensor.detach().clone().requires_grad_(True) 82 83 target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) 84 85 output = torch_ce(_input, target) 86 output2 = target_ce(_input2, target) 87 assert torch.allclose(output, output2, atol=atol, rtol=rtol) 88 89 output.backward() 90 output2.backward() 91 assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) 92 93 94 def _test_correctness_with_ignore_index_once( 95 target_ce, B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol 96 ): 97 98 torch_ce = CrossEntropyLoss(ignore_index=ignore_index, reduction=reduction) 99 100 _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar 101 _input = _tensor.detach().clone().requires_grad_(True) 102 _input2 = _tensor.detach().clone().requires_grad_(True) 103 104 target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) 105 106 # Assign some random number of elements as ignore_index 107 num_elements_to_assign = torch.randint( 108 1, B * T // 2, (1,) 109 ).item() # Random number of elements to set to ignore_index 110 indices_to_assign = torch.randperm(B * T)[ 111 :num_elements_to_assign 112 ] # Randomly select indices 113 target[indices_to_assign] = ignore_index 114 115 output = torch_ce(_input, target) 116 output2 = target_ce(_input2, target) 117 118 assert torch.allclose(output, output2, atol=atol, rtol=rtol) 119 120 output.backward() 121 output2.backward() 122 assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) 123 124 125 def _test_correctness_with_label_smoothing_once( 126 target_ce, B, T, V, label_smoothing, scalar, dtype, atol, rtol 127 ): 128 129 torch_ce = CrossEntropyLoss(label_smoothing=label_smoothing) 130 131 _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar 132 _input = _tensor.detach().clone().requires_grad_(True) 133 _input2 = _tensor.detach().clone().requires_grad_(True) 134 135 target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) 136 137 output = torch_ce(_input, target) 138 output2 = target_ce(_input2, target) 139 140 assert torch.allclose(output, output2, atol=atol, rtol=rtol) 141 142 output.backward() 143 output2.backward() 144 assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) 145 146 147 def _test_correctness_with_label_smoothing_with_ignore_index_once( 148 target_ce, B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol 149 ): 150 151 torch_ce = CrossEntropyLoss( 152 ignore_index=ignore_index, label_smoothing=label_smoothing 153 ) 154 155 _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar 156 _input = _tensor.detach().clone().requires_grad_(True) 157 _input2 = _tensor.detach().clone().requires_grad_(True) 158 159 target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) 160 161 # Assign some random number of elements as ignore_index 162 num_elements_to_assign = torch.randint( 163 1, B * T // 2, (1,) 164 ).item() # Random number of elements to set to ignore_index 165 indices_to_assign = torch.randperm(B * T)[ 166 :num_elements_to_assign 167 ] # Randomly select indices 168 target[indices_to_assign] = ignore_index 169 170 output = torch_ce(_input, target) 171 output2 = target_ce(_input2, target) 172 173 assert torch.allclose(output, output2, atol=atol, rtol=rtol) 174 175 output.backward() 176 output2.backward() 177 assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) 178 179 180 def _test_correctness_with_softcap_once( 181 target_ce, B, T, V, softcap, reduction, scalar, dtype, atol, rtol 182 ): 183 184 torch_ce = CrossEntropyLoss(reduction=reduction) 185 186 _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar 187 # upcasting to match liger's casting strategy 188 _input = _tensor.to(torch.float32).detach().clone().requires_grad_(True) 189 _input2 = _tensor.detach().clone().requires_grad_(True) 190 191 target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) 192 193 # downcasting to original dtype 194 output = torch_ce(softcap * torch.tanh(_input / softcap), target).to(dtype) 195 output2 = target_ce(_input2, target) 196 197 assert torch.allclose(output, output2, atol=atol, rtol=rtol) 198 199 output.backward() 200 output2.backward() 201 202 203 def _test_correctness_with_z_loss_once( 204 target_ce, 205 B, 206 T, 207 V, 208 scalar, 209 dtype, 210 atol, 211 rtol, 212 lse_square_scale, 213 return_z_loss, 214 ): 215 torch.manual_seed(0) 216 torch_ce = CrossEntropyWithZLoss( 217 lse_square_scale=lse_square_scale, 218 return_z_loss=return_z_loss, 219 dtype=dtype, 220 ) 221 222 _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar 223 _input = _tensor.detach().clone().requires_grad_(True) 224 _input2 = _tensor.detach().clone().requires_grad_(True) 225 226 target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) 227 if return_z_loss: 228 output, z_output = torch_ce(_input, target) 229 output2, z_output2 = target_ce(_input2, target) 230 231 else: 232 output = torch_ce(_input, target) 233 output2 = target_ce(_input2, target) 234 235 assert torch.allclose(output, output2, atol=atol, rtol=rtol) 236 237 if return_z_loss: 238 assert torch.allclose(z_output, z_output2, atol=atol, rtol=rtol) 239 240 output.backward() 241 output2.backward() 242 243 assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) 244 245 246 def _test_correctness_with_z_loss_with_other_params_once( 247 target_ce, 248 B, 249 T, 250 V, 251 scalar, 252 dtype, 253 atol, 254 rtol, 255 lse_square_scale, 256 return_z_loss, 257 label_smoothing, 258 ignore_index, 259 reduction, 260 ): 261 torch.manual_seed(0) 262 torch_ce = CrossEntropyWithZLoss( 263 lse_square_scale=lse_square_scale, 264 return_z_loss=return_z_loss, 265 label_smoothing=label_smoothing, 266 ignore_index=ignore_index, 267 reduction=reduction, 268 dtype=dtype, 269 ) 270 271 _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar 272 _input = _tensor.detach().clone().requires_grad_(True) 273 _input2 = _tensor.detach().clone().requires_grad_(True) 274 275 target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) 276 277 # Assign some random number of elements as ignore_index 278 num_elements_to_assign = torch.randint( 279 1, B * T // 2, (1,) 280 ).item() # Random number of elements to set to ignore_index 281 indices_to_assign = torch.randperm(B * T)[ 282 :num_elements_to_assign 283 ] # Randomly select indices 284 target[indices_to_assign] = ignore_index 285 286 if return_z_loss: 287 output, z_output = torch_ce(_input, target) 288 output2, z_output2 = target_ce(_input2, target) 289 290 else: 291 output = torch_ce(_input, target) 292 output2 = target_ce(_input2, target) 293 294 assert torch.allclose(output, output2, atol=atol, rtol=rtol) 295 296 if return_z_loss: 297 assert torch.allclose(z_output, z_output2, atol=atol, rtol=rtol) 298 299 output.backward() 300 output2.backward() 301 assert_verbose_allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) 302 303 304 def _test_correctness_not_last_layer_once( 305 target_ce, B, T, V, reduction, scalar, dtype, atol, rtol 306 ): 307 308 torch_ce = CrossEntropyLoss(reduction=reduction) 309 310 _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar 311 _input = _tensor.detach().clone().requires_grad_(True) 312 _input2 = _tensor.detach().clone().requires_grad_(True) 313 314 target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) 315 316 output = torch_ce(_input, target) 317 output2 = target_ce(_input2, target) 318 assert torch.allclose(output, output2, atol=atol, rtol=rtol) 319 320 loss1 = output * 3 321 loss2 = output2 * 3 322 323 loss1.backward() 324 loss2.backward() 325 assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) 326 327 328 def _test_correctness_functional( 329 B, 330 T, 331 V, 332 scalar, 333 dtype, 334 atol, 335 rtol, 336 ): 337 338 _input = torch.randn(B * T, V, device=device, dtype=dtype) * scalar 339 340 x1 = _input.clone().requires_grad_(True) 341 x2 = _input.clone().requires_grad_(True) 342 343 target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) 344 345 y1, y1_z = liger_cross_entropy( 346 x1, 347 target, 348 ignore_index=0, 349 lse_square_scale=1e-4, 350 label_smoothing=0.1, 351 reduction="mean", 352 softcap=30.0, 353 return_z_loss=True, 354 ) 355 y2, y2_z = LigerCrossEntropyFunction.apply( 356 x2, target, 0, 1e-4, 0.1, "mean", 30.0, True 357 ) 358 359 assert torch.allclose(y1, y2, atol=atol, rtol=rtol) 360 assert torch.allclose(y1_z, y2_z, atol=atol, rtol=rtol) 361 362 grad = torch.randn_like(y2) 363 364 y1.backward(grad) 365 y2.backward(grad) 366 367 assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) 368 369 370 ############################################################################# 371 # Test the correctness of the liger cross entropy loss 372 ############################################################################# 373 374 375 @pytest.mark.parametrize( 376 "B, T, V", 377 [ 378 (2, 4096, 32000), # llama 379 (3, 423, 32000), # weird shapes 380 ], 381 ) 382 @pytest.mark.parametrize("reduction", ["sum", "mean"]) 383 @pytest.mark.parametrize( 384 "scalar, dtype, atol, rtol", 385 [ 386 pytest.param( 387 1.0, 388 torch.bfloat16, 389 1e-8, 390 5e-2, 391 marks=pytest.mark.skipif( 392 not supports_bfloat16(), reason="bfloat16 not supported on this GPU" 393 ), 394 ), 395 (1.0, torch.float32, 1e-8, 1e-6), 396 ], 397 ) 398 def test_correctness(B, T, V, scalar, dtype, reduction, atol, rtol): 399 liger_ce = LigerCrossEntropyLoss(reduction=reduction) 400 _test_correctness_once(liger_ce, B, T, V, reduction, scalar, dtype, atol, rtol) 401 402 403 @pytest.mark.parametrize( 404 "B, T, V", 405 [ 406 (2, 2, 8), 407 # weird shapes 408 (9, 7, 41), 409 ], 410 ) 411 @pytest.mark.parametrize( 412 "scalar, dtype, atol, rtol", 413 [ 414 (1.0, torch.bfloat16, 1e-8, 5e-2), 415 (1.0, torch.float32, 1e-8, 1e-6), 416 ], 417 ) 418 def test_correctness_functional(B, T, V, scalar, dtype, atol, rtol): 419 _test_correctness_functional(B, T, V, scalar, dtype, atol, rtol) 420 421 422 @pytest.mark.parametrize( 423 "B, T, V, ignore_index", 424 [ 425 (2, 4096, 32000, 2), 426 # weird shapes 427 (3, 423, 32000, -123), 428 ], 429 ) 430 @pytest.mark.parametrize("reduction", ["sum", "mean"]) 431 @pytest.mark.parametrize( 432 "scalar, dtype, atol, rtol", 433 [ 434 pytest.param( 435 1.0, 436 torch.bfloat16, 437 1e-8, 438 5e-2, 439 marks=pytest.mark.skipif( 440 not supports_bfloat16(), reason="bfloat16 not supported on this GPU" 441 ), 442 ), 443 (1.0, torch.float32, 1e-8, 1e-6), 444 ], 445 ) 446 def test_correctness_with_ignore_index( 447 B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol 448 ): 449 liger_ce = LigerCrossEntropyLoss(ignore_index=ignore_index, reduction=reduction) 450 _test_correctness_with_ignore_index_once( 451 liger_ce, B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol 452 ) 453 454 455 @pytest.mark.parametrize( 456 "B, T, V, label_smoothing", 457 [ 458 (2, 4096, 32000, 0.1), 459 # weird shapes 460 (3, 423, 32000, 0.1), 461 ], 462 ) 463 @pytest.mark.parametrize( 464 "scalar, dtype, atol, rtol", 465 [ 466 pytest.param( 467 1.0, 468 torch.bfloat16, 469 1e-8, 470 5e-2, 471 marks=pytest.mark.skipif( 472 not supports_bfloat16(), reason="bfloat16 not supported on this GPU" 473 ), 474 ), 475 (1.0, torch.float32, 1e-8, 1e-6), 476 ], 477 ) 478 def test_correctness_with_label_smoothing_once( 479 B, T, V, label_smoothing, scalar, dtype, atol, rtol 480 ): 481 liger_ce = LigerCrossEntropyLoss(label_smoothing=label_smoothing) 482 _test_correctness_with_label_smoothing_once( 483 liger_ce, B, T, V, label_smoothing, scalar, dtype, atol, rtol 484 ) 485 486 487 @pytest.mark.parametrize( 488 "B, T, V, ignore_index, label_smoothing", 489 [ 490 (2, 4096, 32000, 1, 0.1), 491 # weird shapes 492 (3, 423, 32000, -300, 0.2), 493 ], 494 ) 495 @pytest.mark.parametrize( 496 "scalar, dtype, atol, rtol", 497 [ 498 pytest.param( 499 1.0, 500 torch.bfloat16, 501 1e-8, 502 5e-2, 503 marks=pytest.mark.skipif( 504 not supports_bfloat16(), reason="bfloat16 not supported on this GPU" 505 ), 506 ), 507 (1.0, torch.float32, 1e-8, 1e-6), 508 ], 509 ) 510 def test_correctness_with_label_smoothing_with_ignore_index_once( 511 B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol 512 ): 513 liger_ce = LigerCrossEntropyLoss( 514 ignore_index=ignore_index, 515 label_smoothing=label_smoothing, 516 ) 517 _test_correctness_with_label_smoothing_with_ignore_index_once( 518 liger_ce, B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol 519 ) 520 521 522 @pytest.mark.parametrize( 523 "B, T, V, softcap", 524 [ 525 (2, 4096, 32000, 30.0), # llama2, mistral 526 # weird shapes 527 (3, 423, 32000, 30.0), 528 ], 529 ) 530 @pytest.mark.parametrize("reduction", ["sum", "mean"]) 531 @pytest.mark.parametrize( 532 "scalar, dtype, atol, rtol", 533 [ 534 pytest.param( 535 1.0, 536 torch.bfloat16, 537 1e-8, 538 5e-2, 539 marks=pytest.mark.skipif( 540 not supports_bfloat16(), reason="bfloat16 not supported on this GPU" 541 ), 542 ), 543 (1.0, torch.float32, 1e-8, 1e-6), 544 ], 545 ) 546 def test_correctness_with_softcap_once( 547 B, T, V, softcap, reduction, scalar, dtype, atol, rtol 548 ): 549 liger_ce = LigerCrossEntropyLoss(softcap=softcap, reduction=reduction) 550 _test_correctness_with_softcap_once( 551 liger_ce, B, T, V, softcap, reduction, scalar, dtype, atol, rtol 552 ) 553 554 555 @pytest.mark.parametrize( 556 "B, T, V", 557 [ 558 (2, 4096, 32000), # llama2 559 # weird shapes 560 (3, 423, 32000), 561 ], 562 ) 563 @pytest.mark.parametrize( 564 "scalar, dtype, atol, rtol", 565 [ 566 pytest.param( 567 1.0, 568 torch.bfloat16, 569 1e-8, 570 5e-2, 571 marks=pytest.mark.skipif( 572 not supports_bfloat16(), reason="bfloat16 not supported on this GPU" 573 ), 574 ), 575 (1.0, torch.float32, 1e-8, 1e-6), 576 ], 577 ) 578 @pytest.mark.parametrize("return_z_loss", [True, False]) 579 @pytest.mark.parametrize( 580 "lse_square_scale", 581 [ 582 1e-4, # PaLM 583 1e-5, # Chameleon 584 ], 585 ) 586 def test_correctness_with_z_loss_once( 587 B, 588 T, 589 V, 590 scalar, 591 dtype, 592 atol, 593 rtol, 594 lse_square_scale, 595 return_z_loss, 596 ): 597 test_ce = LigerCrossEntropyLoss( 598 lse_square_scale=lse_square_scale, 599 return_z_loss=return_z_loss, 600 ) 601 _test_correctness_with_z_loss_once( 602 test_ce, 603 B, 604 T, 605 V, 606 scalar, 607 dtype, 608 atol, 609 rtol, 610 lse_square_scale, 611 return_z_loss, 612 ) 613 614 615 @pytest.mark.parametrize( 616 "B, T, V", 617 [ 618 (2, 4096, 32000), # llama2, mistral 619 # weird shapes 620 (3, 423, 32000), 621 ], 622 ) 623 @pytest.mark.parametrize( 624 "scalar, dtype, atol, rtol", 625 [ 626 pytest.param( 627 1.0, 628 torch.bfloat16, 629 1e-8, 630 5e-2, 631 marks=pytest.mark.skipif( 632 not supports_bfloat16(), reason="bfloat16 not supported on this GPU" 633 ), 634 ), 635 (1.0, torch.float32, 1e-8, 1e-6), 636 ], 637 ) 638 @pytest.mark.parametrize( 639 "return_z_loss, lse_square_scale", 640 [ 641 (True, 1e-4), 642 (False, 1e-5), 643 ], 644 ) 645 @pytest.mark.parametrize( 646 "label_smoothing, ignore_index, reduction", 647 [ 648 (0.1, 42, "mean"), 649 (0.2, -42, "sum"), 650 ], 651 ) 652 def test_correctness_with_z_loss_with_other_params_once( 653 B, 654 T, 655 V, 656 scalar, 657 dtype, 658 atol, 659 rtol, 660 lse_square_scale, 661 return_z_loss, 662 label_smoothing, 663 ignore_index, 664 reduction, 665 ): 666 test_ce = LigerCrossEntropyLoss( 667 lse_square_scale=lse_square_scale, 668 return_z_loss=return_z_loss, 669 label_smoothing=label_smoothing, 670 ignore_index=ignore_index, 671 reduction=reduction, 672 ) 673 _test_correctness_with_z_loss_with_other_params_once( 674 test_ce, 675 B, 676 T, 677 V, 678 scalar, 679 dtype, 680 atol, 681 rtol, 682 lse_square_scale, 683 return_z_loss, 684 label_smoothing, 685 ignore_index, 686 reduction, 687 ) 688 689 690 @pytest.mark.parametrize( 691 "B, T, V", 692 [ 693 (2, 4096, 32000), # llama2, mistral 694 # # weird shapes 695 (3, 423, 32000), 696 ], 697 ) 698 @pytest.mark.parametrize("reduction", ["sum", "mean"]) 699 @pytest.mark.parametrize( 700 "scalar, dtype, atol, rtol", 701 [ 702 pytest.param( 703 1.0, 704 torch.bfloat16, 705 1e-8, 706 5e-2, 707 marks=pytest.mark.skipif( 708 not supports_bfloat16(), reason="bfloat16 not supported on this GPU" 709 ), 710 ), 711 (1.0, torch.float32, 1e-8, 1e-6), 712 ], 713 ) 714 def test_correctness_not_last_layer(B, T, V, reduction, scalar, dtype, atol, rtol): 715 liger_ce = LigerCrossEntropyLoss(reduction=reduction) 716 _test_correctness_not_last_layer_once( 717 liger_ce, B, T, V, reduction, scalar, dtype, atol, rtol 718 ) 719 720 721 def test_float32_internal(): 722 """ 723 This test validates that the internal softmax calculations occur in float32, 724 even if the input dtype is bfloat16. 725 """ 726 # Set up test parameters 727 batch_size = 4 728 n_cols = 128256 729 n_non_ignore = batch_size 730 ignore_index = -100 731 label_smoothing = 0.0 732 lse_square_scale = 0.0 733 softcap = 0.0 734 BLOCK_SIZE = 32768 735 reduction = "mean" 736 737 # Initialize input tensors 738 X_init = torch.randn(batch_size, n_cols, dtype=torch.bfloat16, device=device) 739 Y = torch.randint(0, n_cols, (batch_size,), device=device) 740 741 # Run kernel for bfloat16 742 X_bf16 = X_init.clone() 743 loss_bf16 = torch.zeros(batch_size, dtype=torch.float32, device=device) 744 liger_cross_entropy_kernel[(batch_size,)]( 745 X_ptr=X_bf16, 746 X_stride=X_bf16.stride(-2), 747 Y_ptr=Y, 748 Y_stride=Y.stride(-1), 749 z_loss_ptr=loss_bf16, # dummy ptr, not used 750 loss_ptr=loss_bf16, 751 loss_stride=loss_bf16.stride(-1), 752 n_cols=n_cols, 753 n_non_ignore=n_non_ignore, 754 ignore_index=ignore_index, 755 lse_square_scale=lse_square_scale, 756 label_smoothing=label_smoothing, 757 reduction=reduction, 758 softcap=softcap, 759 RETURN_Z_LOSS=0, # False 760 HAS_SOFTCAPPING=False, 761 BLOCK_SIZE=BLOCK_SIZE, 762 num_warps=32, 763 ) 764 765 # Run kernel for float32 766 X_fp32 = X_init.float() 767 loss_fp32 = torch.zeros(batch_size, dtype=torch.float32, device=device) 768 liger_cross_entropy_kernel[(batch_size,)]( 769 X_ptr=X_fp32, 770 X_stride=X_fp32.stride(-2), 771 Y_ptr=Y, 772 Y_stride=Y.stride(-1), 773 loss_ptr=loss_fp32, 774 z_loss_ptr=loss_fp32, # dummy ptr, not used 775 loss_stride=loss_fp32.stride(-1), 776 n_cols=n_cols, 777 n_non_ignore=n_non_ignore, 778 ignore_index=ignore_index, 779 lse_square_scale=lse_square_scale, 780 label_smoothing=label_smoothing, 781 reduction=reduction, 782 softcap=softcap, 783 RETURN_Z_LOSS=0, # False 784 HAS_SOFTCAPPING=False, 785 BLOCK_SIZE=BLOCK_SIZE, 786 num_warps=32, 787 ) 788 789 torch.allclose(X_bf16, X_fp32.bfloat16()) 790 torch.allclose(loss_bf16, loss_fp32)