test_group_norm.py
1 import random 2 3 import pytest 4 import torch 5 6 from liger_kernel.transformers.group_norm import LigerGroupNorm 7 from liger_kernel.utils import infer_device 8 9 device = infer_device() 10 11 random_batch_size = random.randint(1, 16) 12 random_num_groups = random.randint(1, 32) 13 random_num_channels = random_num_groups * random.randint(1, 16) 14 random_hidden_size = random.randint(1, 8192) 15 16 17 @pytest.mark.parametrize( 18 "batch_size, num_channels, num_groups, hidden_size", 19 [ 20 (1, 1, 1, 3), 21 (1, 4, 2, 4), 22 (16, 12, 3, 4096), 23 (random_batch_size, random_num_channels, random_num_groups, random_hidden_size), 24 ], 25 ) 26 @pytest.mark.parametrize( 27 "dtype, atol, rtol", 28 [ 29 (torch.float32, 1e-4, 1e-4), 30 ], 31 ) 32 def test_liger_group_norm( 33 batch_size, num_channels, num_groups, hidden_size, dtype, atol, rtol 34 ): 35 torch.manual_seed(0) 36 37 _tensor = torch.randn( 38 batch_size, num_channels, hidden_size, dtype=dtype, device=device 39 ) 40 41 liger_x = _tensor.clone().detach().requires_grad_(True) 42 torch_x = _tensor.clone().detach().requires_grad_(True) 43 44 liger_ln = LigerGroupNorm(num_channels, num_groups, eps=1e-6).to(dtype).to(device) 45 torch_ln = ( 46 torch.nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, eps=1e-6) 47 .to(dtype) 48 .to(device) 49 ) 50 51 with torch.no_grad(): 52 torch_ln.weight.copy_(liger_ln.weight) 53 torch_ln.bias.copy_(liger_ln.bias) 54 55 liger_output = liger_ln( 56 liger_x, 57 ) 58 torch_output = torch_ln(torch_x) 59 60 assert torch.allclose(liger_output, torch_output, atol=atol, rtol=rtol) 61 grad_output = torch.randn_like(torch_x) 62 liger_output.backward(grad_output, retain_graph=True) 63 torch_output.backward(grad_output, retain_graph=True) 64 assert torch.allclose(liger_x.grad, torch_x.grad, atol=atol, rtol=rtol) 65 assert torch.allclose( 66 liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol 67 ), "Bias grads different" 68 assert torch.allclose( 69 liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol 70 ), "Weight grads different"