sglang_v0.5.2/flashinfer_0.3.1/tests/test_mla_page.py

123 lines
3.7 KiB
Python

import math
from typing import List
import pytest
import torch
import flashinfer
CKV_DIM = 512
KPE_DIM = 64
def calculate_last_page_len(kv_len: List[int], page_size: int):
return [len % page_size if len % page_size != 0 else page_size for len in kv_len]
kv_len_configs = [
[45],
[4096],
[45, 8, 25],
[45, 8, 25, 22],
[45, 8, 25, 22, 400],
[45, 8, 25, 22, 100],
]
@pytest.mark.parametrize("kv_len", kv_len_configs)
@pytest.mark.parametrize("page_size", [1, 16, 64])
def test_append_mla_paged_kv_cache(kv_len, page_size):
nnz_kv = sum(kv_len)
ckv_append = torch.randn(nnz_kv, CKV_DIM, dtype=torch.float16, device="cuda:0")
kpe_append = torch.randn(nnz_kv, KPE_DIM, dtype=torch.float16, device="cuda:0")
num_pages_per_req = torch.tensor(
[math.ceil(len / page_size) for len in kv_len],
dtype=torch.int32,
device="cuda:0",
)
kv_append_length = torch.tensor(kv_len, dtype=torch.int32, device="cuda:0")
kv_append_indptr = torch.cat(
[torch.zeros(1).int().to(0), torch.cumsum(kv_append_length, dim=0)]
).int()
max_num_pages = sum(num_pages_per_req)
ckv_cache = torch.zeros(
max_num_pages, page_size, CKV_DIM, dtype=torch.float16, device="cuda:0"
)
kpe_cache = torch.zeros(
max_num_pages, page_size, KPE_DIM, dtype=torch.float16, device="cuda:0"
)
kv_page_indptr = torch.cat(
[torch.zeros(1).int().to(0), torch.cumsum(num_pages_per_req, dim=0)]
).int()
kv_page_indices = torch.arange(
sum(num_pages_per_req), dtype=torch.int32, device="cuda:0"
)
kv_last_page_len = torch.tensor(
calculate_last_page_len(kv_len, page_size), dtype=torch.int32, device="cuda:0"
)
batch_indices, positions = flashinfer.get_batch_indices_positions(
kv_append_indptr,
flashinfer.get_seq_lens(kv_page_indptr, kv_last_page_len, page_size),
nnz_kv,
)
flashinfer.append_paged_mla_kv_cache(
ckv_append,
kpe_append,
batch_indices,
positions,
ckv_cache,
kpe_cache,
kv_page_indices,
kv_page_indptr,
kv_last_page_len,
)
ckv_cache = ckv_cache.view(-1, CKV_DIM)
kpe_cache = kpe_cache.view(-1, KPE_DIM)
acc_kv_len = 0
acc_padding_kv_len = 0
for i in range(len(kv_len)):
assert torch.all(
torch.isclose(
kpe_append[acc_kv_len : acc_kv_len + kv_len[i]],
kpe_cache[acc_padding_kv_len : acc_padding_kv_len + kv_len[i]],
)
)
assert torch.all(
torch.isclose(
ckv_append[acc_kv_len : acc_kv_len + kv_len[i]],
ckv_cache[acc_padding_kv_len : acc_padding_kv_len + kv_len[i]],
)
)
assert bool(
torch.all(
ckv_cache[
acc_padding_kv_len + kv_len[i] : acc_padding_kv_len
+ num_pages_per_req[i] * page_size
]
== 0
)
)
assert bool(
torch.all(
kpe_cache[
acc_padding_kv_len + kv_len[i] : acc_padding_kv_len
+ num_pages_per_req[i] * page_size
]
== 0
)
)
acc_kv_len += kv_len[i]
acc_padding_kv_len += num_pages_per_req[i] * page_size
if __name__ == "__main__":
test_append_mla_paged_kv_cache([45], 64)
test_append_mla_paged_kv_cache([4096], 64)
test_append_mla_paged_kv_cache([45, 8, 25], 64)
test_append_mla_paged_kv_cache([45, 8, 25, 22], 64)
test_append_mla_paged_kv_cache([45, 8, 25, 22, 400], 128)
test_append_mla_paged_kv_cache([45, 8, 25, 22, 100], 16)