sglang_v0.5.2/flashinfer_0.3.1/benchmarks/routines/attention.py

1825 lines
65 KiB
Python

from collections import defaultdict
import numpy as np
import torch
import flashinfer
from flashinfer.testing.utils import (
attention_tb_per_sec_with_actual_seq_lens,
attention_tflops_per_sec_with_actual_seq_lens,
bench_gpu_time,
bench_gpu_time_with_cudagraph,
)
from .flashinfer_benchmark_utils import (
dtype_str_to_torch_dtype,
get_device,
print_perf_metrics,
is_close_stats,
)
def run_attention_test(args):
"""
Run an attention test.
Args:
args: Parsed command line arguments containing test configuration
Returns:
dict: List of dictionaries containing performance results
"""
if args.routine == "BatchDecodeWithPagedKVCacheWrapper":
return testBatchDecodeWithPagedKVCacheWrapper(args)
elif args.routine == "BatchPrefillWithPagedKVCacheWrapper":
return testBatchPrefillWithPagedKVCacheWrapper(args)
elif args.routine == "BatchPrefillWithRaggedKVCacheWrapper":
return testBatchPrefillWithRaggedKVCacheWrapper(args)
elif args.routine == "BatchMLAPagedAttentionWrapper":
return testBatchMLAPagedAttentionWrapper(args)
else:
raise ValueError(f"Unsupported routine: {args.routine}")
def parse_attention_args(line, parser):
"""
Parse command line arguments for attention test configuration.
Args:
line: Command line arguments
parser: ArgumentParser object already populated with shared arguments
Returns:
Parsed argument namespace
"""
parser.add_argument(
"--backends",
type=str,
required=False,
nargs="+",
default=["fa2"],
choices=[
"fa2",
"fa2_tc",
"fa3",
"cudnn",
"cutlass",
"trtllm-gen",
"trtllm-gen-native",
],
help="Kernel backends to test. Default: fa2",
)
parser.add_argument(
"--page_size",
type=int,
required=False,
default=0,
help="Page size for paged attention. Required for paged attention. Ignored for non-paged attention.",
)
parser.add_argument(
"--batch_size", type=int, required=True, help="Batch size of test case."
)
parser.add_argument(
"--s_qo",
type=int,
required=False,
default=1,
help="Max sequence length of the query. Should be 1 for decode.",
)
parser.add_argument(
"--s_kv",
type=int,
required=True,
help="Max sequence length of the key and value.",
)
parser.add_argument(
"--num_qo_heads", type=int, required=True, help="Number of query heads."
)
parser.add_argument(
"--num_kv_heads", type=int, required=True, help="Number of key and value heads."
)
parser.add_argument(
"--head_dim_qk",
type=int,
required=False,
help="Head dimension of the query and key for prefill and decode MHA/GQA/MQA.",
)
parser.add_argument(
"--head_dim_vo",
type=int,
required=False,
help="Head dimension of the value and output for prefill and decode MHA/GQA/MQ.",
)
parser.add_argument(
"--head_dim_ckv",
type=int,
required=False,
help="Head dimension of compressed kv-cache tensor (without rope).",
)
parser.add_argument(
"--head_dim_kpe",
type=int,
required=False,
help="Head dimension of the rope part of the kv-cache tensor.",
)
parser.add_argument(
"--q_dtype",
type=str,
required=False,
default="bfloat16",
help="Data type of the query. Currently only bfloat16 is supported.",
)
parser.add_argument(
"--kv_dtype",
type=str,
required=False,
default="bfloat16",
help="Data type of the key and value. Currently only bfloat16 is supported.",
)
parser.add_argument(
"--causal",
action="store_true",
default=False,
help="Causal masking. Note: not padding masking. Only used for prefill tests.",
)
parser.add_argument(
"--random_actual_seq_len",
action="store_true",
default=False,
help="Use random actual sequence lengths for the query and key and value. Random values are generated between 1 and maximum sequence length. If False, use maximum sequence length.",
)
args = parser.parse_args(line)
if args.verbose >= 1:
print(f"[INFO] {args = }")
return args
def sample_actual_seq_lens(max_seqlen, batch_size, device, random_actual_seq_len):
"""
Get an array of actual sequence lengths for given batch size and max sequence length.
If random_actual_seq_len is True, sample actual sequence lengths randomly.
Otherwise, set all actual sequence lengths to max_seqlen.
Args:
max_seqlen: Maximum sequence length.
batch_size: Batch size.
device: Device to sample on.
random_actual_seq_len: Whether to sample actual sequence lengths randomly.
Returns:
actual_seq_lens: Actual sequence lengths for each batch.
"""
if random_actual_seq_len:
actual_seq_lens = torch.randint(
1, max_seqlen + 1, (batch_size, 1, 1, 1), device=device, dtype=torch.int32
)
else:
actual_seq_lens = torch.full(
(batch_size, 1, 1, 1), max_seqlen, device=device, dtype=torch.int32
)
return actual_seq_lens
def testBatchDecodeWithPagedKVCacheWrapper(args):
"""
Test BatchDecodeWithPagedKVCacheWrapper API and equivalent cuDNN API.
Supports fa2, fa2_tc, cudnn, trtllm-gen, trtllm-gen-native backends.
This test:
1. Creates paged KV cache and query tensors
2. Runs decode attention with different backends
3. Verifies outputs match between backends
4. Measures performance metrics (TFLOPS, TB/sec)
Args:
args: Parsed command line arguments containing test configuration
Returns:
dict: List of dictionaries containing performance results
"""
if args.verbose >= 1:
print("[INFO] Running testBatchDecodeWithPagedKVCacheWrapper")
print(f"[INFO] FlashInfer version: {flashinfer.__version__}")
# Basic setup
device = get_device(args)
if args.generate_repro_command:
print(
f"[INFO] To reproduce this test case, run the following command: {args.repro_command}"
)
q_init_dtype = torch.bfloat16
kv_init_dtype = torch.bfloat16
rtol = 2e-1
atol = 1e-2
# Handle different query data types.
q_dtype = dtype_str_to_torch_dtype(args.q_dtype)
if q_dtype not in [torch.bfloat16, torch.float8_e4m3fn]:
raise ValueError(f"Unsupported q_dtype: {args.q_dtype}")
# Handle different KV cache data types.
kv_dtype = dtype_str_to_torch_dtype(args.kv_dtype)
if kv_dtype not in [torch.bfloat16, torch.float8_e4m3fn]:
raise ValueError(f"Unsupported kv_dtype: {args.kv_dtype}")
# Parse and validate backend configurations
backends = args.backends
page_size = args.page_size
batch_size = args.batch_size
s_qo = args.s_qo
s_kv = args.s_kv
num_qo_heads = args.num_qo_heads
num_kv_heads = args.num_kv_heads
head_dim_qk = args.head_dim_qk
head_dim_vo = args.head_dim_vo
is_cuda_graph_compatible = not args.no_cuda_graph
# return_lse = not args.no_lse # TO-DO: Add support for this
run_refcheck = args.refcheck
# Derived parameters
if "fa2" in backends:
remove_fa2 = False
head_grp_size = (
num_qo_heads // num_kv_heads
) # If 5, FA2 backend is not supported.
if head_grp_size == 5:
print(
"[INFO] FA2 backend is not supported for this configuration. Skipping."
)
remove_fa2 = True
if remove_fa2:
backends.remove("fa2")
if "fa2_tc" in backends:
remove_fa2_tc = False
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
]:
print("[INFO] FA2_TC backend does not support FP8. Skipping.")
remove_fa2_tc = True
if remove_fa2_tc:
backends.remove("fa2_tc")
if "cudnn" in backends:
remove_cudnn = False
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
]:
print("[INFO] cuDNN backend does not support FP8. Skipping.")
remove_cudnn = True
if remove_cudnn:
backends.remove("cudnn")
if len(backends) == 0:
print("[ERROR] No backends to test. Exiting.")
return
# Storage for timing results and outputs
backend_times = {backend: [] for backend in backends}
outputs = {}
# Sample sequence lengths and create tensors
actual_seq_lens_kv = sample_actual_seq_lens(
s_kv, batch_size, device, args.random_actual_seq_len
)
sum_seq_kv = torch.sum(actual_seq_lens_kv).item()
avg_seq_len_kv = sum_seq_kv // batch_size
if args.verbose >= 1:
print(f"[VERBOSE] Average actual seq len: {avg_seq_len_kv}")
if args.verbose >= 2:
print(f"[VVERBOSE] {actual_seq_lens_kv.flatten() = }")
# Create query tensor
q = torch.rand(
batch_size, num_qo_heads, head_dim_qk, device=device, dtype=q_init_dtype
)
if args.verbose >= 2:
print(f"[VVERBOSE] {q.shape = }")
# Create KV cache
num_pages_per_seq = (s_kv + page_size - 1) // page_size
total_num_pages = num_pages_per_seq * batch_size
if args.verbose >= 2:
print(f"[VVERBOSE] {num_pages_per_seq = }")
print(f"[VVERBOSE] {total_num_pages = }")
# Initialize KV cache with appropriate shape and stride
kv_cache_shape = (
total_num_pages,
2, # 2 for key and value
num_kv_heads,
page_size,
head_dim_qk,
)
kv_cache = torch.randn(size=kv_cache_shape, dtype=kv_init_dtype).to(device)
# Keep a copy for TRT-LLM which uses different strides
if "trtllm-gen" in backends:
kv_cache_for_trt = kv_cache.detach().clone()
kv_cache = kv_cache.as_strided(
kv_cache.shape,
(
2 * page_size * num_kv_heads * head_dim_qk,
page_size * num_kv_heads * head_dim_qk,
head_dim_qk,
num_kv_heads * head_dim_qk,
1,
),
)
k_cache_view, v_cache_view = kv_cache[:, 0, :, :, :], kv_cache[:, 1, :, :, :]
if "trtllm-gen" in backends:
# kv_cache now has different tensor stride and logical values. Copy over values to kv_cache_for_trt.
# Result is kv_cache and kv_cache_for_trt have the same logical values but different tensor strides.
kv_cache_for_trt.copy_(kv_cache)
v_cache = v_cache_view.as_strided(
v_cache_view.shape,
(
2 * page_size * num_kv_heads * head_dim_qk,
head_dim_qk,
num_kv_heads * head_dim_qk,
1,
),
)
k_cache = k_cache_view.as_strided(
k_cache_view.shape,
(
2 * page_size * num_kv_heads * head_dim_qk,
head_dim_qk,
num_kv_heads * head_dim_qk,
1,
),
)
# Now initialize the page tables
block_tables = torch.tensor(
[
[k + i * num_pages_per_seq for k in range(num_pages_per_seq)]
for i in range(batch_size)
],
dtype=torch.int,
device=device,
)
kv_indptr = (
torch.cat(
[
torch.tensor([0], device=device),
torch.cumsum(
(actual_seq_lens_kv.flatten() + page_size - 1) // page_size, dim=0
),
]
)
.int()
.to(device)
)
# kv_indices[-1] is the total number of actual pages
kv_indices = torch.zeros(kv_indptr[-1], device=device, dtype=torch.int32)
for i in range(len(kv_indptr) - 1):
start_idx = kv_indptr[i]
end_idx = kv_indptr[i + 1]
kv_indices[start_idx:end_idx] = torch.arange(
i * num_pages_per_seq,
i * num_pages_per_seq + (end_idx - start_idx),
device=device,
)
kv_last_page_len = (
torch.where(
actual_seq_lens_kv.flatten() % page_size == 0,
torch.full((batch_size,), page_size, device=device),
actual_seq_lens_kv.flatten() % page_size,
)
.int()
.to(device)
)
ragged_q = (
torch.arange(0, batch_size + 1, device=device) * (num_qo_heads * head_dim_qk)
).long() # For cuDNN
scale = float(1.0 / (head_dim_qk**0.5))
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
if args.verbose >= 2:
print(f"[VVERBOSE] {kv_cache.shape = }")
print(f"[VVERBOSE] {kv_cache.stride() = }")
print(f"[VVERBOSE] {block_tables.shape = }")
print(f"[VVERBOSE] {kv_indptr.shape = }")
print(f"[VVERBOSE] {kv_indices.shape = }")
print(f"[VVERBOSE] {kv_last_page_len.shape = }")
print(f"[VVERBOSE] {scale = }")
# Prepare wrappers
backend_wrappers = {}
for backend in backends:
if backend in ["fa2", "fa2_tc", "trtllm-gen"]:
plan_kv_indptr = (
kv_indptr.clone().detach() if backend == "trtllm-gen" else kv_indptr
)
backend_wrappers[backend] = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
"HND",
use_cuda_graph=is_cuda_graph_compatible,
use_tensor_cores=(backend != "fa2"),
paged_kv_indptr_buffer=plan_kv_indptr,
paged_kv_indices_buffer=kv_indices,
paged_kv_last_page_len_buffer=kv_last_page_len,
backend=backend,
)
backend_wrappers[backend].plan(
plan_kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim_qk,
page_size,
q_data_type=q_dtype,
data_type=kv_dtype,
)
## If FP8, prepare
k_scale, v_scale = None, None
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
q = q.to(q_dtype)
if kv_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
k_data, v_data = torch.chunk(kv_cache, 2, dim=1)
k_scale = k_data.amax().item() / 256
v_scale = v_data.amax().item() / 256
k_fp8 = (k_data / k_scale).to(kv_dtype)
v_fp8 = (v_data / v_scale).to(kv_dtype)
kv_cache = torch.cat([k_fp8, v_fp8], dim=1)
if "trtllm-gen" in backends:
k_data, v_data = torch.chunk(kv_cache_for_trt, 2, dim=1)
k_fp8 = (k_data / k_scale).to(kv_dtype)
v_fp8 = (v_data / v_scale).to(kv_dtype)
kv_cache_for_trt = torch.cat([k_fp8, v_fp8], dim=1)
def run_backend_wrapper(backend):
if backend in ["fa2", "fa2_tc", "trtllm-gen"]:
return backend_wrappers[backend].run(
q, kv_cache, k_scale=k_scale, v_scale=v_scale
)
elif backend == "cudnn":
return flashinfer.decode.cudnn_batch_decode_with_kv_cache(
q,
k_cache,
v_cache,
scale,
workspace_buffer,
max_sequence_kv=s_kv,
actual_seq_lens_kv=actual_seq_lens_kv,
block_tables=block_tables,
is_cuda_graph_compatible=is_cuda_graph_compatible,
batch_offsets_q=ragged_q,
batch_offsets_o=ragged_q,
)
elif backend == "trtllm-gen-native":
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query=q.contiguous(),
kv_cache=kv_cache,
workspace_buffer=workspace_buffer,
block_tables=block_tables,
seq_lens=actual_seq_lens_kv,
max_seq_len=s_kv,
bmm1_scale=scale if k_scale is None else k_scale * scale,
bmm2_scale=1.0 if v_scale is None else v_scale,
)
else:
raise ValueError(f"Backend {backend} not supported")
has_reference_output = False
# Iterate over each backend:
for cur_backend in backends:
if run_refcheck:
outputs[cur_backend] = run_backend_wrapper(cur_backend).detach().clone()
if cur_backend == "fa2":
has_reference_output = True
reference_output = outputs[cur_backend]
if is_cuda_graph_compatible and cur_backend != "fa2":
backend_times[cur_backend] = bench_gpu_time_with_cudagraph(
fn=lambda: run_backend_wrapper(cur_backend),
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
num_iters_within_graph=20,
l2_flush=True,
l2_flush_size_mb=256,
l2_flush_device=device,
sleep_after_run=False,
)
else:
backend_times[cur_backend] = bench_gpu_time(
fn=lambda: run_backend_wrapper(cur_backend),
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
l2_flush=True,
l2_flush_size_mb=256,
l2_flush_device=device,
sleep_after_run=False,
)
# Perform reference check
tested_backends = list(outputs.keys())
tested_outputs = list(outputs.values())
if len(tested_backends) > 1:
if run_refcheck and has_reference_output:
if reference_output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
if args.verbose >= 2:
print(
"[VVERBOSE] Reference output is FP8. Converting to float32 for reference check."
)
reference_output = reference_output.to(torch.float32)
tested_outputs = [output.to(torch.float32) for output in tested_outputs]
for i in range(len(tested_outputs)):
try:
torch.testing.assert_close(
reference_output, tested_outputs[i], rtol=rtol, atol=atol
)
except AssertionError as e:
(
num_different_elements,
num_elements,
num_different_elements_percentage,
) = is_close_stats(reference_output, tested_outputs[i], rtol, atol)
print(
f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}: "
f"{num_different_elements} / {num_elements} ({num_different_elements_percentage:.2f}%) elements are different"
)
if not args.allow_output_mismatch:
print(e)
raise
# Compute perf metrics
res = []
for backend in backends:
if len(backend_times[backend]) > 0:
median_time = np.median(backend_times[backend])
std_time = np.std(backend_times[backend])
actual_seq_lens_kv_flat = actual_seq_lens_kv.flatten().to("cpu")
actual_seq_lens_q_flat = torch.ones_like(actual_seq_lens_kv_flat)
tflops = attention_tflops_per_sec_with_actual_seq_lens(
actual_seq_lens_q_flat,
actual_seq_lens_kv_flat,
head_dim_qk,
head_dim_vo,
num_qo_heads,
False,
median_time,
)
tb_per_sec = attention_tb_per_sec_with_actual_seq_lens(
actual_seq_lens_q_flat,
actual_seq_lens_kv_flat,
head_dim_qk,
head_dim_vo,
num_qo_heads,
num_kv_heads,
median_time,
q_dtype=q_dtype,
kv_dtype=kv_dtype,
o_dtype=q_dtype,
)
print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec)
if args.output_path is not None:
cur_res = defaultdict(str)
cur_res["routine"] = args.routine
cur_res["median_time"] = median_time
cur_res["std_time"] = std_time
cur_res["tflops"] = tflops
cur_res["tb_per_sec"] = tb_per_sec
cur_res["backend"] = backend
cur_res["page_size"] = page_size
cur_res["batch_size"] = batch_size
cur_res["s_qo"] = s_qo
cur_res["s_kv"] = s_kv
cur_res["num_qo_heads"] = num_qo_heads
cur_res["num_kv_heads"] = num_kv_heads
cur_res["head_dim_qk"] = head_dim_qk
cur_res["head_dim_vo"] = head_dim_vo
cur_res["causal"] = False
cur_res["q_dtype"] = q_dtype
cur_res["kv_dtype"] = kv_dtype
cur_res["avg_actual_seq_len"] = avg_seq_len_kv
cur_res["random_actual_seq_len"] = args.random_actual_seq_len
cur_res["case_tag"] = args.case_tag
res.append(cur_res)
return res
def testBatchPrefillWithPagedKVCacheWrapper(args):
"""
Test BatchPrefillWithPagedKVCacheWrapper API and equivalent cuDNN API.
Supports fa2, fa3, trtllm-gen, trtllm-gen-native, and cudnn backends.
This test:
1. Creates paged KV cache and query tensors for prefill
2. Runs prefill attention with different backends
3. Verifies outputs match between backends (if refcheck enabled)
4. Measures performance metrics (TFLOPS, TB/sec)
Args:
args: Parsed command line arguments containing test configuration
Returns:
dict: Dictionary containing performance results
"""
if args.verbose >= 1:
print("[INFO] Running testBatchPrefillWithPagedKVCacheWrapper")
print(f"[INFO] FlashInfer version: {flashinfer.__version__}")
# Basic setup
device = get_device(args)
if args.generate_repro_command:
print(
f"[INFO] To reproduce this test case, run the following command: {args.repro_command}"
)
q_init_dtype = torch.bfloat16
kv_init_dtype = torch.bfloat16
rtol = 2e-1
atol = 1e-2
q_dtype = dtype_str_to_torch_dtype(args.q_dtype)
if q_dtype not in [torch.bfloat16, torch.float8_e4m3fn]:
raise ValueError(f"Unsupported q_dtype: {args.q_dtype}")
kv_dtype = dtype_str_to_torch_dtype(args.kv_dtype)
if kv_dtype not in [torch.bfloat16, torch.float8_e4m3fn]:
raise ValueError(f"Unsupported kv_dtype: {args.kv_dtype}")
# Parse and validate backend configurations
backends = args.backends
page_size = args.page_size
batch_size = args.batch_size
s_qo = args.s_qo
s_kv = args.s_kv
num_qo_heads = args.num_qo_heads
num_kv_heads = args.num_kv_heads
head_dim_qk = args.head_dim_qk
head_dim_vo = args.head_dim_vo
causal = args.causal
is_cuda_graph_compatible = not args.no_cuda_graph
# return_lse = not args.no_lse # TO-DO: Add support for this
run_refcheck = args.refcheck
# Check for backend-specific constraints
if "fa2" in backends:
remove_fa2 = False
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
print("[INFO] FA2 backend does not support FP8. Skipping.")
remove_fa2 = True
if remove_fa2:
backends.remove("fa2")
if "fa3" in backends:
remove_fa3 = False
device_capability = torch.cuda.get_device_capability()
if device_capability[0] != 9:
print(
f"[INFO] FA3 backend does not support capability {device_capability}. Skipping."
)
remove_fa3 = True
if remove_fa3:
backends.remove("fa3")
if "cudnn" in backends:
remove_cudnn = False
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
]:
print("[INFO] cuDNN backend does not support FP8. Skipping.")
remove_cudnn = True
if remove_cudnn:
backends.remove("cudnn")
if "trtllm-gen" in backends:
remove_trtllm = False
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
]:
print("[INFO] trtllm-gen backend does not support FP8. Skipping.")
remove_trtllm = True
if remove_trtllm:
backends.remove("trtllm-gen")
if "cutlass" in backends:
print("[INFO] CUTLASS backend does not support prefill. Skipping.")
remove_cutlass = True
if remove_cutlass:
backends.remove("cutlass")
if len(backends) == 0:
print("[ERROR] No backends to test. Exiting.")
return
# Check for layer-specific constraints
layer_not_supported = False
if layer_not_supported:
print("[ERROR] Layer not supported. Exiting.")
return
# Storage for timing results and outputs
backend_times = {backend: [] for backend in backends}
outputs = {}
# Randomly sample actual_seq_lens_q. Assume actual_seq_lens_kv is the same as actual_seq_lens_q.
actual_seq_lens_q = sample_actual_seq_lens(
s_qo, batch_size, None, args.random_actual_seq_len
)
actual_seq_lens_kv = actual_seq_lens_q.clone()
avg_seq_len_q = actual_seq_lens_q.sum().item() // batch_size
if args.verbose >= 1:
print(f"[VERBOSE] Average actual seq len: {avg_seq_len_q}")
if args.verbose >= 2:
print(f"[VVERBOSE] {actual_seq_lens_q.flatten() = }")
cumsum_s_qo = torch.sum(actual_seq_lens_q)
q = torch.randn(
cumsum_s_qo, num_qo_heads, head_dim_qk, device=device, dtype=q_init_dtype
)
if args.verbose >= 2:
print(f"[VVERBOSE] {q.shape = }")
# Create KV cache
num_pages_per_seq = (s_kv + page_size - 1) // page_size
total_num_pages = num_pages_per_seq * batch_size
if args.verbose >= 2:
print(f"[VVERBOSE] {num_pages_per_seq = }")
print(f"[VVERBOSE] {total_num_pages = }")
kv_cache_shape = (total_num_pages, 2, num_kv_heads, page_size, head_dim_qk)
kv_cache = torch.randn(size=kv_cache_shape, dtype=kv_init_dtype).to(device)
kv_cache = kv_cache.as_strided(
kv_cache.shape,
(
2 * page_size * num_kv_heads * head_dim_qk,
page_size * num_kv_heads * head_dim_qk,
head_dim_qk,
num_kv_heads * head_dim_qk,
1,
),
)
k_cache_view, v_cache_view = kv_cache[:, 0, :, :, :], kv_cache[:, 1, :, :, :]
v_cache = v_cache_view.as_strided(
v_cache_view.shape,
(
2 * page_size * num_kv_heads * head_dim_qk,
head_dim_qk,
num_kv_heads * head_dim_qk,
1,
),
)
k_cache = k_cache_view.as_strided(
k_cache_view.shape,
(
2 * page_size * num_kv_heads * head_dim_qk,
head_dim_qk,
num_kv_heads * head_dim_qk,
1,
),
)
# Now initialize the page tables
block_tables = torch.tensor(
[
[k + i * num_pages_per_seq for k in range(num_pages_per_seq)]
for i in range(batch_size)
],
dtype=torch.int,
device=device,
)
actual_seq_lens_q_device = actual_seq_lens_q.to(device)
actual_seq_lens_kv_device = actual_seq_lens_kv.to(device)
q_indptr = (
torch.cat(
[
torch.tensor([0], device=device),
torch.cumsum(actual_seq_lens_q_device.view(-1), dim=0)
* head_dim_qk
* num_qo_heads,
]
)
.long()
.to(device)
) # For cuDNN
qo_indptr = (
torch.cat(
[
torch.tensor([0], device=device),
torch.cumsum(actual_seq_lens_q_device.view(-1), dim=0),
]
)
.int()
.to(device)
)
# Because actual_seq_lens_kv is the same as actual_seq_lens_q, kv_indptr will become the same as qo_indptr
kv_indptr = (
torch.cat(
[
torch.tensor([0], device=device),
torch.cumsum(
(actual_seq_lens_kv_device.flatten() + page_size - 1) // page_size,
dim=0,
),
]
)
.int()
.to(device)
)
kv_indices = torch.zeros(kv_indptr[-1], device=device, dtype=torch.int32)
for i in range(len(kv_indptr) - 1):
start_idx = kv_indptr[i]
end_idx = kv_indptr[i + 1]
kv_indices[start_idx:end_idx] = torch.arange(
i * num_pages_per_seq,
i * num_pages_per_seq + (end_idx - start_idx),
device=device,
)
kv_last_page_len = (
torch.where(
actual_seq_lens_kv_device.flatten() % page_size == 0,
torch.full((batch_size,), page_size, device=device),
actual_seq_lens_kv_device.flatten() % page_size,
)
.int()
.to(device)
)
scale = float(1.0 / (head_dim_qk**0.5))
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
if args.verbose >= 2:
print(f"[VVERBOSE] {kv_cache.shape = }")
print(f"[VVERBOSE] {kv_cache.stride() = }")
print(f"[VVERBOSE] {block_tables.shape = }")
print(f"[VVERBOSE] {qo_indptr.shape = }")
print(f"[VVERBOSE] {qo_indptr.dtype = }")
print(f"[VVERBOSE] {kv_indptr.shape = }")
print(f"[VVERBOSE] {kv_indices.shape = }")
print(f"[VVERBOSE] {kv_last_page_len.shape = }")
print(f"[VVERBOSE] {scale = }")
# Prepare wrappers
backend_wrappers = {}
for backend in backends:
if backend in ["fa2", "fa3", "trtllm-gen"]:
backend_wrappers[backend] = (
flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer,
"HND",
use_cuda_graph=is_cuda_graph_compatible
if backend != "fa2"
else False,
qo_indptr_buf=qo_indptr,
paged_kv_indptr_buf=kv_indptr,
paged_kv_indices_buf=kv_indices,
paged_kv_last_page_len_buf=kv_last_page_len,
backend=backend,
)
)
backend_wrappers[backend].plan(
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim_qk,
page_size,
pos_encoding_mode="NONE",
causal=causal,
q_data_type=q_dtype,
kv_data_type=kv_dtype,
)
k_scale, v_scale = None, None
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
q = q.to(q_dtype)
if kv_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
k_data, v_data = torch.chunk(kv_cache, 2, dim=1)
k_scale = k_data.amax().item() / 256
v_scale = v_data.amax().item() / 256
k_fp8 = (k_data / k_scale).to(kv_dtype)
v_fp8 = (v_data / v_scale).to(kv_dtype)
kv_cache = torch.cat([k_fp8, v_fp8], dim=1)
def run_backend_wrapper(backend):
if backend in ["fa2", "fa3", "trtllm-gen"]:
return backend_wrappers[backend].run(
q, kv_cache, k_scale=k_scale, v_scale=v_scale
)
elif backend == "cudnn":
return flashinfer.prefill.cudnn_batch_prefill_with_kv_cache(
q,
k_cache,
v_cache,
scale,
workspace_buffer,
max_token_per_sequence=s_qo,
max_sequence_kv=s_kv,
actual_seq_lens_q=actual_seq_lens_q_device,
actual_seq_lens_kv=actual_seq_lens_kv_device,
block_tables=block_tables,
causal=causal,
return_lse=True,
is_cuda_graph_compatible=is_cuda_graph_compatible,
batch_offsets_q=q_indptr,
batch_offsets_o=q_indptr,
)[0]
elif backend == "trtllm-gen-native":
return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
query=q,
kv_cache=kv_cache,
workspace_buffer=workspace_buffer,
block_tables=block_tables,
seq_lens=actual_seq_lens_kv_device,
max_q_len=s_qo,
max_kv_len=s_kv,
bmm1_scale=scale if k_scale is None else k_scale * scale,
bmm2_scale=1.0 if v_scale is None else v_scale,
batch_size=batch_size,
cum_seq_lens_q=qo_indptr,
cum_seq_lens_kv=kv_indptr,
)
else:
raise ValueError(f"Backend {backend} not supported")
has_reference_output = False
# Iterate over each backend:
for cur_backend in backends:
if run_refcheck:
outputs[cur_backend] = run_backend_wrapper(cur_backend).detach().clone()
if cur_backend == "fa2":
has_reference_output = True
reference_output = outputs[cur_backend]
if is_cuda_graph_compatible and cur_backend != "fa2":
backend_times[cur_backend] = bench_gpu_time_with_cudagraph(
fn=lambda: run_backend_wrapper(cur_backend),
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
num_iters_within_graph=20,
l2_flush=True,
l2_flush_size_mb=256,
l2_flush_device=device,
sleep_after_run=False,
)
else:
backend_times[cur_backend] = bench_gpu_time(
fn=lambda: run_backend_wrapper(cur_backend),
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
l2_flush=True,
l2_flush_size_mb=256,
l2_flush_device=device,
sleep_after_run=False,
)
# Perform reference check
tested_backends = list(outputs.keys())
tested_outputs = list(outputs.values())
if len(tested_backends) > 1:
if run_refcheck and has_reference_output:
if reference_output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
if args.verbose >= 2:
print(
"[VVERBOSE] Reference output is FP8. Converting to float32 for reference check."
)
reference_output = reference_output.to(torch.float32)
tested_outputs = [output.to(torch.float32) for output in tested_outputs]
for i in range(len(tested_backends)):
try:
torch.testing.assert_close(
reference_output, tested_outputs[i], rtol=rtol, atol=atol
)
except AssertionError as e:
(
num_different_elements,
num_elements,
num_different_elements_percentage,
) = is_close_stats(reference_output, tested_outputs[i], rtol, atol)
print(
f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}: "
f"{num_different_elements} / {num_elements} ({num_different_elements_percentage:.2f}%) elements are different"
)
if not args.allow_output_mismatch:
print(e)
raise
# Compute perf metrics
res = []
for backend in backends:
if len(backend_times[backend]) > 0:
median_time = np.median(backend_times[backend])
std_time = np.std(backend_times[backend])
actual_seq_lens_q_flat = actual_seq_lens_q.flatten().to("cpu")
actual_seq_lens_kv_flat = actual_seq_lens_kv.flatten().to("cpu")
tflops = attention_tflops_per_sec_with_actual_seq_lens(
actual_seq_lens_q_flat,
actual_seq_lens_kv_flat,
head_dim_qk,
head_dim_vo,
num_qo_heads,
causal,
median_time,
)
tb_per_sec = attention_tb_per_sec_with_actual_seq_lens(
actual_seq_lens_q_flat,
actual_seq_lens_kv_flat,
head_dim_qk,
head_dim_vo,
num_qo_heads,
num_kv_heads,
median_time,
q_dtype=q_dtype,
kv_dtype=kv_dtype,
o_dtype=q_dtype,
)
print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec)
if args.output_path is not None:
cur_res = defaultdict(str)
cur_res["routine"] = args.routine
cur_res["median_time"] = median_time
cur_res["std_time"] = std_time
cur_res["tflops"] = tflops
cur_res["tb_per_sec"] = tb_per_sec
cur_res["backend"] = backend
cur_res["page_size"] = page_size
cur_res["batch_size"] = batch_size
cur_res["s_qo"] = s_qo
cur_res["s_kv"] = s_kv
cur_res["num_qo_heads"] = num_qo_heads
cur_res["num_kv_heads"] = num_kv_heads
cur_res["head_dim_qk"] = head_dim_qk
cur_res["head_dim_vo"] = head_dim_vo
cur_res["causal"] = causal
cur_res["q_dtype"] = q_dtype
cur_res["kv_dtype"] = kv_dtype
cur_res["avg_actual_seq_len"] = avg_seq_len_q
cur_res["random_actual_seq_len"] = args.random_actual_seq_len
cur_res["case_tag"] = args.case_tag
res.append(cur_res)
return res
def testBatchPrefillWithRaggedKVCacheWrapper(args):
"""
Test BatchPrefillWithRaggedKVCacheWrapper API and equivalent cuDNN API.
Supports fa2, fa3, cutlass, and cudnn backends.
This test:
1. Creates ragged KV cache and query tensors for prefill
2. Runs prefill attention with different backends
3. Verifies outputs match between backends (if refcheck enabled)
4. Measures performance metrics (TFLOPS, TB/sec)
Args:
args: Parsed command line arguments containing test configuration
Returns:
dict: Dictionary containing performance results
"""
if args.verbose >= 1:
print("[INFO] Running testBatchPrefillWithRaggedKVCacheWrapper")
print(f"[INFO] FlashInfer version: {flashinfer.__version__}")
# Basic setup
device = get_device(args)
if args.generate_repro_command:
print(
f"[INFO] To reproduce this test case, run the following command: {args.repro_command}"
)
q_init_dtype = torch.bfloat16
kv_init_dtype = torch.bfloat16
rtol = 2e-1
atol = 1e-2
q_dtype = dtype_str_to_torch_dtype(args.q_dtype)
if q_dtype not in [torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2]:
raise ValueError(f"Unsupported q_dtype: {args.q_dtype}")
kv_dtype = dtype_str_to_torch_dtype(args.kv_dtype)
if kv_dtype not in [torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2]:
raise ValueError(f"Unsupported kv_dtype: {args.kv_dtype}")
# Parse and validate backend configurations
backends = args.backends
batch_size = args.batch_size
s_qo = args.s_qo
s_kv = args.s_kv
num_qo_heads = args.num_qo_heads
num_kv_heads = args.num_kv_heads
head_dim_qk = args.head_dim_qk
head_dim_vo = args.head_dim_vo
causal = args.causal
is_cuda_graph_compatible = not args.no_cuda_graph
# return_lse = not args.no_lse # TO-DO: Add support for this
run_refcheck = args.refcheck
# Check for backend-specific constraints
if "cudnn" in backends:
remove_cudnn = False
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
]:
print("[INFO] CUDNN backend does not support FP8. Skipping.")
remove_cudnn = True
if remove_cudnn:
backends.remove("cudnn")
if "cutlass" in backends:
remove_cutlass = False
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
]:
print("[INFO] CUTLASS backend does not support FP8. Skipping.")
remove_cutlass = True
if remove_cutlass:
backends.remove("cutlass")
if "trtllm-gen" in backends:
print("[INFO] trtllm-gen backend does not support ragged prefill. Skipping.")
remove_trtllm = True
if remove_trtllm:
backends.remove("trtllm-gen")
if len(backends) == 0:
print("[ERROR] No backends to test. Exiting.")
return
# Check for layer-specific constraints
layer_not_supported = False
if not ((head_dim_qk == 128 and head_dim_qk == head_dim_vo) or head_dim_qk == 192):
print("[ERROR] Head dimension must be 128 or 192")
layer_not_supported = True
if layer_not_supported:
print("[ERROR] Layer not supported. Exiting.")
return
backend_times = {backend: [] for backend in backends}
outputs = {}
# Randomly sample actual_seq_lens_q. Assume actual_seq_lens_kv is the same as actual_seq_lens_q.
actual_seq_lens_q = sample_actual_seq_lens(
s_qo, batch_size, None, args.random_actual_seq_len
)
actual_seq_lens_kv = actual_seq_lens_q.clone()
avg_seq_len_q = actual_seq_lens_q.sum().item() // batch_size
if args.verbose >= 1:
print(f"[VERBOSE] Average actual seq len: {avg_seq_len_q}")
if args.verbose >= 2:
print(f"[VVERBOSE] {actual_seq_lens_q.flatten() = }")
cumsum_s_qo = torch.sum(actual_seq_lens_q)
cumsum_s_kv = torch.sum(actual_seq_lens_kv)
q = torch.randn(
cumsum_s_qo, num_qo_heads, head_dim_qk, device=device, dtype=q_init_dtype
)
if args.verbose >= 2:
print(f"[VVERBOSE] {q.shape = }")
k = torch.randn(
cumsum_s_kv, num_kv_heads, head_dim_qk, device=device, dtype=kv_init_dtype
)
v = torch.randn(
cumsum_s_kv, num_kv_heads, head_dim_vo, device=device, dtype=kv_init_dtype
)
block_tables = None
## The following are for BatchPrefillWithRaggedKVCacheWrapper
actual_seq_lens_q_device = actual_seq_lens_q.to(device)
actual_seq_lens_kv_device = actual_seq_lens_kv.to(device)
q_indptr = (
torch.cat(
[
torch.tensor([0], device=device),
torch.cumsum(actual_seq_lens_q_device.view(-1), dim=0)
* head_dim_qk
* num_qo_heads,
]
)
.long()
.to(device)
) # For cuDNN
k_indptr = torch.cat(
[
torch.tensor([0], device=device),
torch.cumsum(actual_seq_lens_kv_device.view(-1), dim=0)
* head_dim_qk
* num_kv_heads,
]
).long()
v_indptr = torch.cat(
[
torch.tensor([0], device=device),
torch.cumsum(actual_seq_lens_kv_device.view(-1), dim=0)
* head_dim_vo
* num_kv_heads,
]
).long()
o_indptr = torch.cat(
[
torch.tensor([0], device=device),
torch.cumsum(actual_seq_lens_q_device.view(-1), dim=0)
* head_dim_vo
* num_qo_heads,
]
).long()
batch_offsets_stats = torch.cat(
[
torch.zeros(
1,
device=actual_seq_lens_q_device.device,
dtype=actual_seq_lens_q_device.dtype,
),
torch.cumsum(actual_seq_lens_q_device.flatten(), dim=0) * num_qo_heads,
]
).cuda()
qo_indptr = (
torch.cat(
[
torch.tensor([0], device=device),
torch.cumsum(actual_seq_lens_q_device.view(-1), dim=0),
]
)
.int()
.to(device)
)
# Because actual_seq_lens_kv is the same as actual_seq_lens_q, kv_indptr will become the same as qo_indptr
kv_indptr = (
torch.cat(
[
torch.tensor([0], device=device),
torch.cumsum(actual_seq_lens_kv_device.view(-1), dim=0),
]
)
.int()
.to(device)
)
scale = float(1.0 / (head_dim_qk**0.5))
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
if args.verbose >= 2:
print(f"[VVERBOSE] {k.shape = }")
print(f"[VVERBOSE] {v.shape = }")
print(f"[VVERBOSE] {qo_indptr.shape = }")
print(f"[VVERBOSE] {kv_indptr.shape = }")
print(f"[VVERBOSE] {scale = }")
# Prepare wrappers
backend_wrappers = {}
for backend in backends:
if backend in ["cutlass", "fa2", "fa3", "trtllm-gen"]:
backend_wrappers[backend] = (
flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
workspace_buffer,
"NHD",
use_cuda_graph=is_cuda_graph_compatible
if backend != "fa2"
else False,
qo_indptr_buf=qo_indptr,
kv_indptr_buf=kv_indptr,
backend=backend,
)
)
backend_wrappers[backend].plan(
qo_indptr,
kv_indptr,
num_qo_heads,
num_kv_heads,
head_dim_qk,
head_dim_vo=head_dim_vo,
causal=causal,
q_data_type=q_dtype,
kv_data_type=kv_dtype,
)
k_scale, v_scale = None, None
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
q = q.to(q_dtype)
if kv_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
k_scale = k.amax().item() / 256
v_scale = v.amax().item() / 256
k = (k / k_scale).to(kv_dtype)
v = (v / v_scale).to(kv_dtype)
def run_backend_wrapper(backend):
if backend in ["cutlass", "fa2", "fa3", "trtllm-gen"]:
return backend_wrappers[backend].run_return_lse(q, k, v)[0]
elif backend == "cudnn":
return flashinfer.prefill.cudnn_batch_prefill_with_kv_cache(
q,
k,
v,
scale,
workspace_buffer,
max_token_per_sequence=s_qo,
max_sequence_kv=s_kv,
actual_seq_lens_q=actual_seq_lens_q_device,
actual_seq_lens_kv=actual_seq_lens_kv_device,
block_tables=block_tables,
causal=causal,
return_lse=True,
batch_offsets_q=q_indptr,
batch_offsets_k=k_indptr,
batch_offsets_v=v_indptr,
batch_offsets_o=o_indptr,
batch_offsets_stats=batch_offsets_stats,
is_cuda_graph_compatible=True,
)[0]
else:
raise ValueError(f"Backend {backend} not supported")
has_reference_output = False
# Iterate over each backend:
for cur_backend in backends:
if run_refcheck:
outputs[cur_backend] = run_backend_wrapper(cur_backend).detach().clone()
if cur_backend == "fa2":
has_reference_output = True
reference_output = outputs[cur_backend]
if is_cuda_graph_compatible and cur_backend != "fa2":
backend_times[cur_backend] = bench_gpu_time_with_cudagraph(
fn=lambda: run_backend_wrapper(cur_backend),
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
num_iters_within_graph=20,
l2_flush=True,
l2_flush_size_mb=256,
l2_flush_device=device,
sleep_after_run=True,
)
else:
backend_times[cur_backend] = bench_gpu_time(
fn=lambda: run_backend_wrapper(cur_backend),
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
l2_flush=True,
l2_flush_size_mb=256,
l2_flush_device=device,
sleep_after_run=True,
)
# Perform reference check
tested_backends = list(outputs.keys())
tested_outputs = list(outputs.values())
if len(tested_backends) > 1:
if run_refcheck and has_reference_output:
if reference_output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
if args.verbose >= 2:
print(
"[VVERBOSE] Reference output is FP8. Converting to float32 for reference check."
)
reference_output = reference_output.to(torch.float32)
tested_outputs = [output.to(torch.float32) for output in tested_outputs]
for i in range(len(tested_backends)):
try:
torch.testing.assert_close(
reference_output, tested_outputs[i], rtol=rtol, atol=atol
)
except AssertionError as e:
(
num_different_elements,
num_elements,
num_different_elements_percentage,
) = is_close_stats(reference_output, tested_outputs[i], rtol, atol)
print(
f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}: "
f"{num_different_elements} / {num_elements} ({num_different_elements_percentage:.2f}%) elements are different"
)
if not args.allow_output_mismatch:
print(e)
raise
# Compute perf metrics
res = []
for backend in backends:
if len(backend_times[backend]) > 0:
median_time = np.median(backend_times[backend])
std_time = np.std(backend_times[backend])
actual_seq_lens_q_flat = actual_seq_lens_q.flatten().to("cpu")
actual_seq_lens_kv_flat = actual_seq_lens_kv.flatten().to("cpu")
tflops = attention_tflops_per_sec_with_actual_seq_lens(
actual_seq_lens_q_flat,
actual_seq_lens_kv_flat,
head_dim_qk,
head_dim_vo,
num_qo_heads,
causal,
median_time,
)
tb_per_sec = attention_tb_per_sec_with_actual_seq_lens(
actual_seq_lens_q_flat,
actual_seq_lens_kv_flat,
head_dim_qk,
head_dim_vo,
num_qo_heads,
num_kv_heads,
median_time,
q_dtype=q_dtype,
kv_dtype=kv_dtype,
o_dtype=q_dtype,
)
print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec)
if args.output_path is not None:
cur_res = defaultdict(str)
cur_res["routine"] = args.routine
cur_res["median_time"] = median_time
cur_res["std_time"] = std_time
cur_res["tflops"] = tflops
cur_res["tb_per_sec"] = tb_per_sec
cur_res["backend"] = backend
cur_res["page_size"] = 0 # No page size for ragged
cur_res["batch_size"] = batch_size
cur_res["s_qo"] = s_qo
cur_res["s_kv"] = s_kv
cur_res["num_qo_heads"] = num_qo_heads
cur_res["num_kv_heads"] = num_kv_heads
cur_res["head_dim_qk"] = head_dim_qk
cur_res["head_dim_vo"] = head_dim_vo
cur_res["causal"] = causal
cur_res["q_dtype"] = q_dtype
cur_res["kv_dtype"] = kv_dtype
cur_res["avg_actual_seq_len"] = avg_seq_len_q
cur_res["random_actual_seq_len"] = args.random_actual_seq_len
cur_res["case_tag"] = args.case_tag
res.append(cur_res)
return res
def testBatchMLAPagedAttentionWrapper(args):
"""
Test BatchMLAPagedAttentionWrapper and equivalent APIs.
Supports fa2. and trtllm-gen-native.
This test:
1. Creates paged query and key-value cache tensors
2. Runs MLA with different backends
3. Verifies outputs match between backends
4. Measures performance metrics (TFLOPS, TB/sec)
Args:
args: Parsed command line arguments containing test configuration
Returns:
dict: List of dictionaries containing performance results
"""
if args.verbose >= 1:
print("[INFO] Running testBatchMLAPagedAttentionWrapper")
print(f"[INFO] FlashInfer version: {flashinfer.__version__}")
# Basic setup
device = get_device(args)
if args.generate_repro_command:
print(
f"[INFO] To reproduce this test case, run the following command: {args.repro_command}"
)
q_init_dtype = torch.bfloat16
kv_init_dtype = torch.bfloat16
rtol = 2e-1
atol = 1e-2
# Handle different query data types.
q_dtype = dtype_str_to_torch_dtype(args.q_dtype)
if q_dtype not in [torch.bfloat16, torch.float8_e4m3fn]:
raise ValueError(f"Unsupported q_dtype: {args.q_dtype}")
# Handle different KV cache data types.
kv_dtype = dtype_str_to_torch_dtype(args.kv_dtype)
if kv_dtype not in [torch.bfloat16, torch.float8_e4m3fn]:
raise ValueError(f"Unsupported kv_dtype: {args.kv_dtype}")
backends = args.backends
page_size = args.page_size
batch_size = args.batch_size
s_qo = args.s_qo
s_kv = args.s_kv
num_qo_heads = args.num_qo_heads
# num_kv_heads not used in MLA
# head_dim_qk = args.head_dim_qk
assert args.head_dim_ckv is not None, "head_dim_ckv must be provided for MLA"
assert args.head_dim_kpe is not None, "head_dim_kpe must be provided for MLA"
head_dim_ckv = args.head_dim_ckv
head_dim_kpe = args.head_dim_kpe
is_cuda_graph_compatible = not args.no_cuda_graph
causal = False # False for MLA
run_refcheck = args.refcheck
# Check for backend-specific constraints
if "fa2" in backends:
remove_fa2 = False
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
]:
print("[INFO] FA2 backend does not support FP8. Skipping.")
remove_fa2 = True
if remove_fa2:
backends.remove("fa2")
# Storage for timing results and outputs
backend_times = {backend: [] for backend in backends}
outputs = {}
actual_seq_lens_kv = sample_actual_seq_lens(
s_kv, batch_size, device, args.random_actual_seq_len
)
sum_seq_kv = torch.sum(actual_seq_lens_kv).item()
avg_seq_len_kv = sum_seq_kv // batch_size
if args.verbose >= 1:
print(f"[VERBOSE] Average actual seq len: {avg_seq_len_kv}")
if args.verbose >= 2:
print(f"[VVERBOSE] {actual_seq_lens_kv.flatten() = }")
q_nope = torch.rand(
batch_size, num_qo_heads, head_dim_ckv, dtype=q_init_dtype, device="cuda"
)
q_pe = torch.zeros(
batch_size, num_qo_heads, head_dim_kpe, dtype=q_init_dtype, device="cuda"
)
q = torch.cat([q_nope, q_pe], dim=2)
if args.verbose >= 2:
print(f"[VVERBOSE] {q_nope.shape = }")
print(f"[VVERBOSE] {q_pe.shape = }")
print(f"[VVERBOSE] {q.shape = }")
# Create KV cache
num_pages_per_seq = (s_kv + page_size - 1) // page_size
total_num_pages = num_pages_per_seq * batch_size
# Now initialize the page tables
block_tables = torch.tensor(
[
[k + i * num_pages_per_seq for k in range(num_pages_per_seq)]
for i in range(batch_size)
],
dtype=torch.int,
device=device,
)
if args.verbose >= 2:
print(f"[VVERBOSE] {num_pages_per_seq = }")
print(f"[VVERBOSE] {total_num_pages = }")
print(f"[VVERBOSE] {block_tables.shape = }")
# Initialize KV cache with appropriate shape and stride
ckv_cache_shape = (
total_num_pages,
page_size,
head_dim_ckv,
)
ckv_cache = torch.randn(size=ckv_cache_shape, dtype=kv_init_dtype, device=device)
kpe_cache_shape = (
total_num_pages,
page_size,
head_dim_kpe,
)
kpe_cache = torch.randn(size=kpe_cache_shape, dtype=q_init_dtype, device=device)
kv_cache = torch.cat([ckv_cache, kpe_cache], dim=2)
qo_indptr = torch.arange(0, batch_size + 1, device=device).int()
kv_indptr = (
torch.cat(
[
torch.tensor([0], device=device),
torch.cumsum(
(actual_seq_lens_kv.flatten() + page_size - 1) // page_size, dim=0
),
]
)
.int()
.to(device)
)
# kv_indices[-1] is the total number of actual pages
kv_indices = torch.zeros(kv_indptr[-1], device=device, dtype=torch.int32)
for i in range(len(kv_indptr) - 1):
start_idx = kv_indptr[i]
end_idx = kv_indptr[i + 1]
kv_indices[start_idx:end_idx] = torch.arange(
i * num_pages_per_seq,
i * num_pages_per_seq + (end_idx - start_idx),
device=device,
)
sm_scale = 1.0 / ((head_dim_ckv + head_dim_kpe) ** 0.5)
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
if args.verbose >= 2:
print(f"[VVERBOSE] {ckv_cache.shape = }")
print(f"[VVERBOSE] {kpe_cache.shape = }")
print(f"[VVERBOSE] {kv_cache.shape = }")
print(f"[VVERBOSE] {qo_indptr.shape = }")
print(f"[VVERBOSE] {kv_indptr.shape = }")
print(f"[VVERBOSE] {kv_indices.shape = }")
print(f"[VVERBOSE] {actual_seq_lens_kv.shape = }")
print(f"[VVERBOSE] {sm_scale = }")
print(f"[VVERBOSE] {workspace_buffer.shape = }")
# Create wrapper
if "fa2" in backends:
fi_fa2_mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
float_workspace_buffer=workspace_buffer,
use_cuda_graph=is_cuda_graph_compatible,
qo_indptr=qo_indptr,
kv_indptr=kv_indptr,
kv_indices=kv_indices,
kv_len_arr=actual_seq_lens_kv,
backend="fa2",
)
fi_fa2_mla_wrapper.plan(
qo_indptr=qo_indptr,
kv_indptr=kv_indptr,
kv_indices=kv_indices,
kv_len_arr=actual_seq_lens_kv,
num_heads=num_qo_heads,
head_dim_ckv=head_dim_ckv,
head_dim_kpe=head_dim_kpe,
page_size=page_size,
causal=causal,
sm_scale=sm_scale,
q_data_type=q_dtype,
kv_data_type=kv_dtype,
)
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
q = q.to(q_dtype)
q_pe = q_pe.to(q_dtype)
q_nope = q_nope.to(q_dtype)
if kv_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
ckv_cache = ckv_cache.to(kv_dtype)
kpe_cache = kpe_cache.to(kv_dtype)
kv_cache = kv_cache.to(kv_dtype)
def run_backend_wrapper(backend):
if backend == "fa2":
return fi_fa2_mla_wrapper.run(
q_nope, q_pe, ckv_cache, kpe_cache, return_lse=False
)
if backend == "trtllm-gen-native":
return flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
query=q.unsqueeze(1),
kv_cache=kv_cache.unsqueeze(1),
workspace_buffer=workspace_buffer,
qk_nope_head_dim=128, # To-do: Why??
kv_lora_rank=head_dim_ckv,
qk_rope_head_dim=head_dim_kpe,
block_tables=block_tables,
seq_lens=actual_seq_lens_kv,
max_seq_len=s_kv,
bmm1_scale=sm_scale,
bmm2_scale=1.0,
).squeeze(1)
else:
raise ValueError(f"Unsupported backend: {backend}")
has_reference_output = False
# Iterate over each backend:
for cur_backend in backends:
if run_refcheck:
outputs[cur_backend] = run_backend_wrapper(cur_backend).detach().clone()
if cur_backend == "fa2":
has_reference_output = True
reference_output = outputs[cur_backend]
if is_cuda_graph_compatible and cur_backend != "fa2":
backend_times[cur_backend] = bench_gpu_time_with_cudagraph(
fn=lambda: run_backend_wrapper(cur_backend),
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
num_iters_within_graph=20,
l2_flush=True,
l2_flush_size_mb=256,
l2_flush_device=device,
sleep_after_run=False,
)
else:
backend_times[cur_backend] = bench_gpu_time(
fn=lambda: run_backend_wrapper(cur_backend),
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
l2_flush=True,
l2_flush_size_mb=256,
l2_flush_device=device,
sleep_after_run=False,
)
# Perform reference check
tested_backends = list(outputs.keys())
tested_outputs = list(outputs.values())
if len(tested_backends) > 1:
if run_refcheck and has_reference_output:
if reference_output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
reference_output = reference_output.to(torch.float32)
tested_outputs = [output.to(torch.float32) for output in tested_outputs]
for i in range(len(tested_outputs)):
try:
torch.testing.assert_close(
reference_output, tested_outputs[i], rtol=rtol, atol=atol
)
except AssertionError as e:
(
num_different_elements,
num_elements,
num_different_elements_percentage,
) = is_close_stats(reference_output, tested_outputs[i], rtol, atol)
print(
f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}: "
f"{num_different_elements} / {num_elements} ({num_different_elements_percentage:.2f}%) elements are different"
)
if not args.allow_output_mismatch:
print(e)
raise
# Compute perf metrics
res = []
for backend in backends:
if len(backend_times[backend]) > 0:
median_time = np.median(backend_times[backend])
std_time = np.std(backend_times[backend])
actual_seq_lens_kv_flat = actual_seq_lens_kv.flatten().to("cpu")
actual_seq_lens_q_flat = torch.ones_like(
actual_seq_lens_kv.flatten().to("cpu")
)
o_mem_bytes = (
actual_seq_lens_q_flat.numel()
* num_qo_heads
* head_dim_ckv
* q_dtype.itemsize
)
qkv_mem_bytes = sum(
[
_.numel() * _.element_size()
for _ in [q_nope, q_pe, ckv_cache, kpe_cache]
]
)
total_mem_bytes = o_mem_bytes + qkv_mem_bytes
tb_per_sec = (total_mem_bytes / (median_time * 1e9)).item()
tflops_total = (
2
* torch.dot(
actual_seq_lens_q_flat.to(torch.float32),
actual_seq_lens_kv_flat.to(torch.float32),
)
* num_qo_heads
* (2 * head_dim_ckv + head_dim_kpe)
)
tflops = (tflops_total / (median_time * 1e9)).item()
print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec)
# TO-Do:
if args.output_path is not None:
cur_res = defaultdict(str)
cur_res["routine"] = args.routine
cur_res["median_time"] = median_time
cur_res["std_time"] = std_time
cur_res["tflops"] = tflops
cur_res["tb_per_sec"] = tb_per_sec
cur_res["backend"] = backend
cur_res["page_size"] = page_size
cur_res["batch_size"] = batch_size
cur_res["s_qo"] = s_qo
cur_res["s_kv"] = s_kv
cur_res["num_qo_heads"] = num_qo_heads
cur_res["head_dim_ckv"] = head_dim_ckv
cur_res["head_dim_kpe"] = head_dim_kpe
cur_res["causal"] = False
cur_res["q_dtype"] = q_dtype
cur_res["kv_dtype"] = kv_dtype
cur_res["avg_actual_seq_len"] = avg_seq_len_kv
cur_res["random_actual_seq_len"] = args.random_actual_seq_len
cur_res["case_tag"] = args.case_tag
res.append(cur_res)
return res