370 lines
11 KiB
Python
370 lines
11 KiB
Python
import math
|
|
|
|
import pytest
|
|
import torch
|
|
from conftest import VARLEN_INDPTR_PARAMS
|
|
|
|
import flashinfer
|
|
from flashinfer.utils import is_sm100a_supported, is_sm110a_supported
|
|
|
|
|
|
def attention_ref(
|
|
batch_size,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
causal: bool,
|
|
sm_scale: float,
|
|
) -> torch.Tensor:
|
|
qo_len = q.shape[0] // batch_size
|
|
kv_len = k.shape[0] // batch_size
|
|
num_qo_heads = q.shape[1]
|
|
head_dim_qk = q.shape[2]
|
|
head_dim_vo = v.shape[2]
|
|
logits = (
|
|
torch.einsum(
|
|
"bmhd,bnhd->bhmn",
|
|
q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).float(),
|
|
k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).float(),
|
|
)
|
|
* sm_scale
|
|
)
|
|
|
|
if causal:
|
|
mask = torch.arange(kv_len - qo_len, kv_len, device=q.device).unsqueeze(
|
|
1
|
|
) >= torch.arange(0, kv_len, device=q.device).unsqueeze(0)
|
|
else:
|
|
mask = torch.ones(qo_len, kv_len, device=q.device)
|
|
|
|
logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf"))
|
|
lse_ref = torch.logsumexp(logits, -1).transpose(-1, -2)
|
|
p = torch.softmax(logits, dim=-1)
|
|
o_ref = (
|
|
torch.einsum(
|
|
"bhmn,bnhd->bmhd",
|
|
p,
|
|
v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).float(),
|
|
)
|
|
.contiguous()
|
|
.view(batch_size * qo_len, num_qo_heads, head_dim_vo)
|
|
.to(q)
|
|
)
|
|
|
|
return o_ref, lse_ref * math.log2(math.e)
|
|
|
|
|
|
def attention_varlen_ref(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
qo_indptr: torch.Tensor,
|
|
kv_indptr: torch.Tensor,
|
|
causal: bool,
|
|
sm_scale: float,
|
|
) -> torch.Tensor:
|
|
batch_size = qo_indptr.shape[0] - 1
|
|
nnz_qo = qo_indptr[-1].item()
|
|
o = torch.empty(nnz_qo, *q.shape[1:-1], v.shape[-1], device=q.device, dtype=q.dtype)
|
|
lse = torch.empty(nnz_qo, q.shape[1], device=q.device, dtype=torch.float32)
|
|
|
|
for i in range(batch_size):
|
|
o_i, lse_i = attention_ref(
|
|
1,
|
|
q[qo_indptr[i] : qo_indptr[i + 1]],
|
|
k[kv_indptr[i] : kv_indptr[i + 1]],
|
|
v[kv_indptr[i] : kv_indptr[i + 1]],
|
|
causal,
|
|
sm_scale,
|
|
)
|
|
|
|
lse_i = lse_i.flatten(0, 1)
|
|
o[qo_indptr[i] : qo_indptr[i + 1]] = o_i
|
|
lse[qo_indptr[i] : qo_indptr[i + 1]] = lse_i
|
|
|
|
return o, lse
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 2, 3, 9, 17])
|
|
@pytest.mark.parametrize("qo_len", [1, 17, 177, 377, 977])
|
|
@pytest.mark.parametrize("kv_len", [1, 17, 544, 977, 1999])
|
|
@pytest.mark.parametrize("num_qo_heads", [32])
|
|
@pytest.mark.parametrize("num_kv_heads", [8, 32])
|
|
@pytest.mark.parametrize("head_dim_qk", [192, 128])
|
|
@pytest.mark.parametrize("head_dim_vo", [128])
|
|
@pytest.mark.parametrize("sm_scale", [1.0, 1.0 / math.sqrt(192), 1.0 / math.sqrt(128)])
|
|
@pytest.mark.parametrize("causal", [False, True])
|
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
|
def test_blackwell_cutlass_fmha(
|
|
batch_size,
|
|
qo_len,
|
|
kv_len,
|
|
num_qo_heads,
|
|
num_kv_heads,
|
|
head_dim_qk,
|
|
head_dim_vo,
|
|
sm_scale,
|
|
causal,
|
|
dtype,
|
|
):
|
|
if qo_len > kv_len and causal:
|
|
pytest.skip("qo_len > kv_len and causal is not supported")
|
|
|
|
if not is_sm100a_supported(torch.device("cuda")) and not is_sm110a_supported(
|
|
torch.device("cuda")
|
|
):
|
|
pytest.skip("only SM100A and SM110A are supported on this device")
|
|
torch.manual_seed(42)
|
|
q = torch.randn(
|
|
batch_size * qo_len, num_qo_heads, head_dim_qk, dtype=dtype, device="cuda"
|
|
)
|
|
qo_indptr = (
|
|
torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qo_len
|
|
)
|
|
k = torch.randn(
|
|
batch_size * kv_len, num_kv_heads, head_dim_qk, dtype=dtype, device="cuda"
|
|
)
|
|
v = torch.randn(
|
|
batch_size * kv_len, num_kv_heads, head_dim_vo, dtype=dtype, device="cuda"
|
|
)
|
|
kv_indptr = (
|
|
torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * kv_len
|
|
)
|
|
|
|
wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
|
|
torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8),
|
|
kv_layout="NHD",
|
|
backend="cutlass",
|
|
)
|
|
wrapper.plan(
|
|
qo_indptr,
|
|
kv_indptr,
|
|
num_qo_heads,
|
|
num_kv_heads,
|
|
head_dim_qk,
|
|
head_dim_vo=head_dim_vo,
|
|
causal=causal,
|
|
sm_scale=sm_scale,
|
|
q_data_type=dtype,
|
|
kv_data_type=dtype,
|
|
)
|
|
o, lse = wrapper.run(q, k, v, return_lse=True)
|
|
|
|
gqa_group_ratio = num_qo_heads // num_kv_heads
|
|
k_repeated = torch.repeat_interleave(k, gqa_group_ratio, dim=1)
|
|
v_repeated = torch.repeat_interleave(v, gqa_group_ratio, dim=1)
|
|
o_ref, lse_ref = attention_ref(
|
|
batch_size, q, k_repeated, v_repeated, causal, sm_scale
|
|
)
|
|
|
|
lse_ref = lse_ref.flatten(0, 1)
|
|
if dtype == torch.half:
|
|
torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)
|
|
else:
|
|
torch.testing.assert_close(o, o_ref, rtol=1e-2, atol=1e-2)
|
|
|
|
torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3)
|
|
|
|
|
|
@pytest.mark.parametrize("indptr", VARLEN_INDPTR_PARAMS)
|
|
@pytest.mark.parametrize("num_qo_heads", [32])
|
|
@pytest.mark.parametrize("num_kv_heads", [8, 32])
|
|
@pytest.mark.parametrize("head_dim_qk", [192, 128])
|
|
@pytest.mark.parametrize("head_dim_vo", [128])
|
|
@pytest.mark.parametrize("sm_scale", [1.0 / math.sqrt(128)])
|
|
@pytest.mark.parametrize("causal", [False, True])
|
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
|
def test_blackwell_cutlass_varlen(
|
|
indptr,
|
|
num_qo_heads,
|
|
num_kv_heads,
|
|
head_dim_qk,
|
|
head_dim_vo,
|
|
sm_scale,
|
|
causal,
|
|
dtype,
|
|
):
|
|
if not is_sm100a_supported(torch.device("cuda")) and not is_sm110a_supported(
|
|
torch.device("cuda")
|
|
):
|
|
pytest.skip("only SM100A and SM110A are supported on this device")
|
|
torch.manual_seed(42)
|
|
qkv = torch.randn(
|
|
indptr[-1],
|
|
(
|
|
num_qo_heads * head_dim_qk
|
|
+ num_kv_heads * head_dim_qk
|
|
+ num_kv_heads * head_dim_vo
|
|
),
|
|
dtype=dtype,
|
|
device="cuda",
|
|
)
|
|
q = qkv[:, : num_qo_heads * head_dim_qk].view(indptr[-1], num_qo_heads, head_dim_qk)
|
|
k = qkv[
|
|
:,
|
|
num_qo_heads * head_dim_qk : num_qo_heads * head_dim_qk
|
|
+ num_kv_heads * head_dim_qk,
|
|
].view(indptr[-1], num_kv_heads, head_dim_qk)
|
|
v = qkv[:, num_qo_heads * head_dim_qk + num_kv_heads * head_dim_qk :].view(
|
|
indptr[-1], num_kv_heads, head_dim_vo
|
|
)
|
|
qo_indptr = torch.tensor(indptr, device="cuda", dtype=torch.int32)
|
|
kv_indptr = qo_indptr
|
|
|
|
wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
|
|
torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8),
|
|
kv_layout="NHD",
|
|
backend="cutlass",
|
|
)
|
|
|
|
wrapper.plan(
|
|
qo_indptr,
|
|
kv_indptr,
|
|
num_qo_heads,
|
|
num_kv_heads,
|
|
head_dim_qk,
|
|
head_dim_vo=head_dim_vo,
|
|
causal=causal,
|
|
sm_scale=sm_scale,
|
|
q_data_type=dtype,
|
|
kv_data_type=dtype,
|
|
)
|
|
o, lse = wrapper.run(q, k, v, return_lse=True)
|
|
|
|
gqa_group_ratio = num_qo_heads // num_kv_heads
|
|
k_repeated = torch.repeat_interleave(k, gqa_group_ratio, dim=1)
|
|
v_repeated = torch.repeat_interleave(v, gqa_group_ratio, dim=1)
|
|
|
|
o_ref, lse_ref = attention_varlen_ref(
|
|
q, k_repeated, v_repeated, qo_indptr, kv_indptr, causal, sm_scale
|
|
)
|
|
|
|
if dtype == torch.half:
|
|
torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)
|
|
else:
|
|
torch.testing.assert_close(o, o_ref, rtol=1e-2, atol=1e-2)
|
|
|
|
torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3)
|
|
|
|
|
|
@pytest.mark.parametrize("qo_indptr_list", [[0, 10, 20, 30, 40, 50, 60, 100]])
|
|
@pytest.mark.parametrize("kv_indptr_list", [[0, 50, 50, 50, 50, 50, 50, 50]])
|
|
@pytest.mark.parametrize("num_qo_heads", [32])
|
|
@pytest.mark.parametrize("num_kv_heads", [8, 32])
|
|
@pytest.mark.parametrize("head_dim_qk", [192, 128])
|
|
@pytest.mark.parametrize("head_dim_vo", [128])
|
|
@pytest.mark.parametrize("sm_scale", [1.0 / math.sqrt(128)])
|
|
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
|
|
def test_blackwell_cutlass_qo_kv_varlen(
|
|
qo_indptr_list,
|
|
kv_indptr_list,
|
|
num_qo_heads,
|
|
num_kv_heads,
|
|
head_dim_qk,
|
|
head_dim_vo,
|
|
sm_scale,
|
|
dtype,
|
|
):
|
|
causal = False
|
|
if not is_sm100a_supported(torch.device("cuda")) and not is_sm110a_supported(
|
|
torch.device("cuda")
|
|
):
|
|
pytest.skip("only SM100A and SM110A are supported on this device")
|
|
torch.manual_seed(42)
|
|
q = torch.randn(
|
|
qo_indptr_list[-1],
|
|
num_qo_heads,
|
|
head_dim_qk,
|
|
dtype=dtype,
|
|
device="cuda",
|
|
)
|
|
k = torch.randn(
|
|
kv_indptr_list[-1],
|
|
num_kv_heads,
|
|
head_dim_qk,
|
|
dtype=dtype,
|
|
device="cuda",
|
|
)
|
|
v = torch.randn(
|
|
kv_indptr_list[-1],
|
|
num_kv_heads,
|
|
head_dim_vo,
|
|
dtype=dtype,
|
|
device="cuda",
|
|
)
|
|
|
|
qo_indptr = torch.tensor(qo_indptr_list, device="cuda", dtype=torch.int32)
|
|
kv_indptr = torch.tensor(kv_indptr_list, device="cuda", dtype=torch.int32)
|
|
|
|
wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
|
|
torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8),
|
|
kv_layout="NHD",
|
|
backend="cutlass",
|
|
)
|
|
|
|
wrapper.plan(
|
|
qo_indptr,
|
|
kv_indptr,
|
|
num_qo_heads,
|
|
num_kv_heads,
|
|
head_dim_qk,
|
|
head_dim_vo=head_dim_vo,
|
|
causal=causal,
|
|
sm_scale=sm_scale,
|
|
q_data_type=dtype,
|
|
kv_data_type=dtype,
|
|
)
|
|
o, lse = wrapper.run(q, k, v, return_lse=True)
|
|
|
|
gqa_group_ratio = num_qo_heads // num_kv_heads
|
|
k_repeated = torch.repeat_interleave(k, gqa_group_ratio, dim=1)
|
|
v_repeated = torch.repeat_interleave(v, gqa_group_ratio, dim=1)
|
|
|
|
o_ref, lse_ref = attention_varlen_ref(
|
|
q, k_repeated, v_repeated, qo_indptr, kv_indptr, causal, sm_scale
|
|
)
|
|
|
|
if dtype == torch.half:
|
|
torch.testing.assert_close(o[10:60], o_ref[10:60], rtol=1e-3, atol=1e-3)
|
|
else:
|
|
torch.testing.assert_close(o[10:60], o_ref[10:60], rtol=1e-2, atol=1e-2)
|
|
|
|
torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_blackwell_cutlass_fmha(
|
|
9,
|
|
377,
|
|
977,
|
|
1,
|
|
1,
|
|
192,
|
|
128,
|
|
1,
|
|
False,
|
|
torch.bfloat16,
|
|
)
|
|
|
|
test_blackwell_cutlass_varlen(
|
|
[0, 1274, 2568, 3915, 5194, 6498, 7839, 8192],
|
|
32,
|
|
4,
|
|
128,
|
|
128,
|
|
1,
|
|
True,
|
|
torch.bfloat16,
|
|
)
|
|
|
|
test_blackwell_cutlass_qo_kv_varlen(
|
|
[0, 10, 20, 30, 40, 50, 60, 100],
|
|
[0, 50, 50, 50, 50, 50, 50, 50],
|
|
32,
|
|
8,
|
|
128,
|
|
128,
|
|
1,
|
|
torch.bfloat16,
|
|
)
|