123 lines
3.7 KiB
Python
123 lines
3.7 KiB
Python
import math
|
|
from typing import List
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
import flashinfer
|
|
|
|
CKV_DIM = 512
|
|
KPE_DIM = 64
|
|
|
|
|
|
def calculate_last_page_len(kv_len: List[int], page_size: int):
|
|
return [len % page_size if len % page_size != 0 else page_size for len in kv_len]
|
|
|
|
|
|
kv_len_configs = [
|
|
[45],
|
|
[4096],
|
|
[45, 8, 25],
|
|
[45, 8, 25, 22],
|
|
[45, 8, 25, 22, 400],
|
|
[45, 8, 25, 22, 100],
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("kv_len", kv_len_configs)
|
|
@pytest.mark.parametrize("page_size", [1, 16, 64])
|
|
def test_append_mla_paged_kv_cache(kv_len, page_size):
|
|
nnz_kv = sum(kv_len)
|
|
ckv_append = torch.randn(nnz_kv, CKV_DIM, dtype=torch.float16, device="cuda:0")
|
|
kpe_append = torch.randn(nnz_kv, KPE_DIM, dtype=torch.float16, device="cuda:0")
|
|
num_pages_per_req = torch.tensor(
|
|
[math.ceil(len / page_size) for len in kv_len],
|
|
dtype=torch.int32,
|
|
device="cuda:0",
|
|
)
|
|
kv_append_length = torch.tensor(kv_len, dtype=torch.int32, device="cuda:0")
|
|
kv_append_indptr = torch.cat(
|
|
[torch.zeros(1).int().to(0), torch.cumsum(kv_append_length, dim=0)]
|
|
).int()
|
|
|
|
max_num_pages = sum(num_pages_per_req)
|
|
ckv_cache = torch.zeros(
|
|
max_num_pages, page_size, CKV_DIM, dtype=torch.float16, device="cuda:0"
|
|
)
|
|
kpe_cache = torch.zeros(
|
|
max_num_pages, page_size, KPE_DIM, dtype=torch.float16, device="cuda:0"
|
|
)
|
|
kv_page_indptr = torch.cat(
|
|
[torch.zeros(1).int().to(0), torch.cumsum(num_pages_per_req, dim=0)]
|
|
).int()
|
|
kv_page_indices = torch.arange(
|
|
sum(num_pages_per_req), dtype=torch.int32, device="cuda:0"
|
|
)
|
|
kv_last_page_len = torch.tensor(
|
|
calculate_last_page_len(kv_len, page_size), dtype=torch.int32, device="cuda:0"
|
|
)
|
|
batch_indices, positions = flashinfer.get_batch_indices_positions(
|
|
kv_append_indptr,
|
|
flashinfer.get_seq_lens(kv_page_indptr, kv_last_page_len, page_size),
|
|
nnz_kv,
|
|
)
|
|
flashinfer.append_paged_mla_kv_cache(
|
|
ckv_append,
|
|
kpe_append,
|
|
batch_indices,
|
|
positions,
|
|
ckv_cache,
|
|
kpe_cache,
|
|
kv_page_indices,
|
|
kv_page_indptr,
|
|
kv_last_page_len,
|
|
)
|
|
|
|
ckv_cache = ckv_cache.view(-1, CKV_DIM)
|
|
kpe_cache = kpe_cache.view(-1, KPE_DIM)
|
|
|
|
acc_kv_len = 0
|
|
acc_padding_kv_len = 0
|
|
for i in range(len(kv_len)):
|
|
assert torch.all(
|
|
torch.isclose(
|
|
kpe_append[acc_kv_len : acc_kv_len + kv_len[i]],
|
|
kpe_cache[acc_padding_kv_len : acc_padding_kv_len + kv_len[i]],
|
|
)
|
|
)
|
|
assert torch.all(
|
|
torch.isclose(
|
|
ckv_append[acc_kv_len : acc_kv_len + kv_len[i]],
|
|
ckv_cache[acc_padding_kv_len : acc_padding_kv_len + kv_len[i]],
|
|
)
|
|
)
|
|
assert bool(
|
|
torch.all(
|
|
ckv_cache[
|
|
acc_padding_kv_len + kv_len[i] : acc_padding_kv_len
|
|
+ num_pages_per_req[i] * page_size
|
|
]
|
|
== 0
|
|
)
|
|
)
|
|
assert bool(
|
|
torch.all(
|
|
kpe_cache[
|
|
acc_padding_kv_len + kv_len[i] : acc_padding_kv_len
|
|
+ num_pages_per_req[i] * page_size
|
|
]
|
|
== 0
|
|
)
|
|
)
|
|
acc_kv_len += kv_len[i]
|
|
acc_padding_kv_len += num_pages_per_req[i] * page_size
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_append_mla_paged_kv_cache([45], 64)
|
|
test_append_mla_paged_kv_cache([4096], 64)
|
|
test_append_mla_paged_kv_cache([45, 8, 25], 64)
|
|
test_append_mla_paged_kv_cache([45, 8, 25, 22], 64)
|
|
test_append_mla_paged_kv_cache([45, 8, 25, 22, 400], 128)
|
|
test_append_mla_paged_kv_cache([45, 8, 25, 22, 100], 16)
|