sglang_v0.5.2/flashinfer_0.3.1/tests/test_norm.py

203 lines
6.7 KiB
Python

"""
Copyright (c) 2024 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import pytest
import torch
import flashinfer
from flashinfer.utils import get_compute_capability
def llama_rms_norm(x, w, eps=1e-6):
orig_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * w.float()
x = x.to(orig_dtype)
return x
def gemma_rms_norm(x, w, eps=1e-6):
orig_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * (1.0 + w.float())
x = x.to(orig_dtype)
return x
def gemma_fused_add_rms_norm(x, residual, w, eps=1e-6):
orig_dtype = x.dtype
x = x + residual
residual = x
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * (1.0 + w.float())
x = x.to(orig_dtype)
return x, residual
def fused_add_rms_norm(x, residual, weight, eps):
orig_dtype = x.dtype
x = x.to(torch.float32)
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = (x * weight.float()).to(orig_dtype)
return x, residual
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("specify_out", [True, False])
@pytest.mark.parametrize("enable_pdl", [True, False])
@pytest.mark.parametrize("contiguous", [True, False])
def test_norm(batch_size, hidden_size, dtype, specify_out, enable_pdl, contiguous):
if contiguous:
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
else:
x = torch.randn(batch_size, hidden_size * 2, device="cuda").to(dtype)
x = x[:, :hidden_size]
major, _ = get_compute_capability(x.device)
if major < 9 and enable_pdl:
pytest.skip("PDL is only available for Hopper and later GPUs")
w = torch.randn(hidden_size).to(0).to(dtype)
y_ref = llama_rms_norm(x, w)
if specify_out:
y = torch.empty_like(x)
flashinfer.norm.rmsnorm(x, w, out=y, enable_pdl=enable_pdl)
else:
y = flashinfer.norm.rmsnorm(x, w, enable_pdl=enable_pdl)
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("enable_pdl", [True, False])
@pytest.mark.parametrize("contiguous", [True, False])
def test_fused_add_rmsnorm(batch_size, hidden_size, dtype, enable_pdl, contiguous):
eps = 1e-6
if contiguous:
x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda")
else:
x = torch.randn(batch_size, hidden_size * 2, device="cuda").to(dtype)
x = x[:, :hidden_size]
major, _ = get_compute_capability(x.device)
if major < 9 and enable_pdl:
pytest.skip("PDL is only available for Hopper and later GPUs")
residual = torch.randn_like(x)
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")
x_native, residual_native = fused_add_rms_norm(
x.clone(), residual.clone(), weight, eps
)
x_fused = x.clone()
residual_fused = residual.clone()
flashinfer.fused_add_rmsnorm(
x_fused, residual_fused, weight, eps, enable_pdl=enable_pdl
)
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("specify_out", [True, False])
@pytest.mark.parametrize("enable_pdl", [True, False])
@pytest.mark.parametrize("contiguous", [True, False])
def test_gemma_norm(
batch_size, hidden_size, dtype, specify_out, enable_pdl, contiguous
):
if contiguous:
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
else:
x = torch.randn(batch_size, hidden_size * 2, device="cuda").to(dtype)
x = x[:, :hidden_size]
major, _ = get_compute_capability(x.device)
if major < 9 and enable_pdl:
pytest.skip("PDL is only available for Hopper and later GPUs")
w = torch.randn(hidden_size).to(0).to(dtype)
y_ref = gemma_rms_norm(x, w)
if specify_out:
y = torch.empty_like(x)
flashinfer.norm.gemma_rmsnorm(x, w, out=y, enable_pdl=enable_pdl)
else:
y = flashinfer.norm.gemma_rmsnorm(x, w, enable_pdl=enable_pdl)
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("enable_pdl", [True, False])
@pytest.mark.parametrize("contiguous", [True, False])
def test_gemma_fused_add_rmsnorm(
batch_size, hidden_size, dtype, enable_pdl, contiguous
):
eps = 1e-6
if contiguous:
x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda")
else:
x = torch.randn(batch_size, hidden_size * 2, device="cuda").to(dtype)
x = x[:, :hidden_size]
major, _ = get_compute_capability(x.device)
if major < 9 and enable_pdl:
pytest.skip("PDL is only available for Hopper and later GPUs")
residual = torch.randn_like(x)
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")
x_native, residual_native = gemma_fused_add_rms_norm(
x.clone(), residual.clone(), weight, eps
)
x_fused = x.clone()
residual_fused = residual.clone()
flashinfer.gemma_fused_add_rmsnorm(
x_fused, residual_fused, weight, eps, enable_pdl=enable_pdl
)
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
if __name__ == "__main__":
# test_norm(1, 1024, torch.float16, False, True)
test_fused_add_rmsnorm(1, 16384, torch.float16, True, True)