sglang0.4.5.post1/python/sglang/test/test_custom_ops.py

89 lines
3.2 KiB
Python

# Adapted from https://github.com/vllm-project/vllm/blob/8ca7a71df787ad711ad3ac70a5bd2eb2bb398938/tests/quantization/test_fp8.py
import pytest
import torch
from sglang.srt.custom_op import scaled_fp8_quant
from sglang.srt.utils import is_cuda
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scaled_fp8_quant_per_tensor(dtype) -> None:
def quantize_ref_per_tensor(tensor, inv_scale):
# The reference implementation that fully aligns to
# the kernel being tested.
finfo = torch.finfo(torch.float8_e4m3fn)
scale = inv_scale.reciprocal()
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
qweight = qweight.to(torch.float8_e4m3fn)
return qweight
def dequantize_per_tensor(tensor, inv_scale, dtype):
fake_qweight = tensor.to(dtype)
dq_weight = fake_qweight * inv_scale
return dq_weight
# Note that we use a shape % 8 != 0 to cover edge cases,
# because scaled_fp8_quant is vectorized by 8.
x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)
# Test Per Tensor Dynamic quantization
# scale = max(abs(x)) / FP8_E4M3_MAX
y, scale = scaled_fp8_quant(x, None)
ref_y = quantize_ref_per_tensor(x, scale)
torch.testing.assert_close(y, ref_y)
torch.testing.assert_close(
dequantize_per_tensor(y, scale, dtype),
dequantize_per_tensor(ref_y, scale, dtype),
)
# Test Per Tensor Static quantization
y, _ = scaled_fp8_quant(x, scale)
ref_y = quantize_ref_per_tensor(x, scale)
torch.testing.assert_close(y, ref_y)
torch.testing.assert_close(
dequantize_per_tensor(y, scale, dtype),
dequantize_per_tensor(ref_y, scale, dtype),
)
if is_cuda:
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_scaled_fp8_quant_per_token_dynamic(dtype) -> None:
def quantize_ref_per_token(tensor, inv_scale):
# The reference implementation that fully aligns to
# the kernel being tested.
finfo = torch.finfo(torch.float8_e4m3fn)
scale = inv_scale.reciprocal()
qweight = (tensor.to(torch.float32) * scale).clamp(
min=finfo.min, max=finfo.max
)
qweight = qweight.to(torch.float8_e4m3fn)
return qweight
def dequantize_per_token(tensor, inv_scale, dtype):
fake_qweight = tensor.to(dtype)
dq_weight = fake_qweight * inv_scale
return dq_weight
# Note that we use a shape % 8 = 0,
# because per_token_quant_fp8 is vectorized by 8 elements.
x = (torch.randn(size=(11, 16), device="cuda") * 13).to(dtype)
# Test Per Tensor Dynamic quantization
# scale = max(abs(x)) / FP8_E4M3_MAX
y, scale = scaled_fp8_quant(x, None, use_per_token_if_dynamic=True)
ref_y = quantize_ref_per_token(x, scale)
torch.testing.assert_close(y, ref_y)
torch.testing.assert_close(
dequantize_per_token(y, scale, dtype),
dequantize_per_token(ref_y, scale, dtype),
)
if __name__ == "__main__":
# Run the specific test function directly
pytest.main([__file__])