sglang_v0.5.2/flashinfer_0.3.1/tests/test_batch_prefill_kernels.py

1032 lines
35 KiB
Python

"""
Copyright (c) 2023 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy
import pytest
import torch
from jit_utils import gen_prefill_attention_modules
import flashinfer
@pytest.fixture(autouse=True, scope="module")
def warmup_jit():
flashinfer.jit.build_jit_specs(
gen_prefill_attention_modules(
[torch.float16], # q_dtypes
[
torch.float16,
torch.float8_e4m3fn,
torch.float8_e5m2,
], # kv_dtypes
[128, 256], # head_dims
[0, 1], # pos_encoding_modes
[False], # use_sliding_windows
[False], # use_logits_soft_caps
[False], # use_fp16_qk_reductions
),
verbose=False,
)
yield
@pytest.mark.parametrize("batch_size", [12, 17, 128])
@pytest.mark.parametrize("kv_len", [54, 97, 512, 2048])
@pytest.mark.parametrize("qo_len", [37, 17, 127, 577])
@pytest.mark.parametrize("page_size", [1, 5, 16])
@pytest.mark.parametrize("num_kv_heads", [4])
@pytest.mark.parametrize("num_qo_heads", [4, 32])
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("kv_layout", ["NHD"])
@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"])
@pytest.mark.parametrize("use_cuda_graph", [True])
@pytest.mark.parametrize("logits_soft_cap", [0.0])
@pytest.mark.parametrize("return_lse", [True])
@pytest.mark.parametrize("contiguous_kv", [True])
def test_batch_prefill_with_paged_kv_cache(
batch_size,
kv_len,
qo_len,
page_size,
num_kv_heads,
num_qo_heads,
head_dim,
causal,
kv_layout,
pos_encoding_mode,
use_cuda_graph,
logits_soft_cap,
return_lse,
contiguous_kv,
):
if qo_len > kv_len and causal:
pytest.skip("qo_len > kv_len and causal is not supported")
q = torch.randn(
batch_size * qo_len,
num_qo_heads,
head_dim,
device="cuda:0",
dtype=torch.float16,
)
q_indptr_cpu = torch.arange(0, batch_size + 1).int() * qo_len
num_pages_per_seq = (kv_len + page_size - 1) // page_size
total_num_pages = num_pages_per_seq * batch_size
if kv_layout == "HND":
kv_shape = [total_num_pages, 2, num_kv_heads, page_size, head_dim]
else:
kv_shape = [total_num_pages, 2, page_size, num_kv_heads, head_dim]
if not contiguous_kv:
tmp = [kv_shape[0]]
for v in kv_shape[1:]:
tmp.append(2)
tmp.append(v)
kv_shape = tmp
kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0")
kv_data = kv_data_fp32.half()
kv_data = kv_data[:, 1, :, 1, :, 1, :, 1, :]
kv_data_fp32 = kv_data_fp32[:, 1, :, 1, :, 1, :, 1, :]
# actual data is stored in non-contiguous memory
assert (
kv_data.stride(-4)
!= kv_data.shape[-3] * kv_data.shape[-2] * kv_data.shape[-1]
)
else:
kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0")
kv_data = kv_data_fp32.half()
kv_indptr_cpu = torch.arange(0, batch_size + 1).int() * num_pages_per_seq
kv_indices_cpu = torch.arange(0, total_num_pages).int()
kv_last_page_len_cpu = torch.full(
(batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32
)
workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device="cuda:0")
if not use_cuda_graph:
q_indptr_gpu = q_indptr_cpu.to(0)
kv_indptr_gpu = kv_indptr_cpu.to(0)
kv_indices_gpu = kv_indices_cpu.to(0)
kv_last_page_len_gpu = kv_last_page_len_cpu.to(0)
wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout
)
wrapper.plan(
q_indptr_gpu,
kv_indptr_gpu,
kv_indices_gpu,
kv_last_page_len_gpu,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
causal=causal,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
if return_lse:
o, _ = wrapper.run(q, kv_data, return_lse=True)
else:
o = wrapper.run(q, kv_data)
# test with pre-allocated output
o_buffer = torch.empty_like(o)
wrapper.run(q, kv_data, out=o_buffer)
torch.testing.assert_close(o, o_buffer, rtol=1e-3, atol=1e-3)
else:
q_indptr_buffer = torch.empty(
batch_size + 1, device="cuda:0", dtype=torch.int32
)
kv_indptr_buffer = torch.empty(
batch_size + 1, device="cuda:0", dtype=torch.int32
)
kv_indices_buffer = torch.empty(
total_num_pages, device="cuda:0", dtype=torch.int32
)
kv_last_page_len_buffer = torch.empty(
batch_size, device="cuda:0", dtype=torch.int32
)
wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout,
use_cuda_graph=True,
qo_indptr_buf=q_indptr_buffer,
paged_kv_indptr_buf=kv_indptr_buffer,
paged_kv_indices_buf=kv_indices_buffer,
paged_kv_last_page_len_buf=kv_last_page_len_buffer,
)
q_indptr_warmup = torch.arange(0, batch_size + 1).int() * qo_len
kv_indptr_warmup = torch.arange(0, batch_size + 1).int()
kv_indices_warmup = torch.arange(0, batch_size).int()
kv_last_page_len_warmup = torch.full(
(batch_size,), page_size, dtype=torch.int32
)
wrapper.plan(
q_indptr_warmup,
kv_indptr_warmup,
kv_indices_warmup,
kv_last_page_len_warmup,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
causal=causal,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
# warmup
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
if return_lse:
o, _ = wrapper.run(q, kv_data, return_lse=True)
else:
o = wrapper.run(q, kv_data)
torch.cuda.current_stream().wait_stream(s)
# capture
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
if return_lse:
o, _ = wrapper.run(q, kv_data, return_lse=True)
else:
o = wrapper.run(q, kv_data)
wrapper.plan(
q_indptr_cpu,
kv_indptr_cpu,
kv_indices_cpu,
kv_last_page_len_cpu,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
causal=causal,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
g.replay()
for i in range(batch_size):
perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3]
perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2]
qi = q[q_indptr_cpu[i] : q_indptr_cpu[i + 1]]
ki = torch.cat(
[
kv_data_fp32[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1, 0]
.permute(*perm_dims)
.reshape(-1, num_kv_heads, head_dim),
(
kv_data_fp32[
kv_indptr_cpu[i + 1] - 1, 0, :, : kv_last_page_len_cpu[i]
]
if kv_layout == "HND"
else kv_data_fp32[
kv_indptr_cpu[i + 1] - 1, 0, : kv_last_page_len_cpu[i], :
]
)
.permute(*perm_dims_last)
.reshape(-1, num_kv_heads, head_dim),
],
dim=0,
).half()
vi = torch.cat(
[
kv_data_fp32[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1, 1]
.permute(*perm_dims)
.reshape(-1, num_kv_heads, head_dim),
(
kv_data_fp32[
kv_indptr_cpu[i + 1] - 1, 1, :, : kv_last_page_len_cpu[i]
]
if kv_layout == "HND"
else kv_data_fp32[
kv_indptr_cpu[i + 1] - 1, 1, : kv_last_page_len_cpu[i], :
]
)
.permute(*perm_dims_last)
.reshape(-1, num_kv_heads, head_dim),
],
dim=0,
).half()
o_ref_i = flashinfer.prefill.single_prefill_with_kv_cache(
qi,
ki,
vi,
causal=causal,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
o_i = o[q_indptr_cpu[i] : q_indptr_cpu[i + 1]]
torch.testing.assert_close(o_i, o_ref_i, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("batch_size", [12, 17, 128])
@pytest.mark.parametrize("kv_len", [54, 97, 512, 2048])
@pytest.mark.parametrize("qo_len", [37, 17, 127, 577])
@pytest.mark.parametrize("page_size", [1, 5, 16])
@pytest.mark.parametrize("num_kv_heads", [4])
@pytest.mark.parametrize("num_qo_heads", [4, 32])
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("kv_layout", ["NHD"])
@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"])
@pytest.mark.parametrize("use_cuda_graph", [False, True])
@pytest.mark.parametrize("logits_soft_cap", [0.0])
@pytest.mark.parametrize("return_lse", [True])
@pytest.mark.parametrize("contiguous_kv", [True])
def test_batch_prefill_with_tuple_paged_kv_cache(
batch_size,
kv_len,
qo_len,
page_size,
num_kv_heads,
num_qo_heads,
head_dim,
causal,
kv_layout,
pos_encoding_mode,
use_cuda_graph,
logits_soft_cap,
return_lse,
contiguous_kv,
):
if qo_len > kv_len and causal:
pytest.skip("qo_len > kv_len and causal is not supported")
q = torch.randn(
batch_size * qo_len,
num_qo_heads,
head_dim,
device="cuda:0",
dtype=torch.float16,
)
q_indptr_cpu = torch.arange(0, batch_size + 1).int() * qo_len
num_pages_per_seq = (kv_len + page_size - 1) // page_size
total_num_pages = num_pages_per_seq * batch_size
if kv_layout == "HND":
kv_shape = [total_num_pages, num_kv_heads, page_size, head_dim]
else:
kv_shape = [total_num_pages, page_size, num_kv_heads, head_dim]
if not contiguous_kv:
tmp = [kv_shape[0]]
for v in kv_shape[1:]:
tmp.append(2)
tmp.append(v)
kv_shape = tmp
kv_data_fp32 = [
torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0")
for _ in range(2)
]
kv_data = [kv_data_fp32[i].half() for i in range(2)]
for i in range(2):
kv_data_fp32[i] = kv_data_fp32[i][:, 1, :, 1, :, 1, :]
kv_data[i] = kv_data[i][:, 1, :, 1, :, 1, :]
# actual data is stored in non-contiguous memory
assert (
kv_data[i].stride(-4)
!= kv_data[i].shape[-3] * kv_data[i].shape[-2] * kv_data[i].shape[-1]
)
else:
kv_data_fp32 = [
torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0")
for _ in range(2)
]
kv_data = [kv_data_fp32[i].half() for i in range(2)]
kv_data = tuple(kv_data)
kv_indptr_cpu = torch.arange(0, batch_size + 1).int() * num_pages_per_seq
kv_indices_cpu = torch.arange(0, total_num_pages).int()
kv_last_page_len_cpu = torch.full(
(batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32
)
workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device="cuda:0")
if not use_cuda_graph:
q_indptr_gpu = q_indptr_cpu.to(0)
kv_indptr_gpu = kv_indptr_cpu.to(0)
kv_indices_gpu = kv_indices_cpu.to(0)
kv_last_page_len_gpu = kv_last_page_len_cpu.to(0)
wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout
)
wrapper.plan(
q_indptr_gpu,
kv_indptr_gpu,
kv_indices_gpu,
kv_last_page_len_gpu,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
causal=causal,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
if return_lse:
o, _ = wrapper.run(q, kv_data, return_lse=True)
else:
o = wrapper.run(q, kv_data)
else:
q_indptr_buffer = torch.empty(
batch_size + 1, device="cuda:0", dtype=torch.int32
)
kv_indptr_buffer = torch.empty(
batch_size + 1, device="cuda:0", dtype=torch.int32
)
kv_indices_buffer = torch.empty(
total_num_pages, device="cuda:0", dtype=torch.int32
)
kv_last_page_len_buffer = torch.empty(
batch_size, device="cuda:0", dtype=torch.int32
)
wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout,
use_cuda_graph=True,
qo_indptr_buf=q_indptr_buffer,
paged_kv_indptr_buf=kv_indptr_buffer,
paged_kv_indices_buf=kv_indices_buffer,
paged_kv_last_page_len_buf=kv_last_page_len_buffer,
)
q_indptr_warmup = torch.arange(0, batch_size + 1).int() * qo_len
kv_indptr_warmup = torch.arange(0, batch_size + 1).int()
kv_indices_warmup = torch.arange(0, batch_size).int()
kv_last_page_len_warmup = torch.full(
(batch_size,), page_size, dtype=torch.int32
)
wrapper.plan(
q_indptr_warmup,
kv_indptr_warmup,
kv_indices_warmup,
kv_last_page_len_warmup,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
causal=causal,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
# warmup
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
if return_lse:
o, _ = wrapper.run(q, kv_data, return_lse=True)
else:
o = wrapper.run(q, kv_data)
torch.cuda.current_stream().wait_stream(s)
# capture
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
if return_lse:
o, _ = wrapper.run(q, kv_data, return_lse=True)
else:
o = wrapper.run(q, kv_data)
wrapper.plan(
q_indptr_cpu,
kv_indptr_cpu,
kv_indices_cpu,
kv_last_page_len_cpu,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
causal=causal,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
g.replay()
k_cache, v_cache = kv_data_fp32
for i in range(batch_size):
perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3]
perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2]
qi = q[q_indptr_cpu[i] : q_indptr_cpu[i + 1]]
ki = torch.cat(
[
k_cache[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1]
.permute(*perm_dims)
.reshape(-1, num_kv_heads, head_dim),
(
k_cache[kv_indptr_cpu[i + 1] - 1, :, : kv_last_page_len_cpu[i]]
if kv_layout == "HND"
else k_cache[kv_indptr_cpu[i + 1] - 1, : kv_last_page_len_cpu[i], :]
)
.permute(*perm_dims_last)
.reshape(-1, num_kv_heads, head_dim),
],
dim=0,
).half()
vi = torch.cat(
[
v_cache[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1]
.permute(*perm_dims)
.reshape(-1, num_kv_heads, head_dim),
(
v_cache[kv_indptr_cpu[i + 1] - 1, :, : kv_last_page_len_cpu[i]]
if kv_layout == "HND"
else v_cache[kv_indptr_cpu[i + 1] - 1, : kv_last_page_len_cpu[i], :]
)
.permute(*perm_dims_last)
.reshape(-1, num_kv_heads, head_dim),
],
dim=0,
).half()
o_ref_i = flashinfer.prefill.single_prefill_with_kv_cache(
qi,
ki,
vi,
causal=causal,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
o_i = o[q_indptr_cpu[i] : q_indptr_cpu[i + 1]]
torch.testing.assert_close(o_i, o_ref_i, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("batch_size", [12, 17, 128])
@pytest.mark.parametrize("kv_len", [54, 97, 512, 2048])
@pytest.mark.parametrize("qo_len", [37, 17, 127, 577])
@pytest.mark.parametrize("page_size", [1, 16])
@pytest.mark.parametrize("num_kv_heads", [4])
@pytest.mark.parametrize("num_qo_heads", [4, 32])
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("kv_layout", ["NHD"])
@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"])
@pytest.mark.parametrize("logits_soft_cap", [0.0])
@pytest.mark.parametrize("return_lse", [True])
@pytest.mark.parametrize("contiguous_kv", [True])
def test_batch_prefill_with_paged_kv_cache_custom_mask(
batch_size,
kv_len,
qo_len,
page_size,
num_kv_heads,
num_qo_heads,
head_dim,
kv_layout,
pos_encoding_mode,
logits_soft_cap,
return_lse,
contiguous_kv,
):
q = torch.randn(
batch_size * qo_len,
num_qo_heads,
head_dim,
device="cuda:0",
dtype=torch.float16,
)
q_indptr = (
torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * qo_len
)
num_pages_per_seq = (kv_len + page_size - 1) // page_size
total_num_pages = num_pages_per_seq * batch_size
if kv_layout == "HND":
kv_shape = [total_num_pages, 2, num_kv_heads, page_size, head_dim]
else:
kv_shape = [total_num_pages, 2, page_size, num_kv_heads, head_dim]
if not contiguous_kv:
tmp = [kv_shape[0]]
for v in kv_shape[1:]:
tmp.append(2)
tmp.append(v)
kv_shape = tmp
kv_data = torch.randn(*kv_shape, dtype=torch.float16, device="cuda:0")
kv_data = kv_data[:, 1, :, 1, :, 1, :, 1, :]
# actual data is stored in non-contiguous memory
assert (
kv_data.stride(-4)
!= kv_data.shape[-3] * kv_data.shape[-2] * kv_data.shape[-1]
)
else:
kv_data = torch.randn(*kv_shape, dtype=torch.float16, device="cuda:0")
kv_indptr = (
torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32)
* num_pages_per_seq
)
kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32)
kv_last_page_len = torch.full(
(batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0"
)
workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device="cuda:0")
wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout
)
custom_mask = torch.tril(
torch.full((batch_size, qo_len, kv_len), True, device="cuda:0"),
diagonal=(kv_len - qo_len),
).reshape(-1)
# use custom mask
wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
custom_mask=custom_mask,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
if return_lse:
o_custom, _ = wrapper.run(q, kv_data, return_lse=True)
else:
o_custom = wrapper.run(q, kv_data)
# use causal
wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
causal=True,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
if return_lse:
o_causal, _ = wrapper.run(q, kv_data, return_lse=True)
else:
o_causal = wrapper.run(q, kv_data)
torch.testing.assert_close(o_custom, o_causal, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("batch_size", [12, 17, 128])
@pytest.mark.parametrize("kv_len", [54, 97, 512, 2048])
@pytest.mark.parametrize("qo_len", [37, 17, 127, 577])
@pytest.mark.parametrize("num_kv_heads", [4])
@pytest.mark.parametrize("num_qo_heads", [4, 32])
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"])
@pytest.mark.parametrize("logits_soft_cap", [0.0])
@pytest.mark.parametrize("return_lse", [True])
def test_batch_prefill_with_ragged_kv_cache(
batch_size,
kv_len,
qo_len,
num_kv_heads,
num_qo_heads,
head_dim,
causal,
pos_encoding_mode,
logits_soft_cap,
return_lse,
):
if qo_len > kv_len and causal:
pytest.skip("qo_len > kv_len and causal is not supported")
kv_layout = "NHD"
q = torch.randn(
batch_size * qo_len,
num_qo_heads,
head_dim,
device="cuda:0",
dtype=torch.float16,
)
q_indptr = (
torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * qo_len
)
k = torch.randn(
batch_size * kv_len,
num_kv_heads,
head_dim,
device="cuda:0",
dtype=torch.float16,
)
v = torch.randn(
batch_size * kv_len,
num_kv_heads,
head_dim,
device="cuda:0",
dtype=torch.float16,
)
kv_indptr = (
torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * kv_len
)
workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device="cuda:0")
wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
workspace_buffer, kv_layout
)
wrapper.plan(
q_indptr,
kv_indptr,
num_qo_heads,
num_kv_heads,
head_dim,
causal=causal,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
if return_lse:
o, _ = wrapper.run(q, k, v, return_lse=True)
else:
o = wrapper.run(q, k, v)
for i in range(batch_size):
o_ref_i = flashinfer.prefill.single_prefill_with_kv_cache(
q[q_indptr[i] : q_indptr[i + 1]],
k[kv_indptr[i] : kv_indptr[i + 1]],
v[kv_indptr[i] : kv_indptr[i + 1]],
causal=causal,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
o_i = o[q_indptr[i] : q_indptr[i + 1]]
torch.testing.assert_close(o_i, o_ref_i, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("batch_size", [12, 17])
@pytest.mark.parametrize("kv_len", [54, 97])
@pytest.mark.parametrize("qo_len", [37, 17])
@pytest.mark.parametrize("num_kv_heads", [4])
@pytest.mark.parametrize("num_qo_heads", [4, 32])
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"])
@pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0])
@pytest.mark.parametrize("return_lse", [True, False])
def test_batch_prefill_with_ragged_kv_cache_custom_mask(
batch_size,
kv_len,
qo_len,
num_kv_heads,
num_qo_heads,
head_dim,
pos_encoding_mode,
logits_soft_cap,
return_lse,
):
kv_layout = "NHD"
q = torch.randn(
batch_size * qo_len,
num_qo_heads,
head_dim,
device="cuda:0",
dtype=torch.float16,
)
q_indptr = (
torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * qo_len
)
k = torch.randn(
batch_size * kv_len,
num_kv_heads,
head_dim,
device="cuda:0",
dtype=torch.float16,
)
v = torch.randn(
batch_size * kv_len,
num_kv_heads,
head_dim,
device="cuda:0",
dtype=torch.float16,
)
kv_indptr = (
torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * kv_len
)
workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device="cuda:0")
wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
workspace_buffer, kv_layout
)
custom_mask = torch.tril(
torch.full((batch_size, qo_len, kv_len), True, device="cuda:0"),
diagonal=(kv_len - qo_len),
).reshape(-1)
# use custom mask
wrapper.plan(
q_indptr,
kv_indptr,
num_qo_heads,
num_kv_heads,
head_dim,
custom_mask=custom_mask,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
if return_lse:
o_custom, _ = wrapper.run(q, k, v, return_lse=True)
else:
o_custom = wrapper.run(q, k, v)
# use causal
wrapper.plan(
q_indptr,
kv_indptr,
num_qo_heads,
num_kv_heads,
head_dim,
causal=True,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
)
if return_lse:
o_causal, _ = wrapper.run(q, k, v, return_lse=True)
else:
o_causal = wrapper.run(q, k, v)
torch.testing.assert_close(o_custom, o_causal, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize(
"kv_len, qo_len, prefix_len_ptr, token_pos_in_items_ptr, token_pos_in_items_len, max_item_len_ptr",
[
(54, 37, 17, list(range(17)) + list(range(19)) + [0], 100, [18]),
(97, 81, 16, list(range(80)) + [0], 97, [79]),
],
)
@pytest.mark.parametrize("page_size", [1, 5, 16])
@pytest.mark.parametrize("num_kv_heads", [4])
@pytest.mark.parametrize("num_qo_heads", [4, 32])
@pytest.mark.parametrize("head_dim", [128])
@pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("kv_layout", ["NHD"])
@pytest.mark.parametrize("pos_encoding_mode", ["ROPE_LLAMA"])
@pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0])
@pytest.mark.parametrize("return_lse", [True, False])
def test_batch_prefill_with_paged_kv_cache_multi_item_scoring(
batch_size,
kv_len,
qo_len,
prefix_len_ptr,
token_pos_in_items_ptr,
token_pos_in_items_len,
max_item_len_ptr,
page_size,
num_kv_heads,
num_qo_heads,
head_dim,
causal,
kv_layout,
pos_encoding_mode,
logits_soft_cap,
return_lse,
):
q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half()
q_indptr_cpu = torch.arange(0, batch_size + 1).int() * qo_len
num_pages_per_seq = (kv_len + page_size - 1) // page_size
total_num_pages = num_pages_per_seq * batch_size
kv_data = (
torch.randn(total_num_pages, 2, num_kv_heads, page_size, head_dim).to(0).half()
if kv_layout == "HND"
else torch.randn(total_num_pages, 2, page_size, num_kv_heads, head_dim)
.to(0)
.half()
)
kv_indptr_cpu = torch.arange(0, batch_size + 1).int() * num_pages_per_seq
kv_indices_cpu = torch.arange(0, total_num_pages).int()
kv_last_page_len_cpu = torch.full(
(batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32
)
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)
q_indptr_gpu = q_indptr_cpu.to(0)
kv_indptr_gpu = kv_indptr_cpu.to(0)
kv_indices_gpu = kv_indices_cpu.to(0)
kv_last_page_len_gpu = kv_last_page_len_cpu.to(0)
wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout
)
wrapper.plan(
q_indptr_gpu,
kv_indptr_gpu,
kv_indices_gpu,
kv_last_page_len_gpu,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
causal=causal,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
prefix_len_ptr=torch.tensor(prefix_len_ptr).to(dtype=torch.uint32).to(0),
token_pos_in_items_ptr=torch.tensor(token_pos_in_items_ptr)
.to(dtype=torch.uint16)
.to(0),
token_pos_in_items_len=torch.tensor(token_pos_in_items_len)
.to(dtype=torch.uint32)
.to(0),
max_item_len_ptr=torch.tensor(max_item_len_ptr).to(dtype=torch.uint16).to(0),
)
if return_lse:
o, _ = wrapper.run_return_lse(q, kv_data)
else:
o = wrapper.run(q, kv_data)
for i in range(batch_size):
perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3]
perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2]
qi = q[q_indptr_cpu[i] : q_indptr_cpu[i + 1]]
ki = torch.cat(
[
kv_data[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1, 0]
.permute(*perm_dims)
.reshape(-1, num_kv_heads, head_dim),
(
kv_data[kv_indptr_cpu[i + 1] - 1, 0, :, : kv_last_page_len_cpu[i]]
if kv_layout == "HND"
else kv_data[
kv_indptr_cpu[i + 1] - 1, 0, : kv_last_page_len_cpu[i], :
]
)
.permute(*perm_dims_last)
.reshape(-1, num_kv_heads, head_dim),
],
dim=0,
)
vi = torch.cat(
[
kv_data[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1, 1]
.permute(*perm_dims)
.reshape(-1, num_kv_heads, head_dim),
(
kv_data[kv_indptr_cpu[i + 1] - 1, 1, :, : kv_last_page_len_cpu[i]]
if kv_layout == "HND"
else kv_data[
kv_indptr_cpu[i + 1] - 1, 1, : kv_last_page_len_cpu[i], :
]
)
.permute(*perm_dims_last)
.reshape(-1, num_kv_heads, head_dim),
],
dim=0,
)
def create_2D_multi_item_mask_dense(
is_delimiter, sliding_window_size=-1, prefix_cache_len=None
):
# Function to create custom_mask for multi-item scoring
#
# Note, sliding window implementation assumes that candidate_i_size < sliding_window_size < prefix_size
# Args:
# is_delimiter: a boolen torch vec to indicate the delimiter position for creating custom attnetion mask in multi-item scoring
# currently assume qo len and kv len are the same and 1D (bsz=1) case
# sliding_window_size: the window size for sliding window attention, -1 means no sliding window attention
delimiter_idx = is_delimiter.nonzero(as_tuple=True)[0]
if len(delimiter_idx) == 0:
return None
else:
first_delimiter_pos = delimiter_idx[0]
seq_len = len(is_delimiter)
pos = torch.arange(seq_len, device=is_delimiter.device)
group_ids = torch.cumsum(is_delimiter, 0)
# Get mask for within-group causal attention
within_group_causal = (group_ids.unsqueeze(1) == group_ids.unsqueeze(0)) & (
pos.unsqueeze(0) <= pos.unsqueeze(1)
)
# Combine all conditions
attention_mask = (
(
within_group_causal
| (
(pos >= first_delimiter_pos).unsqueeze(1)
& (pos < first_delimiter_pos).unsqueeze(0)
) # Prefix attention
)
& ~is_delimiter.unsqueeze(0)
& ~is_delimiter.unsqueeze(1)
) # No delimiter attention
if sliding_window_size > 0 and sliding_window_size < len(is_delimiter):
# Calculate how many positions from right of prefix each token can attend to
group_size = torch.sum(
within_group_causal & ~is_delimiter.unsqueeze(0), dim=1
)
# For prefix: after sliding_window_size position, can see window_size tokens
# For candidate items: can see (sliding_window_size - group_size) tokens from prefix end
prefix_window = torch.where(
pos >= first_delimiter_pos,
sliding_window_size - group_size,
torch.where(
pos < sliding_window_size,
first_delimiter_pos,
sliding_window_size,
),
)
# Starting index of attention window relative to token position for candidate item/group
prefix_start = first_delimiter_pos - prefix_window.unsqueeze(1)
attention_mask = attention_mask & (pos >= prefix_start)
if prefix_cache_len:
patch = torch.ones(
seq_len,
prefix_cache_len,
device=is_delimiter.device,
dtype=torch.bool,
)
attention_mask = torch.concat([patch, attention_mask], dim=1)
return attention_mask.unsqueeze(0).reshape(-1)
custom_mask = create_2D_multi_item_mask_dense(
is_delimiter=torch.tensor(token_pos_in_items_ptr).to(0) == 0,
sliding_window_size=-1,
prefix_cache_len=prefix_len_ptr,
)
o_ref_i = flashinfer.prefill.single_prefill_with_kv_cache(
qi,
ki,
vi,
causal=causal,
pos_encoding_mode=pos_encoding_mode,
logits_soft_cap=logits_soft_cap,
custom_mask=custom_mask,
)
o_i_np = o[q_indptr_cpu[i] : q_indptr_cpu[i + 1]].cpu().numpy()
o_ref_i_np = o_ref_i.cpu().numpy()
numpy.testing.assert_allclose(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3)
if __name__ == "__main__":
test_batch_prefill_with_paged_kv_cache(
12, 54, 37, 16, 8, 8, 128, True, "HND", "NONE", True, 0.0, False, True
)
test_batch_prefill_with_tuple_paged_kv_cache(
12, 54, 37, 16, 8, 8, 128, True, "HND", "NONE", True, 0.0, False, True
)
test_batch_prefill_with_paged_kv_cache(
12, 54, 37, 1, 8, 8, 128, True, "HND", "NONE", False, 0.0, False, True
)
test_batch_prefill_with_paged_kv_cache_custom_mask(
1, 137, 137, 1, 8, 8, 128, "HND", "NONE", 0.0, False, True
)
test_batch_prefill_with_ragged_kv_cache(
12, 54, 37, 8, 8, 128, True, "NONE", 0.0, False
)
test_batch_prefill_with_ragged_kv_cache_custom_mask(
1, 137, 137, 8, 8, 128, "NONE", 0.0, False
)