sglang_v0.5.2/flashinfer_0.3.1/tests/test_single_prefill.py

104 lines
2.7 KiB
Python

import math
import pytest
import torch
import flashinfer
def build_causal_mask(qo_len, kv_len):
i = torch.arange(qo_len).unsqueeze(1).to("cuda:0")
j = torch.arange(kv_len).unsqueeze(0).to("cuda:0")
offset = kv_len - qo_len
mask = (j - offset > i).to(torch.bool)
return mask
def _repeat_kv(t: torch.Tensor, num_groups: int) -> torch.Tensor:
return t.repeat_interleave(num_groups, dim=1)
def single_prefill_with_kv_cache_ref(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool = False,
):
Lq, Hq, D = q.shape
Lk, Hkv, _ = k.shape
assert (Lk, Hkv, D) == v.shape
assert Hq % Hkv == 0
groups = Hq // Hkv
k_states = _repeat_kv(k, groups)
v_states = _repeat_kv(v, groups)
q_t = q.permute(1, 0, 2) # (Hq, Lq, D)
k_t = k_states.permute(1, 2, 0) # (Hq, D, Lk)
v_t = v_states.permute(1, 0, 2) # (Hq, Lk, D)
scale = 1.0 / math.sqrt(D)
attn_scores = torch.bmm(q_t, k_t) * scale # (Hq, Lq, Lk)
if causal:
# apply causal mask
causal_mask = build_causal_mask(Lq, Lk)
attn_scores = attn_scores.masked_fill(causal_mask, float("-inf"))
attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)
attn_output = torch.bmm(attn_weights, v_t) # (Hq, Lq, D)
attn_output = attn_output.permute(1, 0, 2).contiguous() # (Lq, Hq, D)
return attn_output
@pytest.mark.parametrize("kv_len", [501, 2042, 3771, 4932])
@pytest.mark.parametrize("qo_len", [37, 127, 577, 1024])
@pytest.mark.parametrize("num_kv_heads", [1])
@pytest.mark.parametrize("num_qo_heads", [4, 7])
@pytest.mark.parametrize("head_dim", [64, 128, 256])
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("pos_encoding_mode", ["NONE"])
def test_sinqle_prefill_with_paged_kv_cache(
kv_len,
qo_len,
num_kv_heads,
num_qo_heads,
head_dim,
causal,
pos_encoding_mode,
):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
if qo_len > kv_len and causal:
pytest.skip("qo_len > kv_len and causal is not supported")
q = torch.randn(
qo_len,
num_qo_heads,
head_dim,
device="cuda:0",
dtype=torch.float16,
)
k = torch.randn(
kv_len,
num_kv_heads,
head_dim,
device="cuda:0",
dtype=torch.float16,
)
v = torch.randn(
kv_len,
num_kv_heads,
head_dim,
device="cuda:0",
dtype=torch.float16,
)
o = flashinfer.prefill.single_prefill_with_kv_cache(
q, k, v, causal=causal, pos_encoding_mode=pos_encoding_mode, backend="fa2"
)
o_ref = single_prefill_with_kv_cache_ref(q, k, v, causal=causal)
torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)