96 lines
3.1 KiB
Python
96 lines
3.1 KiB
Python
"""
|
|
Copyright (c) 2024 by FlashInfer team.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
import flashinfer
|
|
from flashinfer.testing.utils import bench_gpu_time
|
|
|
|
page_block_size = 16
|
|
num_kv_heads = 4
|
|
num_qo_heads = 32
|
|
head_dim = 128
|
|
|
|
|
|
def bench_batch_decode(
|
|
batch_size,
|
|
seq_len,
|
|
num_qo_heads,
|
|
num_kv_heads,
|
|
head_dim,
|
|
page_block_size,
|
|
q_dtype,
|
|
kv_dtype,
|
|
):
|
|
np.random.seed(42)
|
|
seq_lens = torch.full((batch_size,), seq_len)
|
|
seq_lens_blocks = torch.ceil(seq_lens / page_block_size).int()
|
|
kv_indptr = torch.cat([torch.tensor([0]), torch.cumsum(seq_lens_blocks, 0)], dim=0)
|
|
kv_indptr = kv_indptr.int()
|
|
last_page_len = seq_lens - (seq_lens_blocks - 1) * page_block_size
|
|
last_page_len = last_page_len.int()
|
|
num_blocks = kv_indptr[-1].item()
|
|
|
|
q = torch.rand(batch_size, num_qo_heads, head_dim, dtype=q_dtype, device="cuda:0")
|
|
kv_data = torch.randn(
|
|
num_blocks, 2, page_block_size, num_kv_heads, head_dim, device="cuda:0"
|
|
).to(kv_dtype)
|
|
workspace_buffer = torch.empty(
|
|
128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0"
|
|
)
|
|
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
|
workspace_buffer, kv_layout="NHD", use_tensor_cores=True
|
|
)
|
|
wrapper.plan(
|
|
kv_indptr.to(0),
|
|
torch.arange(num_blocks).int().to(0),
|
|
last_page_len.to(0),
|
|
num_qo_heads,
|
|
num_kv_heads,
|
|
head_dim,
|
|
page_block_size,
|
|
data_type=kv_dtype,
|
|
q_data_type=q_dtype,
|
|
)
|
|
|
|
measurements = bench_gpu_time(lambda: wrapper.run(q, kv_data))
|
|
ms = np.median(measurements)
|
|
|
|
io = q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size()
|
|
print(
|
|
f"batch_size={batch_size}, seq_len={seq_len}, num_qo_heads={num_qo_heads}, num_kv_heads={num_kv_heads}, head_dim={head_dim}, page_block_size={page_block_size}, q_dtype={q_dtype}, kv_dtype={kv_dtype}"
|
|
)
|
|
print(f"execution time: {ms}ms")
|
|
print(f"memory bandwidth: {io / ms / 1024 / 1024:.2f} GB/s")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
for q_dtype in [torch.bfloat16]:
|
|
for kv_dtype in [torch.bfloat16, torch.float8_e4m3fn]:
|
|
for batch_size in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]:
|
|
for seq_len in [512, 1024, 2048, 4096, 8192, 16384]:
|
|
bench_batch_decode(
|
|
batch_size,
|
|
seq_len,
|
|
num_qo_heads,
|
|
num_kv_heads,
|
|
head_dim,
|
|
page_block_size,
|
|
q_dtype,
|
|
kv_dtype,
|
|
)
|