40 lines
1.7 KiB
Python
40 lines
1.7 KiB
Python
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_activation.py
|
|
|
|
import pytest
|
|
import sgl_kernel
|
|
import torch
|
|
|
|
|
|
@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384])
|
|
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
|
|
@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512])
|
|
def test_fused_silu_mul(dim, batch_size, seq_len):
|
|
x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16)
|
|
y_ref = x[..., dim:] * torch.nn.functional.silu(x[..., :dim])
|
|
y = sgl_kernel.silu_and_mul(x)
|
|
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
|
|
|
|
|
@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384])
|
|
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
|
|
@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512])
|
|
def test_fused_gelu_tanh_mul(dim, batch_size, seq_len):
|
|
x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16)
|
|
y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="tanh")
|
|
y = sgl_kernel.gelu_tanh_and_mul(x)
|
|
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
|
|
|
|
|
@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384])
|
|
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
|
|
@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512])
|
|
def test_fused_gelu_mul(dim, batch_size, seq_len):
|
|
x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16)
|
|
y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="none")
|
|
y = sgl_kernel.gelu_and_mul(x)
|
|
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__])
|