739 lines
24 KiB
Python
739 lines
24 KiB
Python
"""
|
|
Copyright (c) 2023 by FlashInfer team.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import math
|
|
|
|
import pytest
|
|
import torch
|
|
from conftest import clear_cuda_cache
|
|
|
|
import flashinfer
|
|
from flashinfer.jit import build_jit_specs
|
|
from flashinfer.jit.attention import (
|
|
gen_batch_mla_module,
|
|
gen_batch_prefill_module,
|
|
gen_single_prefill_module,
|
|
)
|
|
from flashinfer.utils import (
|
|
is_sm90a_supported,
|
|
is_sm100a_supported,
|
|
is_sm110a_supported,
|
|
)
|
|
|
|
|
|
@pytest.fixture(autouse=True, scope="module")
|
|
def warmup_jit():
|
|
try:
|
|
modules = []
|
|
for backend in ["fa2", "fa3"]:
|
|
if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")):
|
|
continue
|
|
|
|
modules.append(
|
|
gen_single_prefill_module(
|
|
backend,
|
|
torch.float16,
|
|
torch.float16,
|
|
torch.float16,
|
|
192,
|
|
128,
|
|
0,
|
|
False,
|
|
False,
|
|
False,
|
|
)
|
|
)
|
|
|
|
for backend in ["fa2", "fa3"]:
|
|
if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")):
|
|
continue
|
|
|
|
modules.append(
|
|
gen_batch_prefill_module(
|
|
backend,
|
|
torch.float16,
|
|
torch.float16,
|
|
torch.float16,
|
|
torch.int32,
|
|
192,
|
|
128,
|
|
0,
|
|
False,
|
|
False,
|
|
False,
|
|
)
|
|
)
|
|
|
|
for backend in ["fa2", "fa3"]:
|
|
if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")):
|
|
continue
|
|
|
|
modules.append(
|
|
gen_batch_mla_module(
|
|
backend,
|
|
torch.float16,
|
|
torch.float16,
|
|
torch.float16,
|
|
torch.int32,
|
|
512,
|
|
64,
|
|
False,
|
|
)
|
|
)
|
|
|
|
build_jit_specs(modules, verbose=False)
|
|
except Exception as e:
|
|
# abort the test session if warmup fails
|
|
pytest.exit(str(e))
|
|
finally:
|
|
yield
|
|
|
|
|
|
def attention_ref(
|
|
batch_size,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
causal: bool,
|
|
sm_scale: float,
|
|
) -> torch.Tensor:
|
|
qo_len = q.shape[0] // batch_size
|
|
kv_len = k.shape[0] // batch_size
|
|
num_qo_heads = q.shape[1]
|
|
head_dim_qk = q.shape[2]
|
|
head_dim_vo = v.shape[2]
|
|
logits = (
|
|
torch.einsum(
|
|
"bmhd,bnhd->bhmn",
|
|
q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).float(),
|
|
k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).float(),
|
|
)
|
|
* sm_scale
|
|
)
|
|
|
|
if causal:
|
|
mask = torch.arange(kv_len - qo_len, kv_len, device=q.device).unsqueeze(
|
|
1
|
|
) >= torch.arange(0, kv_len, device=q.device).unsqueeze(0)
|
|
else:
|
|
mask = torch.ones(qo_len, kv_len, device=q.device)
|
|
|
|
logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf"))
|
|
lse_ref = torch.logsumexp(logits, -1).transpose(-1, -2)
|
|
p = torch.softmax(logits, dim=-1)
|
|
o_ref = (
|
|
torch.einsum(
|
|
"bhmn,bnhd->bmhd",
|
|
p,
|
|
v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).float(),
|
|
)
|
|
.contiguous()
|
|
.view(batch_size * qo_len, num_qo_heads, head_dim_vo)
|
|
.to(q)
|
|
)
|
|
|
|
return o_ref, lse_ref * math.log2(math.e)
|
|
|
|
|
|
@pytest.mark.parametrize("kv_len", [5532, 7563])
|
|
@pytest.mark.parametrize("qo_len", [1832, 3928])
|
|
@pytest.mark.parametrize("num_heads", [4, 32, 128])
|
|
@pytest.mark.parametrize("causal", [False, True])
|
|
@pytest.mark.parametrize("backend", ["fa2", "fa3"])
|
|
@pytest.mark.parametrize("dtype", [torch.half])
|
|
def test_single_prefill_with_kv_cache(
|
|
kv_len,
|
|
qo_len,
|
|
num_heads,
|
|
causal,
|
|
backend,
|
|
dtype,
|
|
):
|
|
device = torch.device("cuda:0")
|
|
clear_cuda_cache(device)
|
|
if backend == "fa3" and not is_sm90a_supported(device):
|
|
pytest.skip("FA3 is not supported on this device")
|
|
if is_sm110a_supported(device) and num_heads * kv_len > 700000:
|
|
pytest.skip("skip large tests on Thor due to memory limit")
|
|
torch.manual_seed(42)
|
|
head_dim_qk = 192
|
|
head_dim_vo = 128
|
|
q = torch.randn(qo_len, num_heads, head_dim_qk, dtype=dtype, device=device)
|
|
k = torch.randn(kv_len, num_heads, head_dim_qk, dtype=dtype, device=device)
|
|
v = torch.randn(kv_len, num_heads, head_dim_vo, dtype=dtype, device=device)
|
|
o, lse = flashinfer.single_prefill_with_kv_cache(
|
|
q, k, v, causal=causal, backend=backend, return_lse=True
|
|
)
|
|
sm_scale = 1.0 / (head_dim_qk**0.5)
|
|
|
|
o_ref, lse_ref = attention_ref(1, q, k, v, causal, sm_scale)
|
|
torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)
|
|
torch.testing.assert_close(lse, lse_ref.squeeze(0), rtol=1e-3, atol=1e-3)
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [12, 17])
|
|
@pytest.mark.parametrize("kv_len", [544, 977])
|
|
@pytest.mark.parametrize("qo_len", [377, 177])
|
|
@pytest.mark.parametrize("num_heads", [4, 32, 128])
|
|
@pytest.mark.parametrize("causal", [False, True])
|
|
@pytest.mark.parametrize("backend", ["fa2", "fa3"])
|
|
@pytest.mark.parametrize("dtype", [torch.half])
|
|
def test_batch_prefill_with_ragged_kv_cache(
|
|
batch_size,
|
|
kv_len,
|
|
qo_len,
|
|
num_heads,
|
|
causal,
|
|
backend,
|
|
dtype,
|
|
):
|
|
device = torch.device("cuda:0")
|
|
clear_cuda_cache(device)
|
|
if backend == "fa3" and not is_sm90a_supported(device):
|
|
pytest.skip("FA3 is not supported on this device")
|
|
torch.manual_seed(42)
|
|
kv_layout = "NHD"
|
|
head_dim_qk = 192
|
|
head_dim_vo = 128
|
|
q = torch.randn(
|
|
batch_size * qo_len, num_heads, head_dim_qk, dtype=dtype, device=device
|
|
)
|
|
q_indptr = (
|
|
torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) * qo_len
|
|
)
|
|
|
|
k = torch.zeros(
|
|
batch_size * kv_len, num_heads, head_dim_qk, dtype=dtype, device=device
|
|
)
|
|
v = torch.zeros(
|
|
batch_size * kv_len, num_heads, head_dim_vo, dtype=dtype, device=device
|
|
)
|
|
kv_indptr = (
|
|
torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) * kv_len
|
|
)
|
|
|
|
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
|
|
wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
|
|
workspace_buffer, kv_layout, backend=backend
|
|
)
|
|
wrapper.plan(
|
|
q_indptr,
|
|
kv_indptr,
|
|
num_heads,
|
|
num_heads,
|
|
head_dim_qk,
|
|
head_dim_vo=head_dim_vo,
|
|
causal=causal,
|
|
)
|
|
o, lse = wrapper.run_return_lse(q, k, v)
|
|
|
|
sm_scale = 1.0 / (head_dim_qk**0.5)
|
|
o_ref, lse_ref = attention_ref(batch_size, q, k, v, causal, sm_scale)
|
|
|
|
lse_ref = lse_ref.flatten(0, 1)
|
|
torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)
|
|
torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3)
|
|
|
|
# test with pre-allocated output
|
|
o_buffer = torch.empty_like(o)
|
|
lse_buffer = torch.empty_like(lse)
|
|
wrapper.run(q, k, v, out=o_buffer, lse=lse_buffer)
|
|
torch.testing.assert_close(o, o_buffer, rtol=1e-3, atol=1e-3)
|
|
torch.testing.assert_close(lse, lse_buffer, rtol=1e-3, atol=1e-3)
|
|
|
|
|
|
def generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads):
|
|
bs_page_num, page_size, ckv_dim = ckv.shape
|
|
page_num = bs_page_num // batch_size
|
|
_, _, kpe_dim = kpe.shape
|
|
ckv = ckv.view(batch_size, page_num * page_size, ckv_dim)
|
|
kpe = kpe.view(batch_size, page_num * page_size, kpe_dim)
|
|
ckv = ckv[:, :kv_len, :]
|
|
kpe = kpe[:, :kv_len, :]
|
|
k = (
|
|
torch.cat([ckv, kpe], dim=-1)
|
|
.view(-1, 1, ckv_dim + kpe_dim)
|
|
.repeat_interleave(num_heads, dim=1)
|
|
)
|
|
v = ckv.repeat_interleave(num_heads, dim=1)
|
|
|
|
return k, v
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 3, 5, 7])
|
|
@pytest.mark.parametrize("kv_len_0", [0, 1, 3, 11])
|
|
@pytest.mark.parametrize("kv_len_1", [17, 33, 79, 114])
|
|
@pytest.mark.parametrize("kv_len_2", [514, 2743, 8736])
|
|
@pytest.mark.parametrize("qo_len", [1, 3, 5, 7, 9, 11, 13, 15, 17])
|
|
@pytest.mark.parametrize("num_heads", [16, 64])
|
|
@pytest.mark.parametrize("causal", [False, True])
|
|
@pytest.mark.parametrize("page_size", [1])
|
|
@pytest.mark.parametrize("backend", ["fa2", "fa3"])
|
|
@pytest.mark.parametrize("dtype", [torch.half])
|
|
def test_batch_mla_varlen_page_attention(
|
|
batch_size,
|
|
kv_len_0,
|
|
kv_len_1,
|
|
kv_len_2,
|
|
qo_len,
|
|
num_heads,
|
|
causal,
|
|
page_size,
|
|
backend,
|
|
dtype,
|
|
):
|
|
device = torch.device("cuda:0")
|
|
clear_cuda_cache(device)
|
|
if backend == "fa3" and not is_sm90a_supported(device):
|
|
pytest.skip("FA3 is not supported on this device")
|
|
if causal and qo_len > min(kv_len_0, kv_len_1, kv_len_2):
|
|
pytest.skip("qo_len > kv_len not supported for causal attention")
|
|
num_different_kv_len = 3
|
|
kv_lens = torch.tensor([kv_len_0, kv_len_1, kv_len_2], dtype=torch.int32)
|
|
torch.manual_seed(42)
|
|
head_dim_ckv = 512
|
|
head_dim_kpe = 64
|
|
q_nope = torch.randn(
|
|
num_different_kv_len * batch_size * qo_len,
|
|
num_heads,
|
|
head_dim_ckv,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
q_pe = torch.randn(
|
|
num_different_kv_len * batch_size * qo_len,
|
|
num_heads,
|
|
head_dim_kpe,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
pages_nums = torch.tensor(
|
|
[math.ceil(kv_len / page_size) for kv_len in kv_lens],
|
|
dtype=torch.int32,
|
|
)
|
|
pages_nums_indptr = torch.zeros(num_different_kv_len + 1, dtype=torch.int32)
|
|
pages_nums_indptr[1:] = pages_nums.cumsum(0)
|
|
pages_nums_sum = pages_nums_indptr[-1]
|
|
ckv = torch.randn(
|
|
batch_size * pages_nums_sum,
|
|
page_size,
|
|
head_dim_ckv,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
kpe = torch.randn(
|
|
batch_size * pages_nums_sum,
|
|
page_size,
|
|
head_dim_kpe,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
sm_scale = 1.0 / ((128 + 64) ** 0.5)
|
|
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
|
|
wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
|
|
workspace_buffer, backend=backend
|
|
)
|
|
q_indptr = (
|
|
torch.arange(
|
|
0, num_different_kv_len * batch_size + 1, device=device, dtype=torch.int32
|
|
)
|
|
* qo_len
|
|
)
|
|
kv_indptr = torch.cat(
|
|
[
|
|
torch.arange(0, batch_size + 1).unsqueeze(-1).int() * pages_nums_sum
|
|
+ pages_nums_indptr[i]
|
|
for i in range(num_different_kv_len)
|
|
],
|
|
dim=-1,
|
|
).flatten()
|
|
kv_indices = torch.arange(
|
|
0, batch_size * pages_nums_sum, device=device, dtype=torch.int32
|
|
)
|
|
kv_lens = torch.tensor(kv_lens, dtype=torch.int32, device=device).repeat(batch_size)
|
|
wrapper.plan(
|
|
q_indptr,
|
|
kv_indptr,
|
|
kv_indices,
|
|
kv_lens,
|
|
num_heads,
|
|
head_dim_ckv,
|
|
head_dim_kpe,
|
|
page_size,
|
|
causal,
|
|
sm_scale,
|
|
q_nope.dtype,
|
|
ckv.dtype,
|
|
)
|
|
o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True)
|
|
|
|
q_rows = (
|
|
torch.arange(0, num_different_kv_len * qo_len)[None, :]
|
|
+ torch.arange(0, batch_size)[:, None] * num_different_kv_len * qo_len
|
|
).int()
|
|
kv_rows = (
|
|
torch.arange(0, pages_nums_sum)[None, :]
|
|
+ torch.arange(0, batch_size)[:, None] * pages_nums_sum
|
|
).int()
|
|
q_rows_arr = [
|
|
q_rows[:, i * qo_len : (i + 1) * qo_len].flatten()
|
|
for i in range(num_different_kv_len)
|
|
]
|
|
kv_rows_arr = [
|
|
kv_rows[:, pages_nums_indptr[i] : pages_nums_indptr[i + 1]].flatten()
|
|
for i in range(num_different_kv_len)
|
|
]
|
|
for i in range(num_different_kv_len):
|
|
k, v = generate_kv_from_cache(
|
|
ckv[kv_rows_arr[i]], kpe[kv_rows_arr[i]], kv_lens[i], batch_size, num_heads
|
|
)
|
|
q = torch.cat([q_nope, q_pe], dim=-1)[q_rows_arr[i]]
|
|
o_ref, lse_ref = attention_ref(batch_size, q, k, v, causal, sm_scale)
|
|
lse_ref = lse_ref.flatten(0, 1)
|
|
o_i = o[q_rows_arr[i]]
|
|
torch.testing.assert_close(o_i, o_ref, rtol=1e-3, atol=1e-3)
|
|
# if kv_lens[i] != 0:
|
|
# torch.testing.assert_close(lse_i, lse_ref, rtol=1e-3, atol=1e-3)
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5, 6, 7, 157])
|
|
@pytest.mark.parametrize("kv_len", [17, 33, 75, 197])
|
|
@pytest.mark.parametrize("qo_len", [3, 7, 17])
|
|
@pytest.mark.parametrize("num_heads", [16])
|
|
@pytest.mark.parametrize("causal", [False, True])
|
|
@pytest.mark.parametrize("page_size", [16, 32])
|
|
@pytest.mark.parametrize("backend", ["fa2", "fa3"])
|
|
@pytest.mark.parametrize("dtype", [torch.half])
|
|
def test_batch_mla_oob_kv_nan(
|
|
batch_size, kv_len, qo_len, num_heads, causal, page_size, backend, dtype
|
|
):
|
|
device = torch.device("cuda:0")
|
|
clear_cuda_cache(device)
|
|
if backend == "fa3" and not is_sm90a_supported(device):
|
|
pytest.skip("FA3 is not supported on this device")
|
|
if causal and qo_len > kv_len:
|
|
pytest.skip("qo_len > kv_len not supported for causal attention")
|
|
torch.manual_seed(42)
|
|
head_dim_ckv = 512
|
|
head_dim_kpe = 64
|
|
q_nope = torch.randn(
|
|
batch_size * qo_len, num_heads, head_dim_ckv, dtype=dtype, device=device
|
|
)
|
|
q_pe = torch.randn(
|
|
batch_size * qo_len, num_heads, head_dim_kpe, dtype=dtype, device=device
|
|
)
|
|
pages_num = math.ceil(kv_len / page_size)
|
|
ckv = torch.randn(
|
|
batch_size * pages_num, page_size, head_dim_ckv, dtype=dtype, device=device
|
|
)
|
|
kpe = torch.randn(
|
|
batch_size * pages_num, page_size, head_dim_kpe, dtype=dtype, device=device
|
|
)
|
|
|
|
# Fill oob positions with nan
|
|
for i in range(batch_size):
|
|
last_page_len = kv_len - (pages_num - 1) * page_size
|
|
ckv[(i + 1) * pages_num - 1, last_page_len:, :] = float("nan")
|
|
kpe[(i + 1) * pages_num - 1, last_page_len:, :] = float("nan")
|
|
|
|
sm_scale = 1.0 / ((128 + 64) ** 0.5)
|
|
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
|
|
wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
|
|
workspace_buffer, backend=backend
|
|
)
|
|
q_indptr = (
|
|
torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) * qo_len
|
|
)
|
|
kv_indptr = (
|
|
torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) * pages_num
|
|
)
|
|
kv_indices = torch.arange(
|
|
0, batch_size * pages_num, device=device, dtype=torch.int32
|
|
)
|
|
kv_lens = torch.full((batch_size,), kv_len, dtype=torch.int32, device=device)
|
|
|
|
wrapper.plan(
|
|
q_indptr,
|
|
kv_indptr,
|
|
kv_indices,
|
|
kv_lens,
|
|
num_heads,
|
|
head_dim_ckv,
|
|
head_dim_kpe,
|
|
page_size,
|
|
causal,
|
|
sm_scale,
|
|
q_nope.dtype,
|
|
ckv.dtype,
|
|
)
|
|
o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True)
|
|
|
|
k, v = generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads)
|
|
|
|
q = torch.cat([q_nope, q_pe], dim=-1)
|
|
o_ref, lse_ref = attention_ref(batch_size, q, k, v, causal, sm_scale)
|
|
lse_ref = lse_ref.flatten(0, 1)
|
|
torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)
|
|
if kv_len != 0:
|
|
torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3)
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 3, 5, 7, 157])
|
|
@pytest.mark.parametrize("kv_len", [0, 17, 33, 96, 97, 114, 514, 1024])
|
|
@pytest.mark.parametrize("qo_len", [1, 3, 5, 7, 9, 11, 13, 15, 17])
|
|
@pytest.mark.parametrize("num_heads", [16])
|
|
@pytest.mark.parametrize("causal", [False, True])
|
|
@pytest.mark.parametrize("page_size", [1, 16])
|
|
@pytest.mark.parametrize("backend", ["fa2", "fa3"])
|
|
@pytest.mark.parametrize("use_cuda_graph", [False])
|
|
@pytest.mark.parametrize("dtype", [torch.half])
|
|
def test_batch_mla_page_attention(
|
|
batch_size,
|
|
kv_len,
|
|
qo_len,
|
|
num_heads,
|
|
causal,
|
|
page_size,
|
|
backend,
|
|
use_cuda_graph,
|
|
dtype,
|
|
):
|
|
device = torch.device("cuda:0")
|
|
clear_cuda_cache(device)
|
|
if backend == "fa3" and not is_sm90a_supported(device):
|
|
pytest.skip("FA3 is not supported on this device")
|
|
if causal and qo_len > kv_len:
|
|
pytest.skip("qo_len > kv_len not supported for causal attention")
|
|
torch.manual_seed(42)
|
|
head_dim_ckv = 512
|
|
head_dim_kpe = 64
|
|
q_nope = torch.randn(
|
|
batch_size * qo_len, num_heads, head_dim_ckv, dtype=dtype, device=device
|
|
)
|
|
q_pe = torch.randn(
|
|
batch_size * qo_len, num_heads, head_dim_kpe, dtype=dtype, device=device
|
|
)
|
|
pages_num = math.ceil(kv_len / page_size)
|
|
ckv = torch.randn(
|
|
batch_size * pages_num,
|
|
page_size,
|
|
head_dim_ckv,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
kpe = torch.randn(
|
|
batch_size * pages_num,
|
|
page_size,
|
|
head_dim_kpe,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
sm_scale = 1.0 / ((128 + 64) ** 0.5) # use head dimension before matrix absorption
|
|
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
|
|
wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
|
|
workspace_buffer,
|
|
backend=backend,
|
|
use_cuda_graph=True,
|
|
qo_indptr=torch.empty(batch_size + 1, dtype=torch.int32, device=device),
|
|
kv_indptr=torch.empty(batch_size + 1, dtype=torch.int32, device=device),
|
|
kv_indices=torch.empty(1048576, dtype=torch.int32, device=device),
|
|
kv_len_arr=torch.empty(batch_size, dtype=torch.int32, device=device),
|
|
)
|
|
q_indptr = (
|
|
torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) * qo_len
|
|
)
|
|
kv_indptr = (
|
|
torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) * pages_num
|
|
)
|
|
kv_indices = torch.arange(
|
|
0, batch_size * pages_num, device=device, dtype=torch.int32
|
|
)
|
|
kv_lens = torch.full((batch_size,), kv_len, dtype=torch.int32, device=device)
|
|
|
|
if use_cuda_graph:
|
|
kv_indptr_warmup = torch.zeros(batch_size + 1, device=device, dtype=torch.int32)
|
|
kv_indices_warmup = torch.arange(
|
|
0, batch_size, device=device, dtype=torch.int32
|
|
)
|
|
kv_lens_warmup = torch.full((batch_size,), 0, dtype=torch.int32, device=device)
|
|
wrapper.plan(
|
|
q_indptr,
|
|
kv_indptr_warmup,
|
|
kv_indices_warmup,
|
|
kv_lens_warmup,
|
|
num_heads,
|
|
head_dim_ckv,
|
|
head_dim_kpe,
|
|
page_size,
|
|
causal,
|
|
sm_scale,
|
|
q_nope.dtype,
|
|
ckv.dtype,
|
|
)
|
|
|
|
# warmup
|
|
s = torch.cuda.Stream()
|
|
s.wait_stream(torch.cuda.current_stream())
|
|
with torch.cuda.stream(s):
|
|
for _ in range(3):
|
|
o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True)
|
|
torch.cuda.current_stream().wait_stream(s)
|
|
|
|
# capture
|
|
g = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(g):
|
|
o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True)
|
|
|
|
wrapper.plan(
|
|
q_indptr,
|
|
kv_indptr,
|
|
kv_indices,
|
|
kv_lens,
|
|
num_heads,
|
|
head_dim_ckv,
|
|
head_dim_kpe,
|
|
page_size,
|
|
causal,
|
|
sm_scale,
|
|
q_nope.dtype,
|
|
ckv.dtype,
|
|
)
|
|
if use_cuda_graph:
|
|
o.fill_(0)
|
|
lse.fill_(0)
|
|
g.replay()
|
|
else:
|
|
o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True)
|
|
|
|
k, v = generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads)
|
|
|
|
q = torch.cat([q_nope, q_pe], dim=-1)
|
|
o_ref, lse_ref = attention_ref(batch_size, q, k, v, causal, sm_scale)
|
|
lse_ref = lse_ref.flatten(0, 1)
|
|
torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)
|
|
if kv_len != 0:
|
|
torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3)
|
|
|
|
# test with pre-allocated output
|
|
o_buffer = torch.empty_like(o)
|
|
lse_buffer = torch.empty_like(lse)
|
|
wrapper.run(q_nope, q_pe, ckv, kpe, out=o_buffer, lse=lse_buffer)
|
|
torch.testing.assert_close(o, o_buffer, rtol=1e-3, atol=1e-3)
|
|
torch.testing.assert_close(lse, lse_buffer, rtol=1e-3, atol=1e-3)
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 2, 4])
|
|
@pytest.mark.parametrize("max_seq_len", [128, 1024, 4096])
|
|
@pytest.mark.parametrize("page_size", [1, 16, 128])
|
|
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half])
|
|
def test_cutlass_mla(batch_size, max_seq_len, page_size, dtype):
|
|
device = torch.device("cuda:0")
|
|
clear_cuda_cache(device)
|
|
if not is_sm100a_supported(device) and not is_sm110a_supported(device):
|
|
pytest.skip("Cutlass MLA is not supported on this device")
|
|
|
|
torch.manual_seed(42)
|
|
|
|
num_local_heads = 128
|
|
head_dim_ckv = 512
|
|
head_dim_kpe = 64
|
|
total_page_num = 8192
|
|
|
|
# NOTE(Zihao): use larger scale to detect bugs such as
|
|
# https://github.com/flashinfer-ai/flashinfer/pull/1055
|
|
q_nope_pe = (
|
|
torch.randn(
|
|
batch_size,
|
|
num_local_heads,
|
|
head_dim_ckv + head_dim_kpe,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
* 100
|
|
)
|
|
ckv_kpe = torch.randn(
|
|
total_page_num,
|
|
page_size,
|
|
head_dim_ckv + head_dim_kpe,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
kv_lens = torch.full((batch_size,), max_seq_len, dtype=torch.int32, device=device)
|
|
page_num_per_batch = (max_seq_len + page_size - 1) // page_size
|
|
# Cutlass MLA requires small pages (< 128) are packed into a 128 page.
|
|
assert page_num_per_batch % (128 // page_size) == 0
|
|
page_table = torch.randint(
|
|
0,
|
|
total_page_num,
|
|
(batch_size, page_num_per_batch),
|
|
dtype=torch.int32,
|
|
device=device,
|
|
)
|
|
|
|
mla_ref = flashinfer.mla.BatchMLAPagedAttentionWrapper(
|
|
torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device), backend="fa2"
|
|
)
|
|
|
|
# for decode, each query length is 1
|
|
q_indptr = torch.arange(0, batch_size + 1, device=device, dtype=torch.int32)
|
|
kv_lens = torch.full((batch_size,), max_seq_len, dtype=torch.int32, device=device)
|
|
kv_indptr = (
|
|
torch.arange(0, batch_size + 1, device=device, dtype=torch.int32)
|
|
* page_num_per_batch
|
|
)
|
|
kv_indices = page_table.flatten()
|
|
|
|
q_nope = q_nope_pe[..., :head_dim_ckv]
|
|
q_pe = q_nope_pe[..., head_dim_ckv:]
|
|
ckv = ckv_kpe[..., :head_dim_ckv]
|
|
kpe = ckv_kpe[..., head_dim_ckv:]
|
|
|
|
# use head dimension before matrix absorption
|
|
sm_scale = 1.0 / ((128 + 64) ** 0.5)
|
|
mla_ref.plan(
|
|
q_indptr,
|
|
kv_indptr,
|
|
kv_indices,
|
|
kv_lens,
|
|
num_local_heads,
|
|
head_dim_ckv,
|
|
head_dim_kpe,
|
|
page_size,
|
|
False, # causal
|
|
sm_scale,
|
|
q_nope.dtype,
|
|
ckv.dtype,
|
|
)
|
|
|
|
o_ref = mla_ref.run(q_nope, q_pe, ckv, kpe, return_lse=False)
|
|
|
|
mla_ans = flashinfer.mla.BatchMLAPagedAttentionWrapper(
|
|
torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device),
|
|
backend="cutlass",
|
|
)
|
|
o_ans = mla_ans.run(q_nope, q_pe, ckv, kpe, kv_len=kv_lens, page_table=page_table)
|
|
torch.testing.assert_close(o_ans, o_ref, rtol=1e-2, atol=1e-2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_batch_mla_varlen_page_attention(
|
|
1, 65, 65, 65, 1, 128, True, 64, "fa2", torch.half
|
|
)
|
|
# test_batch_mla_varlen_page_attention(
|
|
# 155, 1024, 8, 128, 128, 16, False, 1, "fa3", torch.half
|
|
# )
|
|
# test_batch_mla_page_attention(1, 1024, 128, 128, False, 1, "fa2", True, torch.half)
|