sglang_v0.5.2/sglang/sgl-kernel/benchmark/bench_rotary_embedding.py

97 lines
2.4 KiB
Python

import itertools
import torch
import triton
from sgl_kernel import FusedSetKVBufferArg
from sgl_kernel.testing.rotary_embedding import (
FlashInferRotaryEmbedding,
MHATokenToKVPool,
RotaryEmbedding,
create_inputs,
)
from sglang.srt.bench_utils import bench_kineto
configs = [
(batch_size, seq_len, save_kv_cache)
for batch_size, seq_len in (
(1, 1),
(32, 1),
(128, 1),
(512, 1),
(2, 512),
(4, 4096),
)
for save_kv_cache in (False, True)
]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len", "save_kv_cache"],
x_vals=configs,
line_arg="provider",
line_vals=["sglang"],
line_names=["SGL Kernel"],
styles=[("green", "-")],
ylabel="us",
plot_name="bench_rotary_embedding",
args={},
)
)
def benchmark(batch_size, seq_len, save_kv_cache, provider):
device = torch.device("cuda")
num_q_heads = 32
num_kv_heads = 8
head_size = 64
dtype = torch.bfloat16
config = dict(
head_size=head_size,
rotary_dim=64,
max_position_embeddings=4096,
base=8000,
is_neox_style=True,
dtype=dtype,
)
rope_flashinfer = FlashInferRotaryEmbedding(**config).to(device)
pool_flashinfer = MHATokenToKVPool(head_num=num_kv_heads, head_dim=head_size)
inputs = create_inputs(
head_size=head_size,
batch_size=batch_size,
seq_len=seq_len,
device=device,
dtype=dtype,
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
)
query_flashinfer, key_flashinfer = inputs["query"].clone(), inputs["key"].clone()
bench_fn = lambda: rope_flashinfer.forward_cuda(
inputs["pos_ids"],
query_flashinfer,
key_flashinfer,
fused_set_kv_buffer_arg=(
FusedSetKVBufferArg(
value=inputs["value"],
k_buffer=pool_flashinfer.k_buffer[0].view(-1, num_kv_heads * head_size),
v_buffer=pool_flashinfer.v_buffer[0].view(-1, num_kv_heads * head_size),
k_scale=None,
v_scale=None,
cache_loc=inputs["out_cache_loc"],
)
if save_kv_cache
else None
),
)
time_s = bench_kineto(bench_fn, kernel_names="BatchQKApplyRotaryPosIds")
return time_s * 1e6
if __name__ == "__main__":
benchmark.run(print_data=True)