/ test / transformers / test_cross_entropy.py
test_cross_entropy.py
  1  from test.utils import assert_verbose_allclose, set_seed, supports_bfloat16
  2  
  3  import pytest
  4  import torch
  5  import torch.nn.functional as F
  6  from torch.nn import CrossEntropyLoss
  7  
  8  from liger_kernel.ops.cross_entropy import (
  9      LigerCrossEntropyFunction,
 10      liger_cross_entropy_kernel,
 11  )
 12  from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
 13  from liger_kernel.transformers.functional import liger_cross_entropy
 14  from liger_kernel.utils import infer_device
 15  
 16  device = infer_device()
 17  set_seed(42)
 18  
 19  
 20  class CrossEntropyWithZLoss(torch.nn.Module):
 21      def __init__(
 22          self,
 23          lse_square_scale=0.0,
 24          reduction="mean",
 25          ignore_index=-100,
 26          label_smoothing=0.0,
 27          return_z_loss=False,
 28          dtype=torch.float32,
 29      ):
 30          super().__init__()
 31          self.lse_square_scale = lse_square_scale
 32          self.reduction = reduction
 33          self.ignore_index = ignore_index
 34          self.return_z_loss = return_z_loss
 35          self.label_smoothing = label_smoothing
 36          self.dtype = dtype
 37  
 38      def forward(self, logits, targets):
 39          # Loss calculations are all in float32
 40          logits = logits.to(torch.float32)
 41          # Standard cross entropy loss
 42          ce_loss = F.cross_entropy(
 43              logits,
 44              targets,
 45              reduction=self.reduction,
 46              label_smoothing=self.label_smoothing,
 47              ignore_index=self.ignore_index,
 48          )
 49  
 50          # Compute log-sum-exp term
 51          lse = torch.logsumexp(logits, dim=-1)
 52  
 53          # Z-loss term
 54          z_loss = torch.where(
 55              targets != self.ignore_index, self.lse_square_scale * (lse**2), 0.0
 56          )
 57          z_loss = z_loss.to(logits.dtype)
 58          if self.reduction == "mean":
 59              z_loss = z_loss.sum() / (targets != self.ignore_index).sum()
 60          elif self.reduction == "sum":
 61              z_loss = z_loss.sum()
 62          else:
 63              z_loss = z_loss
 64          ce_loss = ce_loss.to(self.dtype)
 65          z_loss = z_loss.to(self.dtype)
 66  
 67          # Final loss: cross-entropy loss + Z-loss
 68          total_loss = ce_loss + z_loss
 69          if self.return_z_loss:
 70              return total_loss, z_loss
 71          else:
 72              return total_loss
 73  
 74  
 75  def _test_correctness_once(target_ce, B, T, V, reduction, scalar, dtype, atol, rtol):
 76      torch.manual_seed(0)
 77      torch_ce = CrossEntropyLoss(reduction=reduction)
 78  
 79      _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar
 80      _input = _tensor.detach().clone().requires_grad_(True)
 81      _input2 = _tensor.detach().clone().requires_grad_(True)
 82  
 83      target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
 84  
 85      output = torch_ce(_input, target)
 86      output2 = target_ce(_input2, target)
 87      assert torch.allclose(output, output2, atol=atol, rtol=rtol)
 88  
 89      output.backward()
 90      output2.backward()
 91      assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)
 92  
 93  
 94  def _test_correctness_with_ignore_index_once(
 95      target_ce, B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol
 96  ):
 97  
 98      torch_ce = CrossEntropyLoss(ignore_index=ignore_index, reduction=reduction)
 99  
100      _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar
101      _input = _tensor.detach().clone().requires_grad_(True)
102      _input2 = _tensor.detach().clone().requires_grad_(True)
103  
104      target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
105  
106      # Assign some random number of elements as ignore_index
107      num_elements_to_assign = torch.randint(
108          1, B * T // 2, (1,)
109      ).item()  # Random number of elements to set to ignore_index
110      indices_to_assign = torch.randperm(B * T)[
111          :num_elements_to_assign
112      ]  # Randomly select indices
113      target[indices_to_assign] = ignore_index
114  
115      output = torch_ce(_input, target)
116      output2 = target_ce(_input2, target)
117  
118      assert torch.allclose(output, output2, atol=atol, rtol=rtol)
119  
120      output.backward()
121      output2.backward()
122      assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)
123  
124  
125  def _test_correctness_with_label_smoothing_once(
126      target_ce, B, T, V, label_smoothing, scalar, dtype, atol, rtol
127  ):
128  
129      torch_ce = CrossEntropyLoss(label_smoothing=label_smoothing)
130  
131      _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar
132      _input = _tensor.detach().clone().requires_grad_(True)
133      _input2 = _tensor.detach().clone().requires_grad_(True)
134  
135      target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
136  
137      output = torch_ce(_input, target)
138      output2 = target_ce(_input2, target)
139  
140      assert torch.allclose(output, output2, atol=atol, rtol=rtol)
141  
142      output.backward()
143      output2.backward()
144      assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)
145  
146  
147  def _test_correctness_with_label_smoothing_with_ignore_index_once(
148      target_ce, B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol
149  ):
150  
151      torch_ce = CrossEntropyLoss(
152          ignore_index=ignore_index, label_smoothing=label_smoothing
153      )
154  
155      _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar
156      _input = _tensor.detach().clone().requires_grad_(True)
157      _input2 = _tensor.detach().clone().requires_grad_(True)
158  
159      target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
160  
161      # Assign some random number of elements as ignore_index
162      num_elements_to_assign = torch.randint(
163          1, B * T // 2, (1,)
164      ).item()  # Random number of elements to set to ignore_index
165      indices_to_assign = torch.randperm(B * T)[
166          :num_elements_to_assign
167      ]  # Randomly select indices
168      target[indices_to_assign] = ignore_index
169  
170      output = torch_ce(_input, target)
171      output2 = target_ce(_input2, target)
172  
173      assert torch.allclose(output, output2, atol=atol, rtol=rtol)
174  
175      output.backward()
176      output2.backward()
177      assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)
178  
179  
180  def _test_correctness_with_softcap_once(
181      target_ce, B, T, V, softcap, reduction, scalar, dtype, atol, rtol
182  ):
183  
184      torch_ce = CrossEntropyLoss(reduction=reduction)
185  
186      _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar
187      # upcasting to match liger's casting strategy
188      _input = _tensor.to(torch.float32).detach().clone().requires_grad_(True)
189      _input2 = _tensor.detach().clone().requires_grad_(True)
190  
191      target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
192  
193      # downcasting to original dtype
194      output = torch_ce(softcap * torch.tanh(_input / softcap), target).to(dtype)
195      output2 = target_ce(_input2, target)
196  
197      assert torch.allclose(output, output2, atol=atol, rtol=rtol)
198  
199      output.backward()
200      output2.backward()
201  
202  
203  def _test_correctness_with_z_loss_once(
204      target_ce,
205      B,
206      T,
207      V,
208      scalar,
209      dtype,
210      atol,
211      rtol,
212      lse_square_scale,
213      return_z_loss,
214  ):
215      torch.manual_seed(0)
216      torch_ce = CrossEntropyWithZLoss(
217          lse_square_scale=lse_square_scale,
218          return_z_loss=return_z_loss,
219          dtype=dtype,
220      )
221  
222      _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar
223      _input = _tensor.detach().clone().requires_grad_(True)
224      _input2 = _tensor.detach().clone().requires_grad_(True)
225  
226      target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
227      if return_z_loss:
228          output, z_output = torch_ce(_input, target)
229          output2, z_output2 = target_ce(_input2, target)
230  
231      else:
232          output = torch_ce(_input, target)
233          output2 = target_ce(_input2, target)
234  
235      assert torch.allclose(output, output2, atol=atol, rtol=rtol)
236  
237      if return_z_loss:
238          assert torch.allclose(z_output, z_output2, atol=atol, rtol=rtol)
239  
240      output.backward()
241      output2.backward()
242  
243      assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)
244  
245  
246  def _test_correctness_with_z_loss_with_other_params_once(
247      target_ce,
248      B,
249      T,
250      V,
251      scalar,
252      dtype,
253      atol,
254      rtol,
255      lse_square_scale,
256      return_z_loss,
257      label_smoothing,
258      ignore_index,
259      reduction,
260  ):
261      torch.manual_seed(0)
262      torch_ce = CrossEntropyWithZLoss(
263          lse_square_scale=lse_square_scale,
264          return_z_loss=return_z_loss,
265          label_smoothing=label_smoothing,
266          ignore_index=ignore_index,
267          reduction=reduction,
268          dtype=dtype,
269      )
270  
271      _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar
272      _input = _tensor.detach().clone().requires_grad_(True)
273      _input2 = _tensor.detach().clone().requires_grad_(True)
274  
275      target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
276  
277      # Assign some random number of elements as ignore_index
278      num_elements_to_assign = torch.randint(
279          1, B * T // 2, (1,)
280      ).item()  # Random number of elements to set to ignore_index
281      indices_to_assign = torch.randperm(B * T)[
282          :num_elements_to_assign
283      ]  # Randomly select indices
284      target[indices_to_assign] = ignore_index
285  
286      if return_z_loss:
287          output, z_output = torch_ce(_input, target)
288          output2, z_output2 = target_ce(_input2, target)
289  
290      else:
291          output = torch_ce(_input, target)
292          output2 = target_ce(_input2, target)
293  
294      assert torch.allclose(output, output2, atol=atol, rtol=rtol)
295  
296      if return_z_loss:
297          assert torch.allclose(z_output, z_output2, atol=atol, rtol=rtol)
298  
299      output.backward()
300      output2.backward()
301      assert_verbose_allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)
302  
303  
304  def _test_correctness_not_last_layer_once(
305      target_ce, B, T, V, reduction, scalar, dtype, atol, rtol
306  ):
307  
308      torch_ce = CrossEntropyLoss(reduction=reduction)
309  
310      _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar
311      _input = _tensor.detach().clone().requires_grad_(True)
312      _input2 = _tensor.detach().clone().requires_grad_(True)
313  
314      target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
315  
316      output = torch_ce(_input, target)
317      output2 = target_ce(_input2, target)
318      assert torch.allclose(output, output2, atol=atol, rtol=rtol)
319  
320      loss1 = output * 3
321      loss2 = output2 * 3
322  
323      loss1.backward()
324      loss2.backward()
325      assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)
326  
327  
328  def _test_correctness_functional(
329      B,
330      T,
331      V,
332      scalar,
333      dtype,
334      atol,
335      rtol,
336  ):
337  
338      _input = torch.randn(B * T, V, device=device, dtype=dtype) * scalar
339  
340      x1 = _input.clone().requires_grad_(True)
341      x2 = _input.clone().requires_grad_(True)
342  
343      target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
344  
345      y1, y1_z = liger_cross_entropy(
346          x1,
347          target,
348          ignore_index=0,
349          lse_square_scale=1e-4,
350          label_smoothing=0.1,
351          reduction="mean",
352          softcap=30.0,
353          return_z_loss=True,
354      )
355      y2, y2_z = LigerCrossEntropyFunction.apply(
356          x2, target, 0, 1e-4, 0.1, "mean", 30.0, True
357      )
358  
359      assert torch.allclose(y1, y2, atol=atol, rtol=rtol)
360      assert torch.allclose(y1_z, y2_z, atol=atol, rtol=rtol)
361  
362      grad = torch.randn_like(y2)
363  
364      y1.backward(grad)
365      y2.backward(grad)
366  
367      assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol)
368  
369  
370  #############################################################################
371  # Test the correctness of the liger cross entropy loss
372  #############################################################################
373  
374  
375  @pytest.mark.parametrize(
376      "B, T, V",
377      [
378          (2, 4096, 32000),  # llama
379          (3, 423, 32000),  # weird shapes
380      ],
381  )
382  @pytest.mark.parametrize("reduction", ["sum", "mean"])
383  @pytest.mark.parametrize(
384      "scalar, dtype, atol, rtol",
385      [
386          pytest.param(
387              1.0,
388              torch.bfloat16,
389              1e-8,
390              5e-2,
391              marks=pytest.mark.skipif(
392                  not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
393              ),
394          ),
395          (1.0, torch.float32, 1e-8, 1e-6),
396      ],
397  )
398  def test_correctness(B, T, V, scalar, dtype, reduction, atol, rtol):
399      liger_ce = LigerCrossEntropyLoss(reduction=reduction)
400      _test_correctness_once(liger_ce, B, T, V, reduction, scalar, dtype, atol, rtol)
401  
402  
403  @pytest.mark.parametrize(
404      "B, T, V",
405      [
406          (2, 2, 8),
407          # weird shapes
408          (9, 7, 41),
409      ],
410  )
411  @pytest.mark.parametrize(
412      "scalar, dtype, atol, rtol",
413      [
414          (1.0, torch.bfloat16, 1e-8, 5e-2),
415          (1.0, torch.float32, 1e-8, 1e-6),
416      ],
417  )
418  def test_correctness_functional(B, T, V, scalar, dtype, atol, rtol):
419      _test_correctness_functional(B, T, V, scalar, dtype, atol, rtol)
420  
421  
422  @pytest.mark.parametrize(
423      "B, T, V, ignore_index",
424      [
425          (2, 4096, 32000, 2),
426          # weird shapes
427          (3, 423, 32000, -123),
428      ],
429  )
430  @pytest.mark.parametrize("reduction", ["sum", "mean"])
431  @pytest.mark.parametrize(
432      "scalar, dtype, atol, rtol",
433      [
434          pytest.param(
435              1.0,
436              torch.bfloat16,
437              1e-8,
438              5e-2,
439              marks=pytest.mark.skipif(
440                  not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
441              ),
442          ),
443          (1.0, torch.float32, 1e-8, 1e-6),
444      ],
445  )
446  def test_correctness_with_ignore_index(
447      B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol
448  ):
449      liger_ce = LigerCrossEntropyLoss(ignore_index=ignore_index, reduction=reduction)
450      _test_correctness_with_ignore_index_once(
451          liger_ce, B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol
452      )
453  
454  
455  @pytest.mark.parametrize(
456      "B, T, V, label_smoothing",
457      [
458          (2, 4096, 32000, 0.1),
459          # weird shapes
460          (3, 423, 32000, 0.1),
461      ],
462  )
463  @pytest.mark.parametrize(
464      "scalar, dtype, atol, rtol",
465      [
466          pytest.param(
467              1.0,
468              torch.bfloat16,
469              1e-8,
470              5e-2,
471              marks=pytest.mark.skipif(
472                  not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
473              ),
474          ),
475          (1.0, torch.float32, 1e-8, 1e-6),
476      ],
477  )
478  def test_correctness_with_label_smoothing_once(
479      B, T, V, label_smoothing, scalar, dtype, atol, rtol
480  ):
481      liger_ce = LigerCrossEntropyLoss(label_smoothing=label_smoothing)
482      _test_correctness_with_label_smoothing_once(
483          liger_ce, B, T, V, label_smoothing, scalar, dtype, atol, rtol
484      )
485  
486  
487  @pytest.mark.parametrize(
488      "B, T, V, ignore_index, label_smoothing",
489      [
490          (2, 4096, 32000, 1, 0.1),
491          # weird shapes
492          (3, 423, 32000, -300, 0.2),
493      ],
494  )
495  @pytest.mark.parametrize(
496      "scalar, dtype, atol, rtol",
497      [
498          pytest.param(
499              1.0,
500              torch.bfloat16,
501              1e-8,
502              5e-2,
503              marks=pytest.mark.skipif(
504                  not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
505              ),
506          ),
507          (1.0, torch.float32, 1e-8, 1e-6),
508      ],
509  )
510  def test_correctness_with_label_smoothing_with_ignore_index_once(
511      B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol
512  ):
513      liger_ce = LigerCrossEntropyLoss(
514          ignore_index=ignore_index,
515          label_smoothing=label_smoothing,
516      )
517      _test_correctness_with_label_smoothing_with_ignore_index_once(
518          liger_ce, B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol
519      )
520  
521  
522  @pytest.mark.parametrize(
523      "B, T, V, softcap",
524      [
525          (2, 4096, 32000, 30.0),  # llama2, mistral
526          # weird shapes
527          (3, 423, 32000, 30.0),
528      ],
529  )
530  @pytest.mark.parametrize("reduction", ["sum", "mean"])
531  @pytest.mark.parametrize(
532      "scalar, dtype, atol, rtol",
533      [
534          pytest.param(
535              1.0,
536              torch.bfloat16,
537              1e-8,
538              5e-2,
539              marks=pytest.mark.skipif(
540                  not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
541              ),
542          ),
543          (1.0, torch.float32, 1e-8, 1e-6),
544      ],
545  )
546  def test_correctness_with_softcap_once(
547      B, T, V, softcap, reduction, scalar, dtype, atol, rtol
548  ):
549      liger_ce = LigerCrossEntropyLoss(softcap=softcap, reduction=reduction)
550      _test_correctness_with_softcap_once(
551          liger_ce, B, T, V, softcap, reduction, scalar, dtype, atol, rtol
552      )
553  
554  
555  @pytest.mark.parametrize(
556      "B, T, V",
557      [
558          (2, 4096, 32000),  # llama2
559          # weird shapes
560          (3, 423, 32000),
561      ],
562  )
563  @pytest.mark.parametrize(
564      "scalar, dtype, atol, rtol",
565      [
566          pytest.param(
567              1.0,
568              torch.bfloat16,
569              1e-8,
570              5e-2,
571              marks=pytest.mark.skipif(
572                  not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
573              ),
574          ),
575          (1.0, torch.float32, 1e-8, 1e-6),
576      ],
577  )
578  @pytest.mark.parametrize("return_z_loss", [True, False])
579  @pytest.mark.parametrize(
580      "lse_square_scale",
581      [
582          1e-4,  # PaLM
583          1e-5,  # Chameleon
584      ],
585  )
586  def test_correctness_with_z_loss_once(
587      B,
588      T,
589      V,
590      scalar,
591      dtype,
592      atol,
593      rtol,
594      lse_square_scale,
595      return_z_loss,
596  ):
597      test_ce = LigerCrossEntropyLoss(
598          lse_square_scale=lse_square_scale,
599          return_z_loss=return_z_loss,
600      )
601      _test_correctness_with_z_loss_once(
602          test_ce,
603          B,
604          T,
605          V,
606          scalar,
607          dtype,
608          atol,
609          rtol,
610          lse_square_scale,
611          return_z_loss,
612      )
613  
614  
615  @pytest.mark.parametrize(
616      "B, T, V",
617      [
618          (2, 4096, 32000),  # llama2, mistral
619          # weird shapes
620          (3, 423, 32000),
621      ],
622  )
623  @pytest.mark.parametrize(
624      "scalar, dtype, atol, rtol",
625      [
626          pytest.param(
627              1.0,
628              torch.bfloat16,
629              1e-8,
630              5e-2,
631              marks=pytest.mark.skipif(
632                  not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
633              ),
634          ),
635          (1.0, torch.float32, 1e-8, 1e-6),
636      ],
637  )
638  @pytest.mark.parametrize(
639      "return_z_loss, lse_square_scale",
640      [
641          (True, 1e-4),
642          (False, 1e-5),
643      ],
644  )
645  @pytest.mark.parametrize(
646      "label_smoothing, ignore_index, reduction",
647      [
648          (0.1, 42, "mean"),
649          (0.2, -42, "sum"),
650      ],
651  )
652  def test_correctness_with_z_loss_with_other_params_once(
653      B,
654      T,
655      V,
656      scalar,
657      dtype,
658      atol,
659      rtol,
660      lse_square_scale,
661      return_z_loss,
662      label_smoothing,
663      ignore_index,
664      reduction,
665  ):
666      test_ce = LigerCrossEntropyLoss(
667          lse_square_scale=lse_square_scale,
668          return_z_loss=return_z_loss,
669          label_smoothing=label_smoothing,
670          ignore_index=ignore_index,
671          reduction=reduction,
672      )
673      _test_correctness_with_z_loss_with_other_params_once(
674          test_ce,
675          B,
676          T,
677          V,
678          scalar,
679          dtype,
680          atol,
681          rtol,
682          lse_square_scale,
683          return_z_loss,
684          label_smoothing,
685          ignore_index,
686          reduction,
687      )
688  
689  
690  @pytest.mark.parametrize(
691      "B, T, V",
692      [
693          (2, 4096, 32000),  # llama2, mistral
694          # # weird shapes
695          (3, 423, 32000),
696      ],
697  )
698  @pytest.mark.parametrize("reduction", ["sum", "mean"])
699  @pytest.mark.parametrize(
700      "scalar, dtype, atol, rtol",
701      [
702          pytest.param(
703              1.0,
704              torch.bfloat16,
705              1e-8,
706              5e-2,
707              marks=pytest.mark.skipif(
708                  not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
709              ),
710          ),
711          (1.0, torch.float32, 1e-8, 1e-6),
712      ],
713  )
714  def test_correctness_not_last_layer(B, T, V, reduction, scalar, dtype, atol, rtol):
715      liger_ce = LigerCrossEntropyLoss(reduction=reduction)
716      _test_correctness_not_last_layer_once(
717          liger_ce, B, T, V, reduction, scalar, dtype, atol, rtol
718      )
719  
720  
721  def test_float32_internal():
722      """
723      This test validates that the internal softmax calculations occur in float32,
724      even if the input dtype is bfloat16.
725      """
726      # Set up test parameters
727      batch_size = 4
728      n_cols = 128256
729      n_non_ignore = batch_size
730      ignore_index = -100
731      label_smoothing = 0.0
732      lse_square_scale = 0.0
733      softcap = 0.0
734      BLOCK_SIZE = 32768
735      reduction = "mean"
736  
737      # Initialize input tensors
738      X_init = torch.randn(batch_size, n_cols, dtype=torch.bfloat16, device=device)
739      Y = torch.randint(0, n_cols, (batch_size,), device=device)
740  
741      # Run kernel for bfloat16
742      X_bf16 = X_init.clone()
743      loss_bf16 = torch.zeros(batch_size, dtype=torch.float32, device=device)
744      liger_cross_entropy_kernel[(batch_size,)](
745          X_ptr=X_bf16,
746          X_stride=X_bf16.stride(-2),
747          Y_ptr=Y,
748          Y_stride=Y.stride(-1),
749          z_loss_ptr=loss_bf16,  # dummy ptr, not used
750          loss_ptr=loss_bf16,
751          loss_stride=loss_bf16.stride(-1),
752          n_cols=n_cols,
753          n_non_ignore=n_non_ignore,
754          ignore_index=ignore_index,
755          lse_square_scale=lse_square_scale,
756          label_smoothing=label_smoothing,
757          reduction=reduction,
758          softcap=softcap,
759          RETURN_Z_LOSS=0,  # False
760          HAS_SOFTCAPPING=False,
761          BLOCK_SIZE=BLOCK_SIZE,
762          num_warps=32,
763      )
764  
765      # Run kernel for float32
766      X_fp32 = X_init.float()
767      loss_fp32 = torch.zeros(batch_size, dtype=torch.float32, device=device)
768      liger_cross_entropy_kernel[(batch_size,)](
769          X_ptr=X_fp32,
770          X_stride=X_fp32.stride(-2),
771          Y_ptr=Y,
772          Y_stride=Y.stride(-1),
773          loss_ptr=loss_fp32,
774          z_loss_ptr=loss_fp32,  # dummy ptr, not used
775          loss_stride=loss_fp32.stride(-1),
776          n_cols=n_cols,
777          n_non_ignore=n_non_ignore,
778          ignore_index=ignore_index,
779          lse_square_scale=lse_square_scale,
780          label_smoothing=label_smoothing,
781          reduction=reduction,
782          softcap=softcap,
783          RETURN_Z_LOSS=0,  # False
784          HAS_SOFTCAPPING=False,
785          BLOCK_SIZE=BLOCK_SIZE,
786          num_warps=32,
787      )
788  
789      torch.allclose(X_bf16, X_fp32.bfloat16())
790      torch.allclose(loss_bf16, loss_fp32)