sglang_v0.5.2/flashinfer_0.3.1/benchmarks/bench_append_paged_kv_cache.py

154 lines
4.9 KiB
Python

import argparse
import dataclasses
from typing import Tuple
import numpy as np
import torch
import flashinfer
from flashinfer.testing.utils import bench_gpu_time
@dataclasses.dataclass(kw_only=True)
class ModelConfig:
num_kv_heads: int
num_layers: int
head_dim: int
def _make_70b(tp: int) -> ModelConfig:
return ModelConfig(
num_kv_heads=8 // tp,
num_layers=80,
head_dim=128,
)
MODELS = {
"l1b": ModelConfig(
num_kv_heads=8,
num_layers=16,
head_dim=64,
),
"l3b": ModelConfig(
num_kv_heads=8,
num_layers=28,
head_dim=128,
),
"l8b": ModelConfig(
num_kv_heads=8,
num_layers=32,
head_dim=128,
),
"l70b-tp8": _make_70b(8),
}
@torch.inference_mode()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--seqlen", type=int, default=5000)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--page-len", type=int, default=16)
parser.add_argument("--dtype", type=str, default="float16")
args = parser.parse_args()
seqlens_ = [
[1] * args.batch_size,
[args.seqlen - args.batch_size + 1] + [1] * (args.batch_size - 1),
[args.seqlen],
[args.seqlen // args.batch_size] * args.batch_size,
]
seqlen_strlen = max(len(str(seqlens)) for seqlens in seqlens_)
page_len = int(args.page_len)
dtype = getattr(torch, args.dtype)
assert isinstance(dtype, torch.dtype)
device = torch.device("cuda:0")
total_pages = int(256000 / page_len)
torch.cuda.profiler.start()
for model_name, model in MODELS.items():
page_shape = (2, page_len, model.num_kv_heads, model.head_dim)
layer_buf = torch.empty((total_pages,) + page_shape, dtype=dtype, device=device)
for seqlens in seqlens_:
k = torch.rand(
(sum(seqlens), model.num_kv_heads, model.head_dim),
dtype=dtype,
device=device,
)
v = torch.rand(
(sum(seqlens), model.num_kv_heads, model.head_dim),
dtype=dtype,
device=device,
)
x_indptr = torch.tensor([0] + seqlens, device=device, dtype=torch.int32)
x_indptr = torch.cumsum(x_indptr, 0, dtype=torch.int32)
kv_indices_host = []
kv_indptr_host = [0]
next_page_id = 0
for seqlen in seqlens:
npages = (seqlen + page_len - 1) // page_len
kv_indices_host.extend(range(next_page_id, next_page_id + npages))
next_page_id += npages
kv_indptr_host.append(len(kv_indices_host))
kv_indices = torch.tensor(kv_indices_host, device=device, dtype=torch.int32)
kv_indptr = torch.tensor(kv_indptr_host, device=device, dtype=torch.int32)
kv_last_page_len = torch.tensor(
[(seqlen - 1) % page_len + 1 for seqlen in seqlens],
device=device,
dtype=torch.int32,
)
@torch.cuda.nvtx.range(f"convert model={model_name}, seqlens={seqlens}")
def fn_convert() -> Tuple[torch.Tensor, torch.Tensor]:
return flashinfer.get_batch_indices_positions(
x_indptr,
flashinfer.get_seq_lens(kv_indptr, kv_last_page_len, page_len),
k.shape[0],
)
batch_indices, positions = fn_convert()
convert_latencies = bench_gpu_time(fn_convert)
convert_latency_ms = np.median(convert_latencies)
@torch.cuda.nvtx.range(f"append model={model_name}, seqlens={seqlens}")
def fn() -> None:
flashinfer.append_paged_kv_cache(
k,
v,
batch_indices,
positions,
layer_buf,
kv_indices,
kv_indptr,
kv_last_page_len,
"NHD",
)
latencies = bench_gpu_time(fn)
latency_ms = np.median(latencies)
all_layers_latency_ms = convert_latency_ms + latency_ms * model.num_layers
throughput = (
k.numel()
* k.element_size()
* sum(1 for _ in ["k", "v"])
* sum(1 for _ in ["read", "write"])
/ (latency_ms * 1e-3)
)
print(
f"model: {model_name:8}",
f"seqlens: {seqlens!r:{seqlen_strlen}}",
f"convert: {convert_latency_ms * 1e3:2.0f}us",
f"1layer: {latency_ms * 1e3:2.0f}us",
f"{model.num_layers}layers: {all_layers_latency_ms * 1e3:3.0f}us",
f"throughput: {throughput * 1e-9:8.3f}GB/s",
)
print("---")
torch.cuda.profiler.stop()
if __name__ == "__main__":
main()