62 lines
2.3 KiB
Python
62 lines
2.3 KiB
Python
import pytest
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from flashinfer import autotune, bmm_fp8
|
|
|
|
|
|
def to_float8(x, dtype=torch.float8_e4m3fn):
|
|
finfo = torch.finfo(dtype)
|
|
min_val, max_val = x.aminmax()
|
|
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
|
scale = finfo.max / amax
|
|
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
|
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
|
|
|
|
|
@pytest.mark.parametrize("b", [1, 16])
|
|
@pytest.mark.parametrize("m", [48, 128])
|
|
@pytest.mark.parametrize("n", [80, 64])
|
|
@pytest.mark.parametrize("k", [64, 256])
|
|
@pytest.mark.parametrize("input_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
|
|
@pytest.mark.parametrize("mat2_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
|
|
@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16])
|
|
@pytest.mark.parametrize("backend", ["cudnn", "cublas", "cutlass", "auto"])
|
|
@pytest.mark.parametrize("auto_tuning", [True, False])
|
|
def test_bmm_fp8(b, m, n, k, input_dtype, mat2_dtype, res_dtype, backend, auto_tuning):
|
|
if input_dtype == torch.float8_e5m2 and mat2_dtype == torch.float8_e5m2:
|
|
pytest.skip("Invalid combination: both input and mat2 are e5m2")
|
|
if input_dtype == torch.float8_e5m2 or mat2_dtype == torch.float8_e5m2:
|
|
if backend == "cutlass":
|
|
pytest.skip("Invalid combination: cutlass does not support e5m2")
|
|
if auto_tuning and backend != "cutlass":
|
|
pytest.skip("Invalid combination: auto_tuning only supported for cutlass")
|
|
|
|
input = torch.randn([b, m, k], device="cuda", dtype=torch.bfloat16)
|
|
input_fp8, input_inv_s = to_float8(input, dtype=input_dtype)
|
|
|
|
# mat2 row major -> column major
|
|
mat2 = torch.randn([b, n, k], device="cuda", dtype=torch.bfloat16).transpose(-2, -1)
|
|
mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=mat2_dtype)
|
|
reference = torch.bmm(input, mat2)
|
|
|
|
res = torch.empty([b, m, n], device="cuda", dtype=res_dtype)
|
|
|
|
with autotune(auto_tuning):
|
|
bmm_fp8(
|
|
input_fp8,
|
|
mat2_fp8,
|
|
input_inv_s,
|
|
mat2_inv_s,
|
|
res_dtype,
|
|
res,
|
|
backend=backend,
|
|
)
|
|
|
|
cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0)
|
|
assert cos_sim > 0.99
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__])
|