64 lines
3.1 KiB
Python
64 lines
3.1 KiB
Python
import flashinfer_benchmark
|
|
import pytest
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [16, 32])
|
|
@pytest.mark.parametrize("s_kv", [1024, 2048])
|
|
@pytest.mark.parametrize("page_size", [8, 16])
|
|
@pytest.mark.parametrize("is_cuda_graph_compatible", [False, True])
|
|
def test_BatchDecodeWithPagedKVCacheWrapper_routine(
|
|
batch_size, s_kv, page_size, is_cuda_graph_compatible
|
|
):
|
|
args = flashinfer_benchmark.parse_args(
|
|
f"--routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 fa2_tc --page_size {page_size} --batch_size {batch_size} --s_qo 1 --s_kv {s_kv} --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck {'--no_cuda_graph' if not is_cuda_graph_compatible else ''}".split()
|
|
)
|
|
flashinfer_benchmark.run_test(args)
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [16, 32])
|
|
@pytest.mark.parametrize("s_kv", [1024, 2048])
|
|
@pytest.mark.parametrize("page_size", [8, 16])
|
|
@pytest.mark.parametrize("is_cuda_graph_compatible", [False])
|
|
def test_BatchPrefillWithPagedKVCacheWrapper_routine(
|
|
batch_size, s_kv, page_size, is_cuda_graph_compatible
|
|
):
|
|
args = flashinfer_benchmark.parse_args(
|
|
f"--routine BatchPrefillWithPagedKVCacheWrapper --backends fa2 --page_size {page_size} --batch_size {batch_size} --s_qo {s_kv} --s_kv {s_kv} --num_qo_heads 8 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --random_actual_seq_len -vv --refcheck --causal {'--no_cuda_graph' if not is_cuda_graph_compatible else ''}".split()
|
|
)
|
|
flashinfer_benchmark.run_test(args)
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [16, 32])
|
|
@pytest.mark.parametrize("s_kv", [1024, 2048])
|
|
@pytest.mark.parametrize("is_cuda_graph_compatible", [False])
|
|
def test_BatchPrefillWithRaggedKVCacheWrapper_routine(
|
|
batch_size, s_kv, is_cuda_graph_compatible
|
|
):
|
|
args = flashinfer_benchmark.parse_args(
|
|
f"--routine BatchPrefillWithRaggedKVCacheWrapper --backends fa2 --batch_size {batch_size} --s_qo {s_kv} --s_kv {s_kv} --num_qo_heads 128 --num_kv_heads 128 --head_dim_qk 192 --head_dim_vo 128 -vv --refcheck --causal {'--no_cuda_graph' if not is_cuda_graph_compatible else ''}".split()
|
|
)
|
|
flashinfer_benchmark.run_test(args)
|
|
|
|
|
|
@pytest.mark.parametrize("m", [1024, 4096])
|
|
@pytest.mark.parametrize("n", [1024])
|
|
@pytest.mark.parametrize("k", [1024, 2048])
|
|
@pytest.mark.parametrize("mma_sm", [1, 2])
|
|
def test_gemm_fp8_nt_groupwise(m, n, k, mma_sm):
|
|
args = flashinfer_benchmark.parse_args(
|
|
f"--routine gemm_fp8_nt_groupwise --m {m} --n {n} --k {k} --mma_sm {mma_sm} --no_cuda_graph --refcheck -vv".split()
|
|
)
|
|
flashinfer_benchmark.run_test(args)
|
|
|
|
|
|
@pytest.mark.parametrize("m", [1024, 4096])
|
|
@pytest.mark.parametrize("n", [1024])
|
|
@pytest.mark.parametrize("k", [1024, 2048])
|
|
@pytest.mark.parametrize("mma_sm", [1, 2])
|
|
@pytest.mark.parametrize("group_size", [1, 2])
|
|
def test_group_gemm_fp8_nt_groupwise(m, n, k, mma_sm, group_size):
|
|
args = flashinfer_benchmark.parse_args(
|
|
f"--routine group_gemm_fp8_nt_groupwise --m {m} --n {n} --k {k} --mma_sm {mma_sm} --group_size {group_size} --no_cuda_graph --refcheck -vv".split()
|
|
)
|
|
flashinfer_benchmark.run_test(args)
|