sglang_v0.5.2/flashinfer_0.3.1/tests/test_page.py

58 lines
1.8 KiB
Python

import pytest
import torch
import flashinfer
@pytest.mark.parametrize("contiguous", [True, False])
def test_append_paged_kv_cache(contiguous):
nnz_kv = 100
num_kv_heads = 32
head_dim = 128
if contiguous:
k_append = torch.randn(nnz_kv, num_kv_heads, head_dim).half().to(0)
v_append = torch.randn(nnz_kv, num_kv_heads, head_dim).half().to(0)
else:
kv_append = torch.randn(nnz_kv, 2, num_kv_heads, head_dim).half().to(0)
k_append = kv_append[:, 0]
v_append = kv_append[:, 1]
# 45 + 8 + 25 + 22 = nnz_kv
kv_append_length = torch.tensor([45, 8, 25, 22], dtype=torch.int32, device="cuda:0")
kv_append_indptr = torch.cat(
[torch.zeros(1).int().to(0), torch.cumsum(kv_append_length, dim=0)]
).int()
max_num_pages = 1000
page_size = 16
paged_kv_cache = (
torch.randn(max_num_pages, 2, page_size, num_kv_heads, head_dim).half().to(0)
)
num_pages_per_req = torch.tensor([3, 1, 2, 2], dtype=torch.int32, device="cuda:0")
kv_page_indptr = torch.cat(
[torch.zeros(1).int().to(0), torch.cumsum(num_pages_per_req, dim=0)]
).int()
# use first 8 pages in the paged-kv
kv_page_indices = torch.arange(8, dtype=torch.int32, device="cuda:0")
# 45 = (3 - 1) * 16 + 13
# 8 = (1 - 1) * 16 + 8
# 25 = (2 - 1) * 16 + 9
# 22 = (2 - 1) * 16 + 6
kv_last_page_len = torch.tensor([13, 8, 9, 6], dtype=torch.int32, device="cuda:0")
batch_indices, positions = flashinfer.get_batch_indices_positions(
kv_append_indptr,
flashinfer.get_seq_lens(kv_page_indptr, kv_last_page_len, page_size),
nnz_kv,
)
flashinfer.append_paged_kv_cache(
k_append,
v_append,
batch_indices,
positions,
paged_kv_cache,
kv_page_indices,
kv_page_indptr,
kv_last_page_len,
)