114 lines
3.7 KiB
Python
114 lines
3.7 KiB
Python
import itertools
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm
|
|
from sglang.test.test_utils import CustomTestCase
|
|
|
|
|
|
class TestRMSNorm(CustomTestCase):
|
|
DTYPES = [torch.half, torch.bfloat16]
|
|
NUM_TOKENS = [7, 83, 4096]
|
|
HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199]
|
|
ADD_RESIDUAL = [False, True]
|
|
SEEDS = [0]
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
if not torch.cuda.is_available():
|
|
raise unittest.SkipTest("CUDA is not available")
|
|
torch.set_default_device("cuda")
|
|
|
|
def _run_rms_norm_test(self, num_tokens, hidden_size, add_residual, dtype, seed):
|
|
torch.manual_seed(seed)
|
|
|
|
layer = RMSNorm(hidden_size).to(dtype=dtype)
|
|
layer.weight.data.normal_(mean=1.0, std=0.1)
|
|
scale = 1 / (2 * hidden_size)
|
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale
|
|
residual = torch.randn_like(x) * scale if add_residual else None
|
|
|
|
with torch.inference_mode():
|
|
ref_out = layer.forward_native(x, residual)
|
|
out = layer(x, residual)
|
|
|
|
if add_residual:
|
|
self.assertTrue(torch.allclose(out[0], ref_out[0], atol=1e-2, rtol=1e-2))
|
|
self.assertTrue(torch.allclose(out[1], ref_out[1], atol=1e-2, rtol=1e-2))
|
|
else:
|
|
self.assertTrue(torch.allclose(out, ref_out, atol=1e-2, rtol=1e-2))
|
|
|
|
def test_rms_norm(self):
|
|
for params in itertools.product(
|
|
self.NUM_TOKENS,
|
|
self.HIDDEN_SIZES,
|
|
self.ADD_RESIDUAL,
|
|
self.DTYPES,
|
|
self.SEEDS,
|
|
):
|
|
with self.subTest(
|
|
num_tokens=params[0],
|
|
hidden_size=params[1],
|
|
add_residual=params[2],
|
|
dtype=params[3],
|
|
seed=params[4],
|
|
):
|
|
self._run_rms_norm_test(*params)
|
|
|
|
|
|
class TestGemmaRMSNorm(CustomTestCase):
|
|
DTYPES = [torch.half, torch.bfloat16]
|
|
NUM_TOKENS = [7, 83, 4096]
|
|
HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199]
|
|
ADD_RESIDUAL = [False, True]
|
|
SEEDS = [0]
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
if not torch.cuda.is_available():
|
|
raise unittest.SkipTest("CUDA is not available")
|
|
torch.set_default_device("cuda")
|
|
|
|
def _run_gemma_rms_norm_test(
|
|
self, num_tokens, hidden_size, add_residual, dtype, seed
|
|
):
|
|
torch.manual_seed(seed)
|
|
|
|
layer = GemmaRMSNorm(hidden_size).to(dtype=dtype)
|
|
layer.weight.data.normal_(mean=1.0, std=0.1)
|
|
scale = 1 / (2 * hidden_size)
|
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale
|
|
residual = torch.randn_like(x) * scale if add_residual else None
|
|
|
|
with torch.inference_mode():
|
|
ref_out = layer.forward_native(x, residual)
|
|
out = layer(x, residual)
|
|
|
|
if add_residual:
|
|
self.assertTrue(torch.allclose(out[0], ref_out[0], atol=1e-3, rtol=1e-3))
|
|
self.assertTrue(torch.allclose(out[1], ref_out[1], atol=1e-3, rtol=1e-3))
|
|
else:
|
|
self.assertTrue(torch.allclose(out, ref_out, atol=1e-3, rtol=1e-3))
|
|
|
|
def test_gemma_rms_norm(self):
|
|
for params in itertools.product(
|
|
self.NUM_TOKENS,
|
|
self.HIDDEN_SIZES,
|
|
self.ADD_RESIDUAL,
|
|
self.DTYPES,
|
|
self.SEEDS,
|
|
):
|
|
with self.subTest(
|
|
num_tokens=params[0],
|
|
hidden_size=params[1],
|
|
add_residual=params[2],
|
|
dtype=params[3],
|
|
seed=params[4],
|
|
):
|
|
self._run_gemma_rms_norm_test(*params)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main(verbosity=2)
|