sglang_v0.5.2/flashinfer_0.3.1/benchmarks/bench_batch_attention.py

205 lines
5.9 KiB
Python

from __future__ import annotations
import itertools
from typing import List, Sequence, Tuple
import numpy as np
import pandas as pd
import torch
import flashinfer
from flashinfer.testing.utils import bench_gpu_time
def run_bench(
kv_lens: Sequence[int],
qo_lens: Sequence[int],
*,
page_block_size: int,
num_kv_heads: int,
num_qo_heads: int,
head_dim: int,
device: int = 0,
causal: bool = True,
) -> Tuple[float, float, float, float, float]:
seq_lens = torch.tensor(kv_lens, dtype=torch.int32)
q_lens = torch.tensor(qo_lens, dtype=torch.int32)
seq_lens_blocks = torch.ceil(seq_lens / page_block_size).int()
q_indptr = torch.cat([torch.tensor([0]), torch.cumsum(q_lens, 0)], dim=0).int()
kv_indptr = torch.cat(
[torch.tensor([0]), torch.cumsum(seq_lens_blocks, 0)], dim=0
).int()
num_blocks = kv_indptr[-1].item()
q = torch.rand(
q_indptr[-1].item(), num_qo_heads, head_dim, dtype=torch.bfloat16, device=device
)
kv_data = torch.randn(
num_blocks,
2,
page_block_size,
num_kv_heads,
head_dim,
dtype=torch.bfloat16,
device=device,
)
# old
wrapper_old = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device),
kv_layout="NHD",
backend="fa2",
)
last_page_len = (seq_lens - 1) % page_block_size + 1
wrapper_old.plan(
q_indptr.to(device),
kv_indptr.to(device),
torch.arange(num_blocks, dtype=torch.int32, device=device),
last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
page_block_size,
causal=causal,
q_data_type=torch.bfloat16,
kv_data_type=torch.bfloat16,
)
measurements_old = bench_gpu_time(lambda: wrapper_old.run(q, kv_data))
ms_old = np.mean(measurements_old)
# new
wrapper = flashinfer.BatchAttention(kv_layout="NHD")
wrapper.plan(
q_indptr.to(device),
kv_indptr.to(device),
torch.arange(num_blocks, dtype=torch.int32, device=device),
seq_lens.to(device),
num_qo_heads,
num_kv_heads,
head_dim,
head_dim,
page_block_size,
causal=causal,
q_data_type=torch.bfloat16,
kv_data_type=torch.bfloat16,
)
measurements_new = bench_gpu_time(lambda: wrapper.run(q, kv_data))
ms_new = np.mean(measurements_new)
total_bytes = (
q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size()
)
mem_MB = total_bytes / 1024**2
bw_old = total_bytes / (ms_old * 1e-3) / 1024**3
bw_new = total_bytes / (ms_new * 1e-3) / 1024**3
return ms_old, ms_new, mem_MB, bw_old, bw_new
def synthesize_seq_len_configs() -> List[List[Tuple[int, int]]]:
cfgs: List[List[Tuple[int, int]]] = [
[(8192, 1)] * 128, # decode-only
[(4096, 128)] * 4, # prefill-only
[(600, 1)] * 122 + [(10_000, 17)] * 8, # hybird
[(8192, 1)] * 127 * 2 + [(8192, 4096)] * 1, # hybrid (chunked-prefill)
]
def _rand_case(bsz: int, lo: int, hi: int) -> List[Tuple[int, int]]:
stride, sparsity = 16, 0.05
full = np.random.randint(lo, hi, size=bsz)
out = []
for i, kv_len in enumerate(full):
if i % stride == 0:
out.append((kv_len, stride + 1))
else:
out.append((int(kv_len * sparsity), 1))
return out
cfgs.append(_rand_case(256, 1000, 8192))
cfgs.append(_rand_case(128, 2000, 16_000))
return cfgs
def main() -> None:
np.random.seed(42)
torch.random.manual_seed(42)
seq_len_cfgs = synthesize_seq_len_configs()
sweep = {
"page_block_size": (1, 8, 16),
"head_dim": (64, 128),
"num_kv_heads": (4,),
"num_qo_heads": (28,),
}
records = []
for cfg_id, pairs in enumerate(seq_len_cfgs, start=1):
kv_lens = [p[0] for p in pairs]
qo_lens = [p[1] for p in pairs]
for pbs, hd, n_kv, n_qo in itertools.product(
sweep["page_block_size"],
sweep["head_dim"],
sweep["num_kv_heads"],
sweep["num_qo_heads"],
):
ms_old, ms_new, mem_MB, bw_old, bw_new = run_bench(
kv_lens,
qo_lens,
page_block_size=pbs,
num_kv_heads=n_kv,
num_qo_heads=n_qo,
head_dim=hd,
device=0,
causal=True,
)
records.extend(
[
{
"scheduler": "BatchPrefillWithPagedKVCacheWrapper",
"seq_cfg_id": cfg_id,
"page_size": pbs,
"head_dim": hd,
"num_kv_heads": n_kv,
"num_qo_heads": n_qo,
"time_ms": ms_old,
"memory_MB": mem_MB,
"bandwidth_GB_s": bw_old,
},
{
"scheduler": "BatchAttentionWrapper",
"seq_cfg_id": cfg_id,
"page_size": pbs,
"head_dim": hd,
"num_kv_heads": n_kv,
"num_qo_heads": n_qo,
"time_ms": ms_new,
"memory_MB": mem_MB,
"bandwidth_GB_s": bw_new,
},
]
)
df = pd.DataFrame(
records,
columns=[
"scheduler",
"seq_cfg_id",
"page_size",
"head_dim",
"num_kv_heads",
"num_qo_heads",
"time_ms",
"memory_MB",
"bandwidth_GB_s",
],
)
print(df.to_markdown(index=False, floatfmt=".2f"))
df.to_csv("bench_batch_attention.csv", index=False)
if __name__ == "__main__":
main()