/ test / transformers / test_group_norm.py
test_group_norm.py
 1  import random
 2  
 3  import pytest
 4  import torch
 5  
 6  from liger_kernel.transformers.group_norm import LigerGroupNorm
 7  from liger_kernel.utils import infer_device
 8  
 9  device = infer_device()
10  
11  random_batch_size = random.randint(1, 16)
12  random_num_groups = random.randint(1, 32)
13  random_num_channels = random_num_groups * random.randint(1, 16)
14  random_hidden_size = random.randint(1, 8192)
15  
16  
17  @pytest.mark.parametrize(
18      "batch_size, num_channels, num_groups, hidden_size",
19      [
20          (1, 1, 1, 3),
21          (1, 4, 2, 4),
22          (16, 12, 3, 4096),
23          (random_batch_size, random_num_channels, random_num_groups, random_hidden_size),
24      ],
25  )
26  @pytest.mark.parametrize(
27      "dtype, atol, rtol",
28      [
29          (torch.float32, 1e-4, 1e-4),
30      ],
31  )
32  def test_liger_group_norm(
33      batch_size, num_channels, num_groups, hidden_size, dtype, atol, rtol
34  ):
35      torch.manual_seed(0)
36  
37      _tensor = torch.randn(
38          batch_size, num_channels, hidden_size, dtype=dtype, device=device
39      )
40  
41      liger_x = _tensor.clone().detach().requires_grad_(True)
42      torch_x = _tensor.clone().detach().requires_grad_(True)
43  
44      liger_ln = LigerGroupNorm(num_channels, num_groups, eps=1e-6).to(dtype).to(device)
45      torch_ln = (
46          torch.nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, eps=1e-6)
47          .to(dtype)
48          .to(device)
49      )
50  
51      with torch.no_grad():
52          torch_ln.weight.copy_(liger_ln.weight)
53          torch_ln.bias.copy_(liger_ln.bias)
54  
55      liger_output = liger_ln(
56          liger_x,
57      )
58      torch_output = torch_ln(torch_x)
59  
60      assert torch.allclose(liger_output, torch_output, atol=atol, rtol=rtol)
61      grad_output = torch.randn_like(torch_x)
62      liger_output.backward(grad_output, retain_graph=True)
63      torch_output.backward(grad_output, retain_graph=True)
64      assert torch.allclose(liger_x.grad, torch_x.grad, atol=atol, rtol=rtol)
65      assert torch.allclose(
66          liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol
67      ), "Bias grads different"
68      assert torch.allclose(
69          liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol
70      ), "Weight grads different"