298 lines
10 KiB
Python
298 lines
10 KiB
Python
"""
|
||
Copyright (c) 2024 by FlashInfer team.
|
||
|
||
Licensed under the Apache License, Version 2.0 (the "License");
|
||
you may not use this file except in compliance with the License.
|
||
You may obtain a copy of the License at
|
||
|
||
http://www.apache.org/licenses/LICENSE-2.0
|
||
|
||
Unless required by applicable law or agreed to in writing, software
|
||
distributed under the License is distributed on an "AS IS" BASIS,
|
||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
See the License for the specific language governing permissions and
|
||
limitations under the License.
|
||
"""
|
||
|
||
import numpy as np
|
||
import pytest
|
||
import scipy as sp
|
||
import torch
|
||
from jit_utils import gen_decode_attention_modules, gen_prefill_attention_modules
|
||
|
||
import flashinfer
|
||
|
||
|
||
@pytest.fixture(autouse=True, scope="module")
|
||
def warmup_jit():
|
||
flashinfer.jit.build_jit_specs(
|
||
gen_decode_attention_modules(
|
||
[torch.float16], # q_dtypes
|
||
[torch.float16], # kv_dtypes
|
||
[128, 256], # head_dims
|
||
[0], # pos_encoding_modes
|
||
[False], # use_sliding_windows
|
||
[False], # use_logits_soft_caps
|
||
)
|
||
+ gen_prefill_attention_modules(
|
||
[torch.float16], # q_dtypes
|
||
[torch.float16], # kv_dtypes
|
||
[128, 256], # head_dims
|
||
[0], # pos_encoding_modes
|
||
[False], # use_sliding_windows
|
||
[False], # use_logits_soft_caps
|
||
[False], # use_fp16_qk_reductions
|
||
),
|
||
verbose=False,
|
||
)
|
||
yield
|
||
|
||
|
||
def bsr_attention_ref(
|
||
q,
|
||
k,
|
||
v,
|
||
indptr,
|
||
indices,
|
||
mask_data,
|
||
):
|
||
M = q.shape[0]
|
||
N = k.shape[0]
|
||
bsr = sp.sparse.bsr_matrix(
|
||
(mask_data.cpu().numpy(), indices.cpu().numpy(), indptr.cpu().numpy()),
|
||
shape=(M, N),
|
||
)
|
||
dense_mask = torch.tensor(bsr.toarray(), dtype=bool, device=q.device)
|
||
o = flashinfer.prefill.single_prefill_with_kv_cache(q, k, v, custom_mask=dense_mask)
|
||
return o
|
||
|
||
|
||
def set_seed(seed: int = 42):
|
||
torch.cuda.manual_seed(seed)
|
||
torch.manual_seed(seed)
|
||
np.random.seed(seed)
|
||
|
||
|
||
@pytest.mark.parametrize("R", [1, 4, 16])
|
||
@pytest.mark.parametrize("C", [1, 4, 16])
|
||
@pytest.mark.parametrize("M", [64, 128, 256])
|
||
@pytest.mark.parametrize("N", [64, 128, 256])
|
||
@pytest.mark.parametrize("num_qo_heads", [1, 4, 16])
|
||
@pytest.mark.parametrize("num_kv_heads", [1, 4, 16])
|
||
@pytest.mark.parametrize("head_dim", [128, 256])
|
||
@pytest.mark.parametrize("mask_inside_block", [True, False])
|
||
def test_block_sparse_attention(
|
||
R, C, M, N, num_qo_heads, num_kv_heads, head_dim, mask_inside_block
|
||
):
|
||
if num_qo_heads % num_kv_heads != 0:
|
||
pytest.skip("num_qo_heads must be divisible by num_kv_heads")
|
||
|
||
set_seed(33)
|
||
rng = np.random.default_rng()
|
||
|
||
MB = M // R
|
||
NB = N // C
|
||
S = sp.sparse.random(MB, NB, density=0.25, random_state=rng).tocsr()
|
||
indptr = torch.from_numpy(S.indptr).to(0)
|
||
indices = torch.from_numpy(S.indices).to(0)
|
||
nnz = S.nnz
|
||
if mask_inside_block:
|
||
data_mask = (torch.rand((nnz, R, C)) > 0.5).to(0)
|
||
else:
|
||
data_mask = torch.full((nnz, R, C), True, dtype=bool, device=0)
|
||
q = torch.randn((M, num_qo_heads, head_dim), dtype=torch.float16, device=0)
|
||
k = torch.randn((N, num_kv_heads, head_dim), dtype=torch.float16, device=0)
|
||
v = torch.randn((N, num_kv_heads, head_dim), dtype=torch.float16, device=0)
|
||
|
||
o_ref = bsr_attention_ref(q, k, v, indptr, indices, data_mask)
|
||
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.uint8, device=0)
|
||
sparse_attention_wrapper = flashinfer.sparse.BlockSparseAttentionWrapper(
|
||
workspace_buffer
|
||
)
|
||
|
||
sparse_attention_wrapper.plan(
|
||
indptr,
|
||
indices,
|
||
M,
|
||
N,
|
||
R,
|
||
C,
|
||
num_qo_heads,
|
||
num_kv_heads,
|
||
head_dim,
|
||
mask=data_mask if mask_inside_block else None,
|
||
)
|
||
|
||
o = sparse_attention_wrapper.run(q, k, v)
|
||
torch.testing.assert_close(o_ref, o, atol=1e-2, rtol=1e-3)
|
||
|
||
# test with pre-allocated output
|
||
o_buffer = torch.empty_like(o)
|
||
sparse_attention_wrapper.run(q, k, v, out=o_buffer)
|
||
torch.testing.assert_close(o_ref, o_buffer, atol=1e-2, rtol=1e-3)
|
||
|
||
|
||
def _ref_attention(
|
||
q: torch.Tensor, # [gqa_group_size, qo_len, head_dim]
|
||
k: torch.Tensor, # [1, kv_len, head_dim]
|
||
v: torch.Tensor, # [1, kv_len, head_dim]
|
||
block_mask_map: torch.Tensor, # [MB, NB]
|
||
block_row_sz: torch.Tensor, # [MB]
|
||
block_col_sz: torch.Tensor, # [NB]
|
||
) -> torch.Tensor:
|
||
# convert block mask map to element mask
|
||
def _block_mask_to_element_mask(
|
||
block_mask_map: torch.Tensor, # [MB, NB] – bool
|
||
block_row_sz: torch.Tensor, # [MB] – int (rows per block-row)
|
||
block_col_sz: torch.Tensor, # [NB] – int (cols per block-col)
|
||
) -> torch.Tensor:
|
||
block_row_sz = block_row_sz.to(block_mask_map.device, dtype=torch.long)
|
||
block_col_sz = block_col_sz.to(block_mask_map.device, dtype=torch.long)
|
||
expanded_rows = torch.repeat_interleave(block_mask_map, block_row_sz, dim=0)
|
||
element_mask = torch.repeat_interleave(expanded_rows, block_col_sz, dim=1)
|
||
|
||
return element_mask
|
||
|
||
dense_mask = _block_mask_to_element_mask(
|
||
block_mask_map, block_row_sz, block_col_sz
|
||
).to(dtype=torch.bool, device=q.device)
|
||
|
||
q = q.transpose(0, 1).contiguous()
|
||
k = k.transpose(0, 1).contiguous()
|
||
v = v.transpose(0, 1).contiguous()
|
||
o = flashinfer.prefill.single_prefill_with_kv_cache(
|
||
q, k, v, custom_mask=dense_mask
|
||
) # [qo_len, gqa_group_size, head_dim]
|
||
o = o.transpose(0, 1).contiguous()
|
||
|
||
return o
|
||
|
||
|
||
@pytest.mark.parametrize("num_qo_heads", [1, 4, 16])
|
||
@pytest.mark.parametrize("num_kv_heads", [1, 4, 16])
|
||
@pytest.mark.parametrize("head_dim", [64, 128])
|
||
@pytest.mark.parametrize("seq_len", [256, 4096, 8192])
|
||
@pytest.mark.parametrize("num_blocks_row", [10, 20])
|
||
@pytest.mark.parametrize("num_blocks_col", [50, 100])
|
||
@pytest.mark.parametrize("block_density", [0.2, 0.7, 0.9])
|
||
def test_variable_block_sparse_attention_wrapper(
|
||
num_qo_heads: int,
|
||
num_kv_heads: int,
|
||
head_dim: int,
|
||
seq_len: int,
|
||
num_blocks_row: int,
|
||
num_blocks_col: int,
|
||
block_density: float,
|
||
):
|
||
if num_qo_heads % num_kv_heads != 0:
|
||
pytest.skip("num_qo_heads must be divisible by num_kv_heads")
|
||
if seq_len // num_blocks_row < 1:
|
||
pytest.skip("seq_len must be greater than num_blocks_row")
|
||
if seq_len // num_blocks_col < 1:
|
||
pytest.skip("seq_len must be greater than num_blocks_col")
|
||
|
||
set_seed(330)
|
||
|
||
def random_partition_batch(
|
||
seq_len: int,
|
||
num_blocks: int,
|
||
bsz: int,
|
||
device: torch.device | str = "cpu",
|
||
dtype: torch.dtype = torch.int32,
|
||
) -> torch.Tensor:
|
||
assert seq_len >= num_blocks
|
||
sizes = torch.empty((bsz, num_blocks), dtype=dtype, device=device)
|
||
for i in range(bsz):
|
||
cut_pts = torch.randperm(seq_len - 1, device=device)[: num_blocks - 1] + 1
|
||
cut_pts, _ = torch.sort(cut_pts)
|
||
row_sizes = torch.diff(
|
||
torch.cat(
|
||
(
|
||
torch.tensor([0], device=device),
|
||
cut_pts,
|
||
torch.tensor([seq_len], device=device),
|
||
)
|
||
)
|
||
)
|
||
sizes[i] = row_sizes
|
||
|
||
assert sizes.min() >= 1
|
||
assert sizes.max() <= seq_len
|
||
assert torch.all(sizes.sum(dim=-1) == seq_len)
|
||
|
||
return sizes.to(device=device)
|
||
|
||
def _test_variable_block_sparse_attention(
|
||
num_qo_heads: int,
|
||
num_kv_heads: int,
|
||
head_dim: int,
|
||
block_mask_map: torch.Tensor,
|
||
block_row_sz: torch.Tensor,
|
||
block_col_sz: torch.Tensor,
|
||
device: str = "cuda:0",
|
||
dtype: torch.dtype = torch.float16,
|
||
):
|
||
# qkv: HND
|
||
qo_len = block_row_sz.sum(dim=1)[0].item()
|
||
kv_len = block_col_sz.sum(dim=1)[0].item()
|
||
assert torch.all(block_col_sz.sum(dim=1) == block_col_sz.sum(dim=1)[0])
|
||
assert torch.all(block_row_sz.sum(dim=1) == block_row_sz.sum(dim=1)[0])
|
||
|
||
q = torch.randn(num_qo_heads, qo_len, head_dim, device=device, dtype=dtype)
|
||
k = torch.randn(num_kv_heads, kv_len, head_dim, device=device, dtype=dtype)
|
||
v = torch.randn(num_kv_heads, kv_len, head_dim, device=device, dtype=dtype)
|
||
|
||
float_workspace_buffer = torch.empty(128 * 1024 * 1024, device=device)
|
||
wrapper = flashinfer.sparse.VariableBlockSparseAttentionWrapper(
|
||
float_workspace_buffer, backend="auto"
|
||
)
|
||
|
||
wrapper.plan(
|
||
block_mask_map=block_mask_map,
|
||
block_row_sz=block_row_sz,
|
||
block_col_sz=block_col_sz,
|
||
num_qo_heads=num_qo_heads,
|
||
num_kv_heads=num_kv_heads,
|
||
head_dim=head_dim,
|
||
q_data_type=dtype,
|
||
)
|
||
|
||
o: torch.Tensor = wrapper.run(q, k, v) # [num_qo_heads, qo_len, head_dim]
|
||
o = o.reshape(num_kv_heads, -1, *o.shape[-2:])
|
||
q = q.reshape(num_kv_heads, -1, *q.shape[-2:])
|
||
for kv_head_idx in range(num_kv_heads):
|
||
o_ref = _ref_attention(
|
||
q[kv_head_idx],
|
||
k[kv_head_idx : kv_head_idx + 1, :, :],
|
||
v[kv_head_idx : kv_head_idx + 1, :, :],
|
||
block_mask_map[kv_head_idx],
|
||
block_row_sz[kv_head_idx],
|
||
block_col_sz[kv_head_idx],
|
||
)
|
||
torch.testing.assert_close(o[kv_head_idx], o_ref, atol=1e-2, rtol=1e-2)
|
||
|
||
block_row_sz = random_partition_batch(
|
||
seq_len, num_blocks_row, num_kv_heads, device="cuda:0"
|
||
)
|
||
block_col_sz = random_partition_batch(
|
||
seq_len, num_blocks_col, num_kv_heads, device="cuda:0"
|
||
)
|
||
block_mask_map = (
|
||
torch.rand(num_kv_heads, num_blocks_row, num_blocks_col) > block_density
|
||
).to(device="cuda:0")
|
||
|
||
_test_variable_block_sparse_attention(
|
||
num_qo_heads,
|
||
num_kv_heads,
|
||
head_dim,
|
||
block_mask_map,
|
||
block_row_sz,
|
||
block_col_sz,
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# This test verifies the INT32_T overflow issue.
|
||
for seq_len in [16 * 1024, 32 * 1024, 40 * 1024, 48 * 1024, 64 * 1024]:
|
||
test_block_sparse_attention(128, 128, seq_len, seq_len, 1, 1, 128, False)
|