sglang_v0.5.2/flashinfer_0.3.1/tests/test_block_sparse.py

298 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Copyright (c) 2024 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
import pytest
import scipy as sp
import torch
from jit_utils import gen_decode_attention_modules, gen_prefill_attention_modules
import flashinfer
@pytest.fixture(autouse=True, scope="module")
def warmup_jit():
flashinfer.jit.build_jit_specs(
gen_decode_attention_modules(
[torch.float16], # q_dtypes
[torch.float16], # kv_dtypes
[128, 256], # head_dims
[0], # pos_encoding_modes
[False], # use_sliding_windows
[False], # use_logits_soft_caps
)
+ gen_prefill_attention_modules(
[torch.float16], # q_dtypes
[torch.float16], # kv_dtypes
[128, 256], # head_dims
[0], # pos_encoding_modes
[False], # use_sliding_windows
[False], # use_logits_soft_caps
[False], # use_fp16_qk_reductions
),
verbose=False,
)
yield
def bsr_attention_ref(
q,
k,
v,
indptr,
indices,
mask_data,
):
M = q.shape[0]
N = k.shape[0]
bsr = sp.sparse.bsr_matrix(
(mask_data.cpu().numpy(), indices.cpu().numpy(), indptr.cpu().numpy()),
shape=(M, N),
)
dense_mask = torch.tensor(bsr.toarray(), dtype=bool, device=q.device)
o = flashinfer.prefill.single_prefill_with_kv_cache(q, k, v, custom_mask=dense_mask)
return o
def set_seed(seed: int = 42):
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
@pytest.mark.parametrize("R", [1, 4, 16])
@pytest.mark.parametrize("C", [1, 4, 16])
@pytest.mark.parametrize("M", [64, 128, 256])
@pytest.mark.parametrize("N", [64, 128, 256])
@pytest.mark.parametrize("num_qo_heads", [1, 4, 16])
@pytest.mark.parametrize("num_kv_heads", [1, 4, 16])
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("mask_inside_block", [True, False])
def test_block_sparse_attention(
R, C, M, N, num_qo_heads, num_kv_heads, head_dim, mask_inside_block
):
if num_qo_heads % num_kv_heads != 0:
pytest.skip("num_qo_heads must be divisible by num_kv_heads")
set_seed(33)
rng = np.random.default_rng()
MB = M // R
NB = N // C
S = sp.sparse.random(MB, NB, density=0.25, random_state=rng).tocsr()
indptr = torch.from_numpy(S.indptr).to(0)
indices = torch.from_numpy(S.indices).to(0)
nnz = S.nnz
if mask_inside_block:
data_mask = (torch.rand((nnz, R, C)) > 0.5).to(0)
else:
data_mask = torch.full((nnz, R, C), True, dtype=bool, device=0)
q = torch.randn((M, num_qo_heads, head_dim), dtype=torch.float16, device=0)
k = torch.randn((N, num_kv_heads, head_dim), dtype=torch.float16, device=0)
v = torch.randn((N, num_kv_heads, head_dim), dtype=torch.float16, device=0)
o_ref = bsr_attention_ref(q, k, v, indptr, indices, data_mask)
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.uint8, device=0)
sparse_attention_wrapper = flashinfer.sparse.BlockSparseAttentionWrapper(
workspace_buffer
)
sparse_attention_wrapper.plan(
indptr,
indices,
M,
N,
R,
C,
num_qo_heads,
num_kv_heads,
head_dim,
mask=data_mask if mask_inside_block else None,
)
o = sparse_attention_wrapper.run(q, k, v)
torch.testing.assert_close(o_ref, o, atol=1e-2, rtol=1e-3)
# test with pre-allocated output
o_buffer = torch.empty_like(o)
sparse_attention_wrapper.run(q, k, v, out=o_buffer)
torch.testing.assert_close(o_ref, o_buffer, atol=1e-2, rtol=1e-3)
def _ref_attention(
q: torch.Tensor, # [gqa_group_size, qo_len, head_dim]
k: torch.Tensor, # [1, kv_len, head_dim]
v: torch.Tensor, # [1, kv_len, head_dim]
block_mask_map: torch.Tensor, # [MB, NB]
block_row_sz: torch.Tensor, # [MB]
block_col_sz: torch.Tensor, # [NB]
) -> torch.Tensor:
# convert block mask map to element mask
def _block_mask_to_element_mask(
block_mask_map: torch.Tensor, # [MB, NB] bool
block_row_sz: torch.Tensor, # [MB] int (rows per block-row)
block_col_sz: torch.Tensor, # [NB] int (cols per block-col)
) -> torch.Tensor:
block_row_sz = block_row_sz.to(block_mask_map.device, dtype=torch.long)
block_col_sz = block_col_sz.to(block_mask_map.device, dtype=torch.long)
expanded_rows = torch.repeat_interleave(block_mask_map, block_row_sz, dim=0)
element_mask = torch.repeat_interleave(expanded_rows, block_col_sz, dim=1)
return element_mask
dense_mask = _block_mask_to_element_mask(
block_mask_map, block_row_sz, block_col_sz
).to(dtype=torch.bool, device=q.device)
q = q.transpose(0, 1).contiguous()
k = k.transpose(0, 1).contiguous()
v = v.transpose(0, 1).contiguous()
o = flashinfer.prefill.single_prefill_with_kv_cache(
q, k, v, custom_mask=dense_mask
) # [qo_len, gqa_group_size, head_dim]
o = o.transpose(0, 1).contiguous()
return o
@pytest.mark.parametrize("num_qo_heads", [1, 4, 16])
@pytest.mark.parametrize("num_kv_heads", [1, 4, 16])
@pytest.mark.parametrize("head_dim", [64, 128])
@pytest.mark.parametrize("seq_len", [256, 4096, 8192])
@pytest.mark.parametrize("num_blocks_row", [10, 20])
@pytest.mark.parametrize("num_blocks_col", [50, 100])
@pytest.mark.parametrize("block_density", [0.2, 0.7, 0.9])
def test_variable_block_sparse_attention_wrapper(
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
seq_len: int,
num_blocks_row: int,
num_blocks_col: int,
block_density: float,
):
if num_qo_heads % num_kv_heads != 0:
pytest.skip("num_qo_heads must be divisible by num_kv_heads")
if seq_len // num_blocks_row < 1:
pytest.skip("seq_len must be greater than num_blocks_row")
if seq_len // num_blocks_col < 1:
pytest.skip("seq_len must be greater than num_blocks_col")
set_seed(330)
def random_partition_batch(
seq_len: int,
num_blocks: int,
bsz: int,
device: torch.device | str = "cpu",
dtype: torch.dtype = torch.int32,
) -> torch.Tensor:
assert seq_len >= num_blocks
sizes = torch.empty((bsz, num_blocks), dtype=dtype, device=device)
for i in range(bsz):
cut_pts = torch.randperm(seq_len - 1, device=device)[: num_blocks - 1] + 1
cut_pts, _ = torch.sort(cut_pts)
row_sizes = torch.diff(
torch.cat(
(
torch.tensor([0], device=device),
cut_pts,
torch.tensor([seq_len], device=device),
)
)
)
sizes[i] = row_sizes
assert sizes.min() >= 1
assert sizes.max() <= seq_len
assert torch.all(sizes.sum(dim=-1) == seq_len)
return sizes.to(device=device)
def _test_variable_block_sparse_attention(
num_qo_heads: int,
num_kv_heads: int,
head_dim: int,
block_mask_map: torch.Tensor,
block_row_sz: torch.Tensor,
block_col_sz: torch.Tensor,
device: str = "cuda:0",
dtype: torch.dtype = torch.float16,
):
# qkv: HND
qo_len = block_row_sz.sum(dim=1)[0].item()
kv_len = block_col_sz.sum(dim=1)[0].item()
assert torch.all(block_col_sz.sum(dim=1) == block_col_sz.sum(dim=1)[0])
assert torch.all(block_row_sz.sum(dim=1) == block_row_sz.sum(dim=1)[0])
q = torch.randn(num_qo_heads, qo_len, head_dim, device=device, dtype=dtype)
k = torch.randn(num_kv_heads, kv_len, head_dim, device=device, dtype=dtype)
v = torch.randn(num_kv_heads, kv_len, head_dim, device=device, dtype=dtype)
float_workspace_buffer = torch.empty(128 * 1024 * 1024, device=device)
wrapper = flashinfer.sparse.VariableBlockSparseAttentionWrapper(
float_workspace_buffer, backend="auto"
)
wrapper.plan(
block_mask_map=block_mask_map,
block_row_sz=block_row_sz,
block_col_sz=block_col_sz,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
q_data_type=dtype,
)
o: torch.Tensor = wrapper.run(q, k, v) # [num_qo_heads, qo_len, head_dim]
o = o.reshape(num_kv_heads, -1, *o.shape[-2:])
q = q.reshape(num_kv_heads, -1, *q.shape[-2:])
for kv_head_idx in range(num_kv_heads):
o_ref = _ref_attention(
q[kv_head_idx],
k[kv_head_idx : kv_head_idx + 1, :, :],
v[kv_head_idx : kv_head_idx + 1, :, :],
block_mask_map[kv_head_idx],
block_row_sz[kv_head_idx],
block_col_sz[kv_head_idx],
)
torch.testing.assert_close(o[kv_head_idx], o_ref, atol=1e-2, rtol=1e-2)
block_row_sz = random_partition_batch(
seq_len, num_blocks_row, num_kv_heads, device="cuda:0"
)
block_col_sz = random_partition_batch(
seq_len, num_blocks_col, num_kv_heads, device="cuda:0"
)
block_mask_map = (
torch.rand(num_kv_heads, num_blocks_row, num_blocks_col) > block_density
).to(device="cuda:0")
_test_variable_block_sparse_attention(
num_qo_heads,
num_kv_heads,
head_dim,
block_mask_map,
block_row_sz,
block_col_sz,
)
if __name__ == "__main__":
# This test verifies the INT32_T overflow issue.
for seq_len in [16 * 1024, 32 * 1024, 40 * 1024, 48 * 1024, 64 * 1024]:
test_block_sparse_attention(128, 128, seq_len, seq_len, 1, 1, 128, False)