/ test / transformers / test_fused_linear_cross_entropy.py
test_fused_linear_cross_entropy.py
  1  from test.transformers.test_cross_entropy import CrossEntropyWithZLoss
  2  from test.utils import assert_verbose_allclose, set_seed
  3  from typing import Optional
  4  
  5  import pytest
  6  import torch
  7  
  8  from liger_kernel.ops.fused_linear_cross_entropy import (
  9      LigerFusedLinearCrossEntropyFunction,
 10  )
 11  from liger_kernel.transformers.functional import liger_fused_linear_cross_entropy
 12  from liger_kernel.transformers.fused_linear_cross_entropy import (
 13      LigerFusedLinearCrossEntropyLoss,
 14  )
 15  from liger_kernel.utils import infer_device
 16  
 17  device = infer_device()
 18  
 19  # set random seed globally
 20  set_seed()
 21  
 22  
 23  class TorchLMHeadCE(torch.nn.Module):
 24      """Ground truth implementation of the linear fused with torch based cross entropy loss.
 25  
 26      :param H: hidden size
 27      :param V: vocab size
 28      :param ignore_index: index to ignore
 29      :param reduction: reduction method
 30      :param label_smoothing: label_smoothing to apply on target
 31      :param lse_square_scale: scaler of lse ^ 2 to compute z loss
 32  
 33      # TODO: if we bump CI env's `transformers` version to >= 4.46, we should just directly
 34      # call https://github.com/huggingface/transformers/blob/main/src/transformers/loss/loss_utils.py#L32
 35      # to be consistent with Hugging Face model implementation.
 36      """
 37  
 38      def __init__(
 39          self,
 40          H: int,
 41          V: int,
 42          dtype: torch.dtype,
 43          bias: bool = False,
 44          ignore_index: int = -100,
 45          lse_square_scale: float = 0.0,
 46          label_smoothing: float = 0.0,
 47          reduction: str = "mean",
 48          softcap: Optional[float] = None,
 49      ):
 50          super().__init__()
 51          self.lin = torch.nn.Linear(
 52              in_features=H, out_features=V, bias=bias, dtype=dtype
 53          )
 54          self.ce_loss = CrossEntropyWithZLoss(
 55              ignore_index=ignore_index,
 56              lse_square_scale=lse_square_scale,
 57              label_smoothing=label_smoothing,
 58              reduction=reduction,
 59          )
 60          self.softcap = softcap
 61  
 62      def forward(self, x, y):
 63          logits = self.lin(x).to(torch.float32)
 64          if self.softcap is not None and self.softcap != 0.0:
 65              logits = self.softcap * torch.tanh(logits / self.softcap)
 66          return self.ce_loss(logits, y)
 67  
 68  
 69  class LigerLMHeadCE(torch.nn.Module):
 70      def __init__(
 71          self,
 72          H: int,
 73          V: int,
 74          dtype: torch.dtype,
 75          bias: bool = False,
 76          ignore_index: int = -100,
 77          lse_square_scale: float = 0.0,
 78          label_smoothing: float = 0.0,
 79          reduction: str = "mean",
 80          softcap: Optional[float] = None,
 81      ):
 82          super().__init__()
 83          self.lin = torch.nn.Linear(
 84              in_features=H, out_features=V, bias=bias, dtype=dtype
 85          )
 86          self.ce_loss = LigerFusedLinearCrossEntropyLoss(
 87              ignore_index=ignore_index,
 88              lse_square_scale=lse_square_scale,
 89              label_smoothing=label_smoothing,
 90              reduction=reduction,
 91              softcap=softcap,
 92          )
 93  
 94      def forward(self, x, y):
 95          return self.ce_loss(self.lin.weight, x, y, self.lin.bias)
 96  
 97  
 98  #############################################################################
 99  # Test the correctness of the fused linear cross entropy loss
100  #############################################################################
101  
102  
103  @pytest.mark.parametrize(
104      "B, T, H, V",
105      [
106          (8, 128, 1024, 4096),
107          (4, 47, 31, 123),  # random shape
108      ],
109  )
110  @pytest.mark.parametrize(
111      "reduction, scalar, dtype, atol, rtol",
112      [
113          ("mean", 1.0, torch.bfloat16, 5e-3, 5e-2),
114          ("mean", 1.0, torch.float32, 1e-5, 5e-4),
115          ("sum", 1.0, torch.bfloat16, 5e-0, 5e1),
116          ("sum", 1.0, torch.float32, 1e-3, 5e-2),
117      ],
118  )
119  @pytest.mark.parametrize("bias", [True, False])
120  @pytest.mark.parametrize(
121      "label_smoothing, ignore_index, lse_square_scale, softcap",
122      [
123          (0, -100, 0, None),
124          (
125              0.1,
126              42,
127              1e-4,
128              30.0,
129          ),  # Pass non-default values once to ensure all params work along
130      ],
131  )
132  def test_correctness(
133      B,
134      T,
135      H,
136      V,
137      scalar,
138      dtype,
139      bias,
140      lse_square_scale,
141      label_smoothing,
142      ignore_index,
143      reduction,
144      softcap,
145      atol,
146      rtol,
147  ):
148      torch_lm_head_ce = TorchLMHeadCE(
149          H=H,
150          V=V,
151          bias=bias,
152          lse_square_scale=lse_square_scale,
153          label_smoothing=label_smoothing,
154          ignore_index=ignore_index,
155          reduction=reduction,
156          softcap=softcap,
157          dtype=dtype,
158      ).to(device)
159      liger_lm_head_ce = LigerLMHeadCE(
160          H=H,
161          V=V,
162          bias=bias,
163          lse_square_scale=lse_square_scale,
164          label_smoothing=label_smoothing,
165          ignore_index=ignore_index,
166          reduction=reduction,
167          softcap=softcap,
168          dtype=dtype,
169      ).to(device)
170  
171      # init the linear in all CEs with the same weights
172      torch_lm_head_ce.lin.weight.data = liger_lm_head_ce.lin.weight.data = torch.rand(
173          V, H, device=device, dtype=dtype
174      )
175  
176      if bias:
177          torch_lm_head_ce.lin.bias.data = liger_lm_head_ce.lin.bias.data = torch.rand(
178              V, device=device, dtype=dtype
179          )
180  
181      _tensor = torch.randn(B * T, H, device=device, dtype=dtype) * scalar
182      _input1 = _tensor.detach().clone().requires_grad_(True)
183      _input2 = _tensor.detach().clone().requires_grad_(True)
184  
185      target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
186      # Assign some random number of elements as ignore_index
187      num_elements_to_assign = torch.randint(
188          1, B * T // 2, (1,)
189      ).item()  # Random number of elements to set to ignore_index
190      indices_to_assign = torch.randperm(B * T)[
191          :num_elements_to_assign
192      ]  # Randomly select indices
193      target[indices_to_assign] = ignore_index
194  
195      output1 = torch_lm_head_ce(_input1, target)
196      output2 = liger_lm_head_ce(_input2, target)
197  
198      assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol)
199  
200      output1.backward()
201      output2.backward()
202  
203      assert_verbose_allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol)
204  
205      assert_verbose_allclose(
206          torch_lm_head_ce.lin.weight.grad,
207          liger_lm_head_ce.lin.weight.grad,
208          atol=atol,
209          rtol=rtol,
210      )
211  
212      if bias:
213          assert_verbose_allclose(
214              torch_lm_head_ce.lin.bias.grad,
215              liger_lm_head_ce.lin.bias.grad,
216              atol=atol,
217              rtol=rtol,
218          )
219  
220  
221  @pytest.mark.parametrize(
222      "B, T, H, V",
223      [
224          (2, 2, 8, 8),
225          # weird shapes
226          (9, 7, 41, 41),
227      ],
228  )
229  @pytest.mark.parametrize(
230      "scalar, dtype, atol, rtol",
231      [
232          (1.0, torch.bfloat16, 5e-3, 5e-2),
233          (1.0, torch.float32, 1e-5, 5e-4),
234      ],
235  )
236  @pytest.mark.parametrize("bias", [True, False])
237  def test_correctness_functional(B, T, H, V, scalar, dtype, bias, atol, rtol):
238      _input = torch.randn(B * T, H, device=device, dtype=dtype) * scalar
239      x1 = _input.detach().clone().requires_grad_(True)
240      x2 = _input.detach().clone().requires_grad_(True)
241  
242      target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
243  
244      weight = torch.randn(V, H, device=device, dtype=dtype)
245      bias = torch.randn(V, device=device, dtype=dtype) if bias else None
246  
247      y1 = liger_fused_linear_cross_entropy(
248          input=x1,
249          weight=weight,
250          target=target,
251          bias=bias,
252      )
253      y2 = LigerFusedLinearCrossEntropyFunction.apply(x2, weight, target, bias)
254  
255      assert torch.allclose(y1, y2, atol=atol, rtol=rtol)
256  
257      grad_output = torch.randn_like(y1)
258  
259      y1.backward(grad_output)
260      y2.backward(grad_output)
261  
262      assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol)
263  
264  
265  @pytest.mark.parametrize(
266      "B, T, H, V",
267      [
268          (8, 128, 1024, 4096),
269          (4, 47, 31, 123),  # random shape
270      ],
271  )
272  @pytest.mark.parametrize(
273      "cast_dtype, atol, rtol",
274      [
275          (torch.bfloat16, 5e-3, 5e-2),
276          (torch.float16, 5e-3, 5e-2),
277      ],
278  )
279  def test_amp(B, T, H, V, cast_dtype, atol, rtol):
280      dtype = torch.float32
281      torch_lm_head_ce = TorchLMHeadCE(
282          H=H,
283          V=V,
284          bias=True,
285          label_smoothing=0.0,
286          reduction="mean",
287          dtype=dtype,
288      ).to(device)
289      liger_lm_head_ce = LigerLMHeadCE(
290          H=H,
291          V=V,
292          bias=True,
293          label_smoothing=0.0,
294          reduction="mean",
295          dtype=dtype,
296      ).to(device)
297  
298      # init the linear in all CEs with the same weights
299      torch_lm_head_ce.lin.weight.data = liger_lm_head_ce.lin.weight.data = torch.rand(
300          V, H, device=device, dtype=dtype
301      )
302  
303      _tensor = torch.randn(B * T, H, device=device, dtype=dtype)
304      _input1 = _tensor.detach().clone().requires_grad_(True)
305      _input2 = _tensor.detach().clone().requires_grad_(True)
306  
307      target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)
308  
309      with torch.autocast(device_type=device, dtype=cast_dtype):
310          output1 = torch_lm_head_ce(_input1, target)
311          output2 = liger_lm_head_ce(_input2, target)
312  
313      assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol)
314  
315      with torch.autocast(device_type=device, dtype=cast_dtype):
316          output1.backward()
317          output2.backward()
318  
319      assert_verbose_allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol)
320  
321      assert_verbose_allclose(
322          torch_lm_head_ce.lin.weight.grad,
323          liger_lm_head_ce.lin.weight.grad,
324          atol=atol,
325          rtol=rtol,
326      )