sglang0.4.5.post1/benchmark/kernels/decoding_attention_triton/triton_flashinfer_cudnn.py

404 lines
12 KiB
Python

import itertools
import math
import cudnn
import torch
import torch.utils.benchmark as benchmark
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd
from sglang.srt.utils import should_use_tensor_core
def benchmark_forward(
fn,
*inputs,
repeats=10,
amp=False,
amp_dtype=torch.float16,
**kwinputs,
):
def amp_wrapper(*inputs, **kwinputs):
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
fn(*inputs, **kwinputs)
t = benchmark.Timer(
stmt="fn_amp(*inputs, **kwinputs)",
globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
num_threads=torch.get_num_threads(),
)
m = t.timeit(repeats)
return t, m
def time_fwd(func, *args, **kwargs):
time_f = benchmark_forward(func, *args, **kwargs)
return time_f[1].mean * 1e6
def decode_attention_sglang(
q,
kv_data,
batch_size,
kv_len,
head_num_q,
head_num_kv,
head_dim,
num_kv_splits,
warmup=10,
):
k_buffer = kv_data[0].view(-1, head_num_kv, head_dim)
v_buffer = kv_data[1].view(-1, head_num_kv, head_dim)
o = torch.empty_like(q)
total_tokens = batch_size * kv_len
req_to_token = torch.arange(0, total_tokens).to(0).int().view(batch_size, kv_len)
b_req_idx = torch.arange(0, batch_size).to(0).int()
b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32, device="cuda")
max_len_in_batch = kv_len
sm_scale = 1.0 / (head_dim**0.5)
attn_logits = torch.empty(
(batch_size, head_num_q, num_kv_splits, head_dim + 1),
dtype=torch.float32,
device="cuda",
)
for _ in range(warmup):
decode_attention_fwd(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_req_idx,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
)
f = time_fwd(
decode_attention_fwd,
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_req_idx,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
)
return f, o
def decode_attention_flashinfer(dtype, head_num_q, head_num_kv):
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda")
use_tensor_cores = should_use_tensor_core(
kv_cache_dtype=dtype,
num_attention_heads=head_num_q,
num_kv_heads=head_num_kv,
)
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores
)
class FlashinferAttention(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
kv_data,
batch_size,
kv_len,
head_num_q,
head_num_kv,
head_dim,
dtype,
warmup=10,
):
total_tokens = batch_size * kv_len
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
kv_indices = torch.arange(0, total_tokens).to(0).int()
kv_last_page_len = torch.full(
(batch_size,), 1, dtype=torch.int32, device="cuda"
)
flashinfer_decode_wrapper.end_forward()
flashinfer_decode_wrapper.begin_forward(
kv_indptr,
kv_indices,
kv_last_page_len,
head_num_q,
head_num_kv,
head_dim,
1,
pos_encoding_mode="NONE",
data_type=dtype,
)
for _ in range(warmup):
o = flashinfer_decode_wrapper.forward(
q.contiguous().view(-1, head_num_q, head_dim), kv_data
)
f = time_fwd(
flashinfer_decode_wrapper.forward,
q.contiguous().view(-1, head_num_q, head_dim),
kv_data,
)
return f, o
return FlashinferAttention
def convert_to_cudnn_type(torch_type):
if torch_type == torch.float16:
return cudnn.data_type.HALF
elif torch_type == torch.bfloat16:
return cudnn.data_type.BFLOAT16
elif torch_type == torch.float32:
return cudnn.data_type.FLOAT
elif torch_type == torch.int32:
return cudnn.data_type.INT32
elif torch_type == torch.int64:
return cudnn.data_type.INT64
else:
raise ValueError("Unsupported tensor data type.")
def decode_attention_cudnn(
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype, warmup=10
):
# Prepare data: continuous q,k,v
dims_q = (batch_size, head_num_q, 1, head_dim)
strides_q = (head_num_q * head_dim, head_dim, head_num_q * head_dim, 1)
q_gpu = q.as_strided(dims_q, strides_q)
o_gpu = (
torch.empty(batch_size * head_num_q * head_dim)
.half()
.cuda()
.as_strided(dims_q, strides_q)
)
dims_kv = (batch_size, head_num_kv, kv_len, head_dim)
strides_kv = (
kv_len * head_num_kv * head_dim,
head_dim,
head_num_kv * head_dim,
1,
)
k_gpu = kv_data[0].as_strided(dims_kv, strides_kv)
v_gpu = kv_data[1].as_strided(dims_kv, strides_kv)
seq_len_q_gpu = torch.full((batch_size, 1, 1, 1), 1, device="cuda")
seq_len_kv_gpu = torch.full((batch_size, 1, 1, 1), kv_len, device="cuda")
attn_scale = 1.0 / (head_dim**0.5)
# Prepare data: paged k,v
block_size = 1
blocks_per_batch = math.ceil(kv_len / block_size)
# [num_blocks, head_num_kv, block_size, head_dim], num_blocks = batch_size * blocks_per_batch
container_k_gpu = torch.cat(k_gpu.chunk(blocks_per_batch, dim=2), dim=0)
container_v_gpu = torch.cat(v_gpu.chunk(blocks_per_batch, dim=2), dim=0)
page_table_k_gpu = (
torch.linspace(
0,
batch_size * blocks_per_batch - 1,
batch_size * blocks_per_batch,
device="cuda",
dtype=torch.int32,
)
.reshape(blocks_per_batch, 1, batch_size, 1)
.transpose(0, 2)
)
page_table_v_gpu = page_table_k_gpu.clone()
graph = cudnn.pygraph(
io_data_type=convert_to_cudnn_type(dtype),
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
)
q = graph.tensor_like(q_gpu)
container_k = graph.tensor_like(container_k_gpu)
container_v = graph.tensor_like(container_v_gpu)
page_table_k = graph.tensor_like(page_table_k_gpu)
page_table_v = graph.tensor_like(page_table_v_gpu)
seq_len_q = graph.tensor_like(seq_len_q_gpu)
seq_len_kv = graph.tensor_like(seq_len_kv_gpu)
o, _ = graph.sdpa(
name="sdpa",
q=q,
k=container_k, # Container K: non contiguous container with K blocks
v=container_v, # Container V: non contiguous container with V blocks
is_inference=True,
attn_scale=attn_scale,
use_causal_mask=False,
use_padding_mask=True,
seq_len_q=seq_len_q,
seq_len_kv=seq_len_kv,
paged_attention_k_table=page_table_k, # Page Table K: Tensor containing offsets to the container with K blocks
paged_attention_v_table=page_table_v, # Page Table V: Tensor containing offsets to the container with V blocks
paged_attention_max_seq_len_kv=kv_len, # The maximum sequence length for K caches (this is optional, but recommended)
)
o.set_output(True).set_dim(dims_q).set_stride(strides_q)
graph.validate()
graph.build_operation_graph()
graph.create_execution_plans([cudnn.heur_mode.A])
graph.check_support()
graph.build_plans()
workspace = torch.empty(
graph.get_workspace_size(), device="cuda", dtype=torch.uint8
)
variant_pack = {
q: q_gpu,
container_k: container_k_gpu,
container_v: container_v_gpu,
page_table_k: page_table_k_gpu,
page_table_v: page_table_v_gpu,
seq_len_q: seq_len_q_gpu,
seq_len_kv: seq_len_kv_gpu,
o: o_gpu,
}
for _ in range(warmup):
graph.execute(variant_pack, workspace)
f = time_fwd(
graph.execute,
variant_pack,
workspace,
)
return f, o_gpu.squeeze(dim=2)
def calculate_diff():
dtype = torch.float16
batch_size = 64
kv_len = 4096
head_num_q = 64
head_num_kv = 8
head_dim = 128
q = torch.randn(batch_size, head_num_q, head_dim, dtype=dtype, device="cuda")
kv_data = (
torch.randn(
batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device="cuda"
),
torch.randn(
batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device="cuda"
),
)
_, output_sglang = decode_attention_sglang(
q,
kv_data,
batch_size,
kv_len,
head_num_q,
head_num_kv,
head_dim,
num_kv_splits=8,
)
attn_flashinfer = decode_attention_flashinfer(dtype, head_num_q, head_num_kv).apply
_, output_flashinfer = attn_flashinfer(
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
)
_, output_cudnn = decode_attention_cudnn(
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
)
print(f"SGLang output={output_sglang}")
print(f"FlashInfer output={output_flashinfer}")
print(f"cuDNN output={output_cudnn}")
if torch.allclose(output_sglang, output_flashinfer, atol=1e-2, rtol=1e-2):
print("✅ SGLang[Triton] and FlashInfer match")
else:
print("❌ SGLang[Triton] and FlashInfer differ")
if torch.allclose(output_sglang, output_cudnn, atol=1e-2, rtol=1e-2):
print("✅ SGLang[Triton] and cuDNN match")
else:
print("❌ SGLang[Triton] and cuDNN differ")
if __name__ == "__main__":
calculate_diff()
head_dim = 128
dtype = torch.float16
batch_size_range = [2**i for i in range(0, 8, 2)]
kv_len_range = [2**i for i in range(6, 13, 1)]
configs = list(itertools.product(batch_size_range, kv_len_range))
for head_num_q, head_num_kv in [[32, 32], [64, 8], [40, 8]]:
attn_flashinfer = decode_attention_flashinfer(
dtype, head_num_q, head_num_kv
).apply
for batch_size, kv_len in configs:
q = torch.randn(
batch_size, head_num_q, head_dim, dtype=dtype, device="cuda"
)
kv_data = (
torch.randn(
batch_size * kv_len,
head_num_kv,
head_dim,
dtype=dtype,
device="cuda",
),
torch.randn(
batch_size * kv_len,
head_num_kv,
head_dim,
dtype=dtype,
device="cuda",
),
)
us_cudnn, output_cudnn = decode_attention_cudnn(
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
)
us_sglang, output_sglang = decode_attention_sglang(
q,
kv_data,
batch_size,
kv_len,
head_num_q,
head_num_kv,
head_dim,
num_kv_splits=8,
)
us_flashinfer, _ = attn_flashinfer(
q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype
)
print(
head_num_q,
" ",
head_num_kv,
" ",
batch_size,
" ",
kv_len,
" ",
us_cudnn,
" ",
us_sglang,
" ",
us_flashinfer,
)