58 lines
1.8 KiB
Python
58 lines
1.8 KiB
Python
import pytest
|
|
import torch
|
|
|
|
import flashinfer
|
|
|
|
|
|
@pytest.mark.parametrize("contiguous", [True, False])
|
|
def test_append_paged_kv_cache(contiguous):
|
|
nnz_kv = 100
|
|
num_kv_heads = 32
|
|
head_dim = 128
|
|
|
|
if contiguous:
|
|
k_append = torch.randn(nnz_kv, num_kv_heads, head_dim).half().to(0)
|
|
v_append = torch.randn(nnz_kv, num_kv_heads, head_dim).half().to(0)
|
|
else:
|
|
kv_append = torch.randn(nnz_kv, 2, num_kv_heads, head_dim).half().to(0)
|
|
k_append = kv_append[:, 0]
|
|
v_append = kv_append[:, 1]
|
|
# 45 + 8 + 25 + 22 = nnz_kv
|
|
kv_append_length = torch.tensor([45, 8, 25, 22], dtype=torch.int32, device="cuda:0")
|
|
kv_append_indptr = torch.cat(
|
|
[torch.zeros(1).int().to(0), torch.cumsum(kv_append_length, dim=0)]
|
|
).int()
|
|
|
|
max_num_pages = 1000
|
|
page_size = 16
|
|
paged_kv_cache = (
|
|
torch.randn(max_num_pages, 2, page_size, num_kv_heads, head_dim).half().to(0)
|
|
)
|
|
num_pages_per_req = torch.tensor([3, 1, 2, 2], dtype=torch.int32, device="cuda:0")
|
|
kv_page_indptr = torch.cat(
|
|
[torch.zeros(1).int().to(0), torch.cumsum(num_pages_per_req, dim=0)]
|
|
).int()
|
|
# use first 8 pages in the paged-kv
|
|
kv_page_indices = torch.arange(8, dtype=torch.int32, device="cuda:0")
|
|
# 45 = (3 - 1) * 16 + 13
|
|
# 8 = (1 - 1) * 16 + 8
|
|
# 25 = (2 - 1) * 16 + 9
|
|
# 22 = (2 - 1) * 16 + 6
|
|
kv_last_page_len = torch.tensor([13, 8, 9, 6], dtype=torch.int32, device="cuda:0")
|
|
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
|
kv_append_indptr,
|
|
flashinfer.get_seq_lens(kv_page_indptr, kv_last_page_len, page_size),
|
|
nnz_kv,
|
|
)
|
|
|
|
flashinfer.append_paged_kv_cache(
|
|
k_append,
|
|
v_append,
|
|
batch_indices,
|
|
positions,
|
|
paged_kv_cache,
|
|
kv_page_indices,
|
|
kv_page_indptr,
|
|
kv_last_page_len,
|
|
)
|