sglang_v0.5.2/sglang/sgl-kernel/tests/test_sparse_flash_attn.py

493 lines
17 KiB
Python

import math
from typing import List, Optional, Tuple
import pytest
import torch
from einops import rearrange, repeat
from sgl_kernel.sparse_flash_attn import (
convert_vertical_slash_indexes,
convert_vertical_slash_indexes_mergehead,
sparse_attn_func,
)
from test_flash_attention import construct_local_mask, is_fa3_supported
def ref_attn(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
attn_bias=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
softcap=0.0,
upcast=True,
reorder_ops=False,
key_leftpad=None,
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k: (batch_size, seqlen_k, nheads_k, head_dim)
v: (batch_size, seqlen_k, nheads_k, head_dim)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
without changing the math. This is to estimate the numerical error from operation
reordering.
Output:
output: (batch_size, seqlen_q, nheads, head_dim)
lse: (batch_size, nheads, seqlen_q)
"""
if causal:
window_size = (window_size[0], 0)
dtype_og = q.dtype
if upcast:
q, k, v = q.float(), k.float(), v.float()
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
d = q.shape[-1]
if not reorder_ops:
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
else:
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
lse_ref = scores.logsumexp(dim=-1)
if softcap > 0:
scores = scores / softcap
scores = scores.tanh()
scores = scores * softcap
if key_padding_mask is not None:
scores.masked_fill_(
rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")
)
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
q.device,
key_leftpad=key_leftpad,
)
scores.masked_fill_(local_mask, float("-inf"))
if attn_bias is not None:
scores = scores + attn_bias
attention = torch.softmax(scores, dim=-1).to(v.dtype)
# Some rows might be completely masked out so we fill them with zero instead of NaN
if window_size[0] >= 0 or window_size[1] >= 0:
attention = attention.masked_fill(
torch.all(local_mask, dim=-1, keepdim=True), 0.0
)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if query_padding_mask is not None:
attention = attention.masked_fill(
rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0
)
dropout_scaling = 1.0 / (1 - dropout_p)
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
if dropout_mask is not None:
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
else:
attention_drop = attention
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
return output.to(dtype=dtype_og), lse_ref
def ref_paged_attn(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
query_lens: List[int],
kv_lens: List[int],
block_tables: torch.Tensor,
scale: float,
sliding_window: Optional[int] = None,
soft_cap: Optional[float] = None,
) -> torch.Tensor:
num_seqs = len(query_lens)
block_tables = block_tables.cpu().numpy()
_, block_size, num_kv_heads, head_size = key_cache.shape
outputs: List[torch.Tensor] = []
start_idx = 0
for i in range(num_seqs):
query_len = query_lens[i]
kv_len = kv_lens[i]
# clone to avoid clobbering the query tensor
q = query[start_idx : start_idx + query_len].clone()
q *= scale
num_kv_blocks = (kv_len + block_size - 1) // block_size
block_indices = block_tables[i, :num_kv_blocks]
k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
k = k[:kv_len]
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
v = v[:kv_len]
if q.shape[1] != k.shape[1]:
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
attn = torch.einsum("qhd,khd->hqk", q, k).float()
empty_mask = torch.ones(query_len, kv_len)
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
if sliding_window is not None:
sliding_window_mask = (
torch.triu(
empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1
)
.bool()
.logical_not()
)
mask |= sliding_window_mask
if soft_cap is not None:
attn = soft_cap * torch.tanh(attn / soft_cap)
attn.masked_fill_(mask, float("-inf"))
attn = torch.softmax(attn, dim=-1).to(v.dtype)
out = torch.einsum("hqk,khd->qhd", attn, v)
outputs.append(out)
start_idx += query_len
return torch.cat(outputs, dim=0)
@pytest.mark.skipif(
not is_fa3_supported(),
reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
)
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize(
"seq_lens",
[
(1, 1),
(1, 1024),
(1, 2048),
(1023, 2049),
(1023, 1023),
(32, 32),
(65, 65),
(129, 129),
],
)
@pytest.mark.parametrize("num_heads", [1, 2, 4])
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("NNZ_S", [0, 1, 2, 3, 7, 15, 32])
@torch.inference_mode()
def test_sparse_attention(
batch_size,
seq_lens,
num_heads,
head_size,
dtype,
NNZ_S,
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
block_size_M = 64
block_size_N = 64
seqlen_q, seqlen_k = seq_lens
q = torch.randn(
batch_size, seqlen_q, num_heads, head_size, dtype=dtype, requires_grad=False
)
k = torch.randn(
batch_size, seqlen_k, num_heads, head_size, dtype=dtype, requires_grad=False
)
v = torch.randn(
batch_size, seqlen_k, num_heads, head_size, dtype=dtype, requires_grad=False
)
NUM_ROWS = (seqlen_q + block_size_M - 1) // block_size_M
if NNZ_S * block_size_N > seqlen_k:
return
NNZ_V = seqlen_k - NNZ_S * block_size_N
block_count = torch.tensor(
[NNZ_S] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32
).reshape(batch_size, num_heads, NUM_ROWS)
column_count = torch.tensor(
[NNZ_V] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32
).reshape(batch_size, num_heads, NUM_ROWS)
block_offset = torch.tensor(
[[i * block_size_N for i in range(NNZ_S)]] * batch_size * NUM_ROWS * num_heads,
dtype=torch.int32,
).reshape(batch_size, num_heads, NUM_ROWS, NNZ_S)
column_index = torch.tensor(
[[NNZ_S * block_size_N + i for i in range(NNZ_V)]]
* batch_size
* NUM_ROWS
* num_heads,
dtype=torch.int32,
).reshape(batch_size, num_heads, NUM_ROWS, NNZ_V)
out, lse = sparse_attn_func(
q,
k,
v,
block_count,
block_offset,
column_count,
column_index,
return_softmax_lse=True,
)
ref_out, ref_lse = ref_attn(q, k, v)
torch.testing.assert_close(
out, ref_out, atol=2e-2, rtol=1e-2
), f"{torch.max(torch.abs(out - ref_out))}"
torch.testing.assert_close(
lse, ref_lse, atol=2e-2, rtol=1e-2
), f"{torch.max(torch.abs(lse - ref_lse))}"
# sparse attention utils
# origin
@pytest.mark.skipif(
not is_fa3_supported(),
reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
)
@pytest.mark.parametrize("causal", [True, False])
def test_convert_vertical_slash_indexes(causal):
# Prepare small, hand-checkable inputs
q_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda") # [BATCH]
kv_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda")
vertical_indexes = torch.tensor(
[[[1, 3]]], dtype=torch.int32, device="cuda"
) # [BATCH, N_HEADS, NNZ_V]
slash_indexes = torch.tensor(
[[[2]]], dtype=torch.int32, device="cuda"
) # [BATCH, N_HEADS, NNZ_S]
context_size = 4
block_size_M = 2
block_size_N = 2
# Call your CUDA kernel wrapper
block_count, block_offset, column_count, column_index = (
convert_vertical_slash_indexes(
q_seqlens,
kv_seqlens,
vertical_indexes,
slash_indexes,
context_size,
block_size_M,
block_size_N,
causal=causal,
)
)
# Manually create expected outputs for this input
# There are 2 rows (blocks): row0 (tokens 0-1), row1 (tokens 2-3)
# Fill these expected tensors based on your CUDA kernel's logic
# For demonstration, we assume:
# - block_count: how many slash indices fall into each block
# - block_offset: the value of those indices
# - column_count: number of valid vertical indices per block
# - column_index: the actual vertical indices
expected_column_index = torch.tensor(
[[[[0, 0], [0, 0]]]], dtype=torch.int32, device="cuda"
)
# If causal=False, update these tensors according to expected behavior
if not causal:
# Update these values if your kernel produces different output in non-causal mode
expected_column_index = torch.tensor(
[[[[1, 0], [1, 3]]]], dtype=torch.int32, device="cuda"
)
# Assert that outputs match expectations
assert torch.equal(column_index, expected_column_index)
# mergehead
@pytest.mark.skipif(
not is_fa3_supported(),
reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
)
@pytest.mark.parametrize("causal", [True, False])
def test_convert_vertical_slash_indexes_mergehead(causal):
# Prepare small, hand-checkable inputs for mergehead version
q_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda")
kv_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda")
vertical_indexes = torch.tensor(
[
[
[1, 3], # head 0
[2, 0], # head 1
]
],
dtype=torch.int32,
device="cuda",
) # [BATCH, N_HEADS, NNZ_V]
slash_indexes = torch.tensor(
[
[
[2, 0], # head 0
[1, 3], # head 1
]
],
dtype=torch.int32,
device="cuda",
) # [BATCH, N_HEADS, NNZ_S]
vertical_indices_count = torch.tensor([2, 1], dtype=torch.int32, device="cuda")
slash_indices_count = torch.tensor([1, 2], dtype=torch.int32, device="cuda")
context_size = 4
block_size_M = 2
block_size_N = 2
# Call your CUDA kernel wrapper
block_count, block_offset, column_count, column_index = (
convert_vertical_slash_indexes_mergehead(
q_seqlens,
kv_seqlens,
vertical_indexes,
slash_indexes,
vertical_indices_count,
slash_indices_count,
context_size,
block_size_M,
block_size_N,
causal=causal,
)
)
# Manually create expected outputs for this input
# For demonstration, assume:
# - batch=1, head=2, num_rows=2, nnz_v=2, nnz_s=2
# Fill these expected tensors according to your kernel's behavior
expected_column_index = torch.tensor(
[[[[1, 0], [1, 3]], [[-1079459945, -1077788999], [-1080050043, -1104625879]]]],
dtype=torch.int32,
device="cuda",
)
if not causal:
# If non-causal mode output is different, update these values
expected_column_index = torch.tensor(
[[[[1, 0], [1, 3]], [[2, -1077788999], [2, -1104625879]]]],
dtype=torch.int32,
device="cuda",
)
# Assert that outputs match expectations
assert torch.equal(column_index, expected_column_index)
# skip cause use fa2 for test
# @pytest.mark.parametrize("seq_lens", [[(1024, 1328)],
# [(1024, 1328), (1, 2048)],
# [(1025, 1328), (2, 2048)],
# [(1025, 2049), (2, 1281)],
# ])
# @pytest.mark.parametrize("head_size", [128])
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
# @torch.inference_mode()
# def test_sparse_attention_varlen(
# seq_lens,
# head_size,
# dtype,
# ) -> None:
# torch.set_default_device("cuda")
# torch.cuda.manual_seed_all(0)
# block_size_M = 64
# block_size_N = 64
# num_seqs = len(seq_lens)
# query_lens = [x[0] for x in seq_lens]
# kv_lens = [x[1] for x in seq_lens]
# num_heads = 1
# query = torch.randn(sum(query_lens),
# num_heads,
# head_size,
# dtype=dtype)
# key = torch.randn(sum(kv_lens),
# num_heads,
# head_size,
# dtype=dtype)
# value = torch.randn_like(key)
# cu_query_lens = torch.tensor([0] + query_lens,
# dtype=torch.int32).cumsum(dim=0,
# dtype=torch.int32)
# cu_kv_lens = torch.tensor([0] + kv_lens,
# dtype=torch.int32).cumsum(dim=0,
# dtype=torch.int32)
# max_query_len = max(query_lens)
# max_kv_len = max(kv_lens)
# NUM_ROWS = (max_query_len + block_size_M - 1) // block_size_M
# NNZ_S = 20
# NNZ_V = 2048
# batch_size = len(query_lens)
# block_counts = []
# column_counts = []
# block_offsets = []
# column_indices = []
# for b in range(batch_size):
# block_counts.append(torch.tensor([NNZ_S] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS))
# columns = kv_lens[b] - NNZ_S * block_size_N
# column_counts.append(torch.tensor([columns] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS))
# block_offsets.append(torch.tensor([[i * block_size_N for i in range(NNZ_S)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS, NNZ_S))
# column_indices.append(torch.tensor([[NNZ_S * block_size_N + i for i in range(NNZ_V)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS, NNZ_V))
# block_count = torch.concat(block_counts).reshape(batch_size, num_heads, NUM_ROWS)
# column_count = torch.concat(column_counts).reshape(batch_size, num_heads, NUM_ROWS)
# block_offset = torch.concat(block_offsets).reshape(batch_size, num_heads, NUM_ROWS, NNZ_S)
# column_index = torch.concat(column_indices).reshape(batch_size, num_heads, NUM_ROWS, NNZ_V)
# out, lse = sparse_attn_varlen_func(
# query,
# key,
# value,
# block_count,
# block_offset,
# column_count,
# column_index,
# cu_seqlens_q=cu_query_lens,
# cu_seqlens_k=cu_kv_lens,
# max_seqlen_q=max_query_len,
# max_seqlen_k=max_kv_len,
# return_softmax_lse=True,
# )
# max_num_blocks_per_seq = (max_kv_len + 2048 - 1) // 2048
# block_tables = torch.randint(0,
# 2048,
# (len(query_lens), max_num_blocks_per_seq),
# dtype=torch.int32)
# scale = head_size**-0.5
# ref_out, ref_lse, _ = ref_paged_attn(
# query,
# key,
# value,
# query_lens=query_lens,
# kv_lens=kv_lens,
# block_tables=block_tables,
# scale=scale
# )
# torch.testing.assert_close(out, ref_out, atol=2e-2, rtol=1e-2), \
# f"{torch.max(torch.abs(out - ref_out))}"
# torch.testing.assert_close(lse, ref_lse, atol=2e-2, rtol=1e-2), \
# f"{torch.max(torch.abs(lse - ref_lse))}"
if __name__ == "__main__":
pytest.main([__file__])