sglang_v0.5.2/sglang/test/srt/cpu/test_norm.py

91 lines
2.8 KiB
Python

import itertools
import unittest
from typing import Optional, Tuple, Union
import sgl_kernel
import torch
from utils import make_non_contiguous, precision
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(1234)
class TestNorm(CustomTestCase):
M = [4096, 1024]
N = [4096, 4096 + 13]
dtype = [torch.float16, torch.bfloat16]
def _forward_native(
self,
x: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float = 1e-6,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + variance_epsilon)
x = x.to(orig_dtype) * weight
if residual is None:
return x
else:
return x, residual
def _norm_test(self, m, n, dtype):
x = torch.randn([m, n], dtype=dtype)
x = make_non_contiguous(x)
hidden_size = x.size(-1)
weight = torch.randn(hidden_size, dtype=dtype)
variance_epsilon = 1e-6
out = torch.ops.sgl_kernel.rmsnorm_cpu(x, weight, variance_epsilon)
ref_out = self._forward_native(x, weight, variance_epsilon)
atol = rtol = precision[ref_out.dtype]
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
ref_x = x.clone()
residual = torch.randn([m, hidden_size], dtype=dtype)
ref_residual = residual.clone()
torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(
x, residual, weight, variance_epsilon
)
ref_x, ref_residual = self._forward_native(
ref_x, weight, variance_epsilon, ref_residual
)
torch.testing.assert_close(x, ref_x, atol=atol, rtol=rtol)
torch.testing.assert_close(residual, ref_residual, atol=atol, rtol=rtol)
def _l2norm_test(self, m, n, dtype):
x = torch.randn([m, n], dtype=dtype)
hidden_size = x.size(-1)
fake_ones_weight = torch.ones(hidden_size, dtype=dtype)
variance_epsilon = 1e-6
out = torch.ops.sgl_kernel.l2norm_cpu(x, variance_epsilon)
ref_out = self._forward_native(x, fake_ones_weight, variance_epsilon)
atol = rtol = precision[ref_out.dtype]
torch.testing.assert_close(ref_out, out, atol=atol, rtol=rtol)
def test_norm(self):
for params in itertools.product(self.M, self.N, self.dtype):
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
self._norm_test(*params)
self._l2norm_test(*params)
if __name__ == "__main__":
unittest.main()