717 lines
22 KiB
Python
717 lines
22 KiB
Python
import functools
|
|
import math
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
import flashinfer
|
|
from flashinfer.decode import single_decode_with_kv_cache_with_jit_module
|
|
from flashinfer.jit.attention import (
|
|
gen_customize_single_decode_module,
|
|
gen_customize_single_prefill_module,
|
|
)
|
|
from flashinfer.prefill import single_prefill_with_kv_cache_with_jit_module
|
|
from flashinfer.utils import MaskMode, is_sm90a_supported
|
|
|
|
|
|
def test_single_decode_mask():
|
|
torch.manual_seed(42)
|
|
variant_decl = r"""
|
|
struct SingleDecodeWithCustomMask : AttentionVariantBase {
|
|
static constexpr bool use_softmax = true;
|
|
|
|
uint8_t* custom_mask_ptr;
|
|
uint32_t window_left, qo_len, kv_len;
|
|
float sm_scale_log2;
|
|
|
|
// Create closure
|
|
template <typename Params>
|
|
__device__ __host__ SingleDecodeWithCustomMask(const Params& params, uint32_t batch_idx,
|
|
uint8_t* smem_ptr) {
|
|
custom_mask_ptr = params.custom_mask;
|
|
qo_len = 1;
|
|
kv_len = params.get_kv_len(batch_idx);
|
|
window_left = kv_len;
|
|
sm_scale_log2 = params.sm_scale * math::log2e;
|
|
}
|
|
|
|
REGISTER_LOGITS_MASK(params, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, {
|
|
const uint32_t offset = kv_idx;
|
|
return ((custom_mask_ptr[offset / 8] >> (offset % 8)) & 1);
|
|
})
|
|
|
|
REGISTER_OUTPUT_TRANSFORM(params, output, batch_idx, qo_idx, qo_head_idx, m, d, scale, {
|
|
return output;
|
|
})
|
|
};
|
|
"""
|
|
jit_module = gen_customize_single_decode_module(
|
|
"single_decode_custom_mask", # uri
|
|
torch.float16, # dtype_q
|
|
torch.float16, # dtype_kv
|
|
torch.float16, # dtype_o
|
|
128, # head_dim_qk
|
|
128, # head_dim_vo
|
|
["custom_mask"], # additional_tensor_names
|
|
["uint8_t"], # additional_tensor_dtypes
|
|
["sm_scale"], # # additional_scalar_names
|
|
["double"], # additional_scalar_dtypes
|
|
"SingleDecodeWithCustomMask",
|
|
variant_decl,
|
|
).build_and_load()
|
|
|
|
f = functools.partial(single_decode_with_kv_cache_with_jit_module, jit_module)
|
|
|
|
q = torch.randn(32, 128, dtype=torch.float16, device="cuda")
|
|
k = torch.randn(254, 32, 128, dtype=torch.float16, device="cuda")
|
|
v = torch.randn(254, 32, 128, dtype=torch.float16, device="cuda")
|
|
sm_scale = 1.0 / math.sqrt(128)
|
|
|
|
custom_mask = torch.randint(0, 2, (254,), dtype=torch.uint8, device="cuda")
|
|
packed_custom_mask = flashinfer.packbits(custom_mask, bitorder="little")
|
|
|
|
o = f(q, k, v, packed_custom_mask, sm_scale)
|
|
|
|
p = torch.einsum("hd,nhd->hn", q.float(), k.float()) * sm_scale
|
|
p[:, torch.nonzero(torch.logical_not(custom_mask)).squeeze()] = -float("inf")
|
|
o_ref = torch.einsum("hn,nhd->hd", torch.softmax(p, dim=-1), v.float()).half()
|
|
torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)
|
|
|
|
|
|
flash_sigmoid_sm80_decl = r"""
|
|
struct FlashSigmoid : AttentionVariantBase {
|
|
static constexpr bool use_softmax = false;
|
|
|
|
uint32_t window_left, qo_len, kv_len;
|
|
float sigmoid_scale_log2;
|
|
float sigmoid_bias_log2;
|
|
|
|
// Create closure
|
|
template <typename Params>
|
|
__device__ __host__ FlashSigmoid(const Params& params, uint32_t batch_idx,
|
|
uint8_t* smem_ptr) {
|
|
qo_len = params.get_qo_len(batch_idx);
|
|
kv_len = params.get_kv_len(batch_idx);
|
|
window_left = kv_len;
|
|
sigmoid_bias_log2 = params.sigmoid_bias * math::log2e;
|
|
sigmoid_scale_log2 = params.logits_scale * math::log2e;
|
|
}
|
|
|
|
REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, {
|
|
return math::ptx_rcp(1.f + math::ptx_exp2(-float(logits * sigmoid_scale_log2 + sigmoid_bias_log2)));
|
|
});
|
|
|
|
REGISTER_OUTPUT_TRANSFORM(params, output, batch_idx, qo_idx, qo_head_idx, m, d, scale, {
|
|
return output;
|
|
})
|
|
};
|
|
"""
|
|
|
|
flash_sigmoid_sm90_decl = r"""
|
|
struct FlashSigmoid : AttentionVariantBase {
|
|
float logits_scale_log2, sigmoid_bias_log2e;
|
|
// Init
|
|
template <typename MainloopParams, typename BlockCoord>
|
|
__device__ __host__ FlashSigmoid(const MainloopParams& params, const BlockCoord& block_coord) {
|
|
logits_scale_log2 = params.additional_params.logits_scale * math::log2e;
|
|
sigmoid_bias_log2e = params.additional_params.sigmoid_bias * math::log2e;
|
|
}
|
|
|
|
|
|
template <int NUM_ROWS_PER_THREAD>
|
|
__device__ auto GetAttentionUpdater() {
|
|
return DefaultUpdater<NUM_ROWS_PER_THREAD>();
|
|
}
|
|
|
|
REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, {
|
|
return math::ptx_rcp(1.f + math::ptx_exp2(-float(logits * logits_scale_log2 + sigmoid_bias_log2e)));
|
|
});
|
|
};
|
|
"""
|
|
|
|
|
|
def test_flash_sigmoid():
|
|
torch.manual_seed(42)
|
|
variant_decl = flash_sigmoid_sm80_decl
|
|
jit_module = gen_customize_single_prefill_module(
|
|
"fa2", # backend
|
|
"single_prefill_flash_sigmoid", # uri
|
|
torch.float16, # dtype_q
|
|
torch.float16, # dtype_kv
|
|
torch.float16, # dtype_o
|
|
128, # head_dim_qk
|
|
128, # head_dim_vo
|
|
[], # additional_tensor_names
|
|
[], # additional_tensor_dtypes
|
|
["logits_scale", "sigmoid_bias"], # additional_scalar_names
|
|
["double", "double"], # additional_scalar_dtypes
|
|
"FlashSigmoid",
|
|
variant_decl,
|
|
).build_and_load()
|
|
|
|
f = functools.partial(single_prefill_with_kv_cache_with_jit_module, jit_module)
|
|
|
|
q = torch.randn(128, 8, 128, dtype=torch.float16, device="cuda")
|
|
k = torch.randn(1027, 8, 128, dtype=torch.float16, device="cuda")
|
|
v = torch.randn(1027, 8, 128, dtype=torch.float16, device="cuda")
|
|
logits_scale = 1.0 / math.sqrt(128)
|
|
sigmoid_bias = 0.25
|
|
o = f(q, k, v, logits_scale, sigmoid_bias, mask_mode=MaskMode.NON_CAUSAL.value)
|
|
|
|
p = torch.sigmoid(
|
|
torch.einsum("mhd,nhd->hmn", q.float(), k.float()) * logits_scale + sigmoid_bias
|
|
)
|
|
o_ref = torch.einsum("hmn,nhd->mhd", p, v.float()).half()
|
|
torch.testing.assert_close(o, o_ref, rtol=2e-2, atol=2e-2)
|
|
|
|
|
|
def test_dump_logits():
|
|
torch.manual_seed(42)
|
|
variant_decl = r"""
|
|
struct DumpLogits : AttentionVariantBase {
|
|
static constexpr bool use_softmax = true;
|
|
|
|
uint32_t window_left, qo_len, kv_len;
|
|
float sm_scale_log2;
|
|
|
|
// Create closure
|
|
template <typename Params>
|
|
__device__ __host__ DumpLogits(const Params& params, uint32_t batch_idx,
|
|
uint8_t* smem_ptr) {
|
|
qo_len = params.get_qo_len(batch_idx);
|
|
kv_len = params.get_kv_len(batch_idx);
|
|
window_left = kv_len;
|
|
sm_scale_log2 = params.sm_scale * math::log2e;
|
|
}
|
|
|
|
REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, {
|
|
if (qo_idx < qo_len && kv_idx < kv_len) {
|
|
params.output_logits[qo_head_idx * (qo_len * kv_len) + qo_idx * kv_len + kv_idx] = logits * params.sm_scale;
|
|
}
|
|
return logits;
|
|
});
|
|
};
|
|
"""
|
|
jit_module = gen_customize_single_prefill_module(
|
|
"fa2", # backend
|
|
"single_prefill_dump_logits", # uri
|
|
torch.float16, # dtype_q
|
|
torch.float16, # dtype_kv
|
|
torch.float16, # dtype_o
|
|
128, # head_dim_qk
|
|
128, # head_dim_vo
|
|
["output_logits"], # additional_tensor_names
|
|
["float"], # additional_tensor_dtypes
|
|
["sm_scale"], # additional_scalar_names
|
|
["double"], # additional_scalar_dtypes
|
|
"DumpLogits",
|
|
variant_decl,
|
|
).build_and_load()
|
|
|
|
f = functools.partial(single_prefill_with_kv_cache_with_jit_module, jit_module)
|
|
|
|
q = torch.randn(128, 32, 128, dtype=torch.float16, device="cuda")
|
|
k = torch.randn(1023, 32, 128, dtype=torch.float16, device="cuda")
|
|
v = torch.randn(1023, 32, 128, dtype=torch.float16, device="cuda")
|
|
logits = torch.empty(32, 128, 1023, dtype=torch.float32, device="cuda")
|
|
sm_scale = 1.0 / math.sqrt(128)
|
|
o = f(q, k, v, logits, sm_scale, mask_mode=MaskMode.NON_CAUSAL.value)
|
|
|
|
p = torch.einsum("mhd,nhd->hmn", q.float(), k.float()) * sm_scale
|
|
o_ref = torch.einsum("hmn,nhd->mhd", torch.softmax(p, dim=-1), v.float()).half()
|
|
torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)
|
|
torch.testing.assert_close(logits, p, rtol=2e-2, atol=2e-2)
|
|
|
|
|
|
@pytest.mark.parametrize("use_tensor_cores", [False, True])
|
|
def test_batch_decode_flash_sigmoid(use_tensor_cores):
|
|
torch.manual_seed(42)
|
|
variant_decl = flash_sigmoid_sm80_decl
|
|
jit_args = (
|
|
f"batch_decode_flash_sigmoid_sm80_{use_tensor_cores}", # uri
|
|
torch.float16, # dtype_q
|
|
torch.float16, # dtype_kv
|
|
torch.float16, # dtype_o
|
|
torch.int32, # idtype
|
|
128, # hidden_dim_qk
|
|
128, # hidden_dim_vo
|
|
[], # additional_tensor_names
|
|
[], # additional_tensor_dtypes
|
|
["logits_scale", "sigmoid_bias"], # additional_scalar_names
|
|
["double", "double"], # additional_scalar_dtypes
|
|
"FlashSigmoid",
|
|
variant_decl,
|
|
)
|
|
|
|
float_workspace_buffer = torch.empty(
|
|
128 * 1024 * 1024, dtype=torch.uint8, device="cuda"
|
|
)
|
|
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
|
float_workspace_buffer,
|
|
kv_layout="NHD",
|
|
use_tensor_cores=use_tensor_cores,
|
|
jit_args=jit_args,
|
|
)
|
|
|
|
batch_size = 128
|
|
seq_len_per_request = 1024
|
|
kv_indptr_host = torch.arange(
|
|
0, batch_size * seq_len_per_request + 1, seq_len_per_request, dtype=torch.int32
|
|
)
|
|
page_size = 1
|
|
kv_indices_host = torch.arange(
|
|
0, batch_size * seq_len_per_request, dtype=torch.int32
|
|
)
|
|
last_page_len_host = torch.full((batch_size,), 1, dtype=torch.int32)
|
|
num_qo_heads = 32
|
|
num_kv_heads = 32
|
|
head_dim = 128
|
|
|
|
wrapper.plan(
|
|
kv_indptr_host,
|
|
kv_indices_host,
|
|
last_page_len_host,
|
|
num_qo_heads,
|
|
num_kv_heads,
|
|
head_dim,
|
|
page_size,
|
|
q_data_type=torch.float16,
|
|
kv_data_type=torch.float16,
|
|
)
|
|
|
|
q = torch.randn(
|
|
batch_size,
|
|
num_qo_heads,
|
|
head_dim,
|
|
dtype=torch.float16,
|
|
device="cuda",
|
|
)
|
|
k_cache = torch.randn(
|
|
batch_size * seq_len_per_request,
|
|
num_kv_heads,
|
|
head_dim,
|
|
dtype=torch.float16,
|
|
device="cuda",
|
|
)
|
|
v_cache = torch.randn(
|
|
batch_size * seq_len_per_request,
|
|
num_kv_heads,
|
|
head_dim,
|
|
dtype=torch.float16,
|
|
device="cuda",
|
|
)
|
|
|
|
logits_scale = 1.0 / math.sqrt(128)
|
|
sigmoid_bias = 0.25
|
|
|
|
o = wrapper.run(q, (k_cache, v_cache), logits_scale, sigmoid_bias)
|
|
p = torch.sigmoid(
|
|
torch.einsum(
|
|
"bhd,bnhd->bhn",
|
|
q.view(batch_size, num_qo_heads, head_dim).float(),
|
|
k_cache.view(
|
|
batch_size, seq_len_per_request, num_kv_heads, head_dim
|
|
).float(),
|
|
)
|
|
* logits_scale
|
|
+ sigmoid_bias
|
|
)
|
|
o_ref = (
|
|
torch.einsum(
|
|
"bhn,bnhd->bhd",
|
|
p,
|
|
v_cache.view(
|
|
batch_size, seq_len_per_request, num_kv_heads, head_dim
|
|
).float(),
|
|
)
|
|
.half()
|
|
.reshape(batch_size, num_qo_heads, head_dim)
|
|
)
|
|
|
|
torch.testing.assert_close(o, o_ref, rtol=2e-2, atol=2e-2)
|
|
|
|
|
|
def test_batch_prefill_flash_sigmoid():
|
|
torch.manual_seed(42)
|
|
variant_decl = flash_sigmoid_sm80_decl
|
|
jit_args = (
|
|
"batch_prefill_flash_sigmoid_sm80", # uri
|
|
torch.float16, # dtype_q
|
|
torch.float16, # dtype_kv
|
|
torch.float16, # dtype_o
|
|
torch.int32, # idtype
|
|
128, # hidden_dim_qk
|
|
128, # hidden_dim_vo
|
|
[], # additional_tensor_names
|
|
[], # additional_tensor_dtypes
|
|
["logits_scale", "sigmoid_bias"], # additional_scalar_names
|
|
["double", "double"], # additional_scalar_dtypes
|
|
"FlashSigmoid",
|
|
variant_decl,
|
|
)
|
|
|
|
float_workspace_buffer = torch.empty(
|
|
128 * 1024 * 1024, dtype=torch.uint8, device="cuda"
|
|
)
|
|
wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
|
|
float_workspace_buffer, kv_layout="NHD", backend="fa2", jit_args=jit_args
|
|
)
|
|
|
|
batch_size = 128
|
|
seq_len_per_request = 1024
|
|
qo_indptr_host = torch.arange(
|
|
0, batch_size * seq_len_per_request + 1, seq_len_per_request, dtype=torch.int32
|
|
)
|
|
kv_indptr_host = torch.arange(
|
|
0, batch_size * seq_len_per_request + 1, seq_len_per_request, dtype=torch.int32
|
|
)
|
|
|
|
num_qo_heads = 32
|
|
num_kv_heads = 32
|
|
head_dim = 128
|
|
|
|
wrapper.plan(
|
|
qo_indptr_host,
|
|
kv_indptr_host,
|
|
num_qo_heads,
|
|
num_kv_heads,
|
|
head_dim,
|
|
causal=False,
|
|
q_data_type=torch.float16,
|
|
kv_data_type=torch.float16,
|
|
)
|
|
|
|
q = torch.randn(
|
|
batch_size * seq_len_per_request,
|
|
num_qo_heads,
|
|
head_dim,
|
|
dtype=torch.float16,
|
|
device="cuda",
|
|
)
|
|
k = torch.randn(
|
|
batch_size * seq_len_per_request,
|
|
num_kv_heads,
|
|
head_dim,
|
|
dtype=torch.float16,
|
|
device="cuda",
|
|
)
|
|
v = torch.randn(
|
|
batch_size * seq_len_per_request,
|
|
num_kv_heads,
|
|
head_dim,
|
|
dtype=torch.float16,
|
|
device="cuda",
|
|
)
|
|
logits_scale = 1.0 / math.sqrt(128)
|
|
sigmoid_bias = 0.25
|
|
|
|
o = wrapper.run(q, k, v, logits_scale, sigmoid_bias)
|
|
|
|
wrapper_paged = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
|
float_workspace_buffer, kv_layout="NHD", backend="fa2", jit_args=jit_args
|
|
)
|
|
kv_indices_host = torch.arange(
|
|
0,
|
|
batch_size * seq_len_per_request,
|
|
dtype=torch.int32,
|
|
)
|
|
paged_kv_last_page_len_host = torch.full((batch_size,), 1, dtype=torch.int32)
|
|
wrapper_paged.plan(
|
|
qo_indptr_host,
|
|
kv_indptr_host,
|
|
kv_indices_host,
|
|
paged_kv_last_page_len_host,
|
|
num_qo_heads,
|
|
num_kv_heads,
|
|
head_dim,
|
|
1,
|
|
)
|
|
o_paged = wrapper_paged.run(q, (k, v), logits_scale, sigmoid_bias)
|
|
|
|
p = torch.sigmoid(
|
|
torch.einsum(
|
|
"bmhd,bnhd->bhmn",
|
|
q.view(batch_size, seq_len_per_request, num_qo_heads, head_dim).float(),
|
|
k.view(batch_size, seq_len_per_request, num_kv_heads, head_dim).float(),
|
|
)
|
|
* logits_scale
|
|
+ sigmoid_bias
|
|
)
|
|
o_ref = (
|
|
torch.einsum(
|
|
"bhmn,bnhd->bmhd",
|
|
p,
|
|
v.view(batch_size, seq_len_per_request, num_kv_heads, head_dim).float(),
|
|
)
|
|
.half()
|
|
.reshape(batch_size * seq_len_per_request, num_qo_heads, head_dim)
|
|
)
|
|
torch.testing.assert_close(o, o_ref, rtol=2e-2, atol=2e-2)
|
|
torch.testing.assert_close(o_paged, o_ref, rtol=2e-2, atol=2e-2)
|
|
|
|
|
|
def test_batch_prefill_sm90_flash_sigmoid():
|
|
if not is_sm90a_supported(torch.device("cuda")):
|
|
pytest.skip("SM90A is not supported")
|
|
|
|
torch.manual_seed(42)
|
|
variant_decl = flash_sigmoid_sm90_decl
|
|
jit_args = (
|
|
"batch_prefill_flash_sigmoid", # uri
|
|
torch.float16, # dtype_q
|
|
torch.float16, # dtype_kv
|
|
torch.float16, # dtype_o
|
|
torch.int32, # idtype
|
|
128, # hidden_dim_qk
|
|
128, # hidden_dim_vo
|
|
[], # additional_tensor_names
|
|
[], # additional_tensor_dtypes
|
|
["logits_scale", "sigmoid_bias"], # additional_scalar_names
|
|
["double", "double"], # additional_scalar_dtypes
|
|
"FlashSigmoid",
|
|
variant_decl,
|
|
)
|
|
|
|
float_workspace_buffer = torch.empty(
|
|
128 * 1024 * 1024, dtype=torch.uint8, device="cuda"
|
|
)
|
|
wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
|
|
float_workspace_buffer, kv_layout="NHD", backend="fa3", jit_args=jit_args
|
|
)
|
|
|
|
batch_size = 128
|
|
seq_len_per_request = 1024
|
|
qo_indptr_host = torch.arange(
|
|
0, batch_size * seq_len_per_request + 1, seq_len_per_request, dtype=torch.int32
|
|
)
|
|
kv_indptr_host = torch.arange(
|
|
0, batch_size * seq_len_per_request + 1, seq_len_per_request, dtype=torch.int32
|
|
)
|
|
|
|
num_qo_heads = 32
|
|
num_kv_heads = 32
|
|
head_dim = 128
|
|
|
|
wrapper.plan(
|
|
qo_indptr_host,
|
|
kv_indptr_host,
|
|
num_qo_heads,
|
|
num_kv_heads,
|
|
head_dim,
|
|
causal=False,
|
|
q_data_type=torch.float16,
|
|
kv_data_type=torch.float16,
|
|
)
|
|
|
|
q = torch.randn(
|
|
batch_size * seq_len_per_request,
|
|
num_qo_heads,
|
|
head_dim,
|
|
dtype=torch.float16,
|
|
device="cuda",
|
|
)
|
|
k = torch.randn(
|
|
batch_size * seq_len_per_request,
|
|
num_kv_heads,
|
|
head_dim,
|
|
dtype=torch.float16,
|
|
device="cuda",
|
|
)
|
|
v = torch.randn(
|
|
batch_size * seq_len_per_request,
|
|
num_kv_heads,
|
|
head_dim,
|
|
dtype=torch.float16,
|
|
device="cuda",
|
|
)
|
|
logits_scale = 1.0 / math.sqrt(128)
|
|
sigmoid_bias = 0.25
|
|
|
|
o = wrapper.run(q, k, v, logits_scale, sigmoid_bias)
|
|
wrapper_paged = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
|
float_workspace_buffer, kv_layout="NHD", backend="fa3", jit_args=jit_args
|
|
)
|
|
kv_indices_host = torch.arange(
|
|
0,
|
|
batch_size * seq_len_per_request,
|
|
dtype=torch.int32,
|
|
)
|
|
paged_kv_last_page_len_host = torch.full((batch_size,), 1, dtype=torch.int32)
|
|
wrapper_paged.plan(
|
|
qo_indptr_host,
|
|
kv_indptr_host,
|
|
kv_indices_host,
|
|
paged_kv_last_page_len_host,
|
|
num_qo_heads,
|
|
num_kv_heads,
|
|
head_dim,
|
|
1,
|
|
)
|
|
o_paged = wrapper_paged.run(q, (k, v), logits_scale, sigmoid_bias)
|
|
|
|
p = torch.sigmoid(
|
|
torch.einsum(
|
|
"bmhd,bnhd->bhmn",
|
|
q.view(batch_size, seq_len_per_request, num_qo_heads, head_dim).float(),
|
|
k.view(batch_size, seq_len_per_request, num_kv_heads, head_dim).float(),
|
|
)
|
|
* logits_scale
|
|
+ sigmoid_bias
|
|
)
|
|
o_ref = (
|
|
torch.einsum(
|
|
"bhmn,bnhd->bmhd",
|
|
p,
|
|
v.view(batch_size, seq_len_per_request, num_kv_heads, head_dim).float(),
|
|
)
|
|
.half()
|
|
.reshape(batch_size * seq_len_per_request, num_qo_heads, head_dim)
|
|
)
|
|
torch.testing.assert_close(o, o_ref, rtol=2e-2, atol=2e-2)
|
|
torch.testing.assert_close(o_paged, o_ref, rtol=2e-2, atol=2e-2)
|
|
|
|
|
|
def test_debug_print_logits():
|
|
torch.manual_seed(42)
|
|
variant_decl = r"""
|
|
struct DebugPrintLogits : AttentionVariantBase {
|
|
static constexpr bool use_softmax = true;
|
|
|
|
uint32_t window_left, qo_len, kv_len;
|
|
float sm_scale_log2;
|
|
|
|
// Create closure
|
|
template <typename Params>
|
|
__device__ __host__ DebugPrintLogits(const Params& params, uint32_t batch_idx,
|
|
uint8_t* smem_ptr) {
|
|
qo_len = params.get_qo_len(batch_idx);
|
|
kv_len = params.get_kv_len(batch_idx);
|
|
window_left = kv_len;
|
|
sm_scale_log2 = params.sm_scale * math::log2e;
|
|
}
|
|
|
|
REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, {
|
|
if (logits >= 5) {
|
|
printf("Large logits at qo_idx=%d, kv_idx=%d, qo_head_idx=%d, kv_head_idx=%d: %.3f\n",
|
|
qo_idx, kv_idx, qo_head_idx, kv_head_idx, float(logits));
|
|
}
|
|
return logits;
|
|
});
|
|
};
|
|
"""
|
|
jit_module = gen_customize_single_prefill_module(
|
|
"fa2", # backend
|
|
"batch_prefill_debug_print_logits", # uri
|
|
torch.float16, # dtype_q
|
|
torch.float16, # dtype_kv
|
|
torch.float16, # dtype_o
|
|
128, # hidden_dim_qk
|
|
128, # hidden_dim_vo
|
|
[], # additional_tensor_names
|
|
[], # additional_tensor_dtypes
|
|
["sm_scale"], # additional_scalar_names
|
|
["double"], # additional_scalar_dtypes
|
|
"DebugPrintLogits",
|
|
variant_decl,
|
|
).build_and_load()
|
|
|
|
f = functools.partial(single_prefill_with_kv_cache_with_jit_module, jit_module)
|
|
|
|
q = torch.randn(128, 32, 128, dtype=torch.float16, device="cuda")
|
|
k = torch.randn(1023, 32, 128, dtype=torch.float16, device="cuda")
|
|
v = torch.randn(1023, 32, 128, dtype=torch.float16, device="cuda")
|
|
sm_scale = 1.0 / math.sqrt(128)
|
|
o = f(q, k, v, sm_scale, mask_mode=MaskMode.NON_CAUSAL.value)
|
|
|
|
p = torch.einsum("mhd,nhd->hmn", q.float(), k.float()) * sm_scale
|
|
o_ref = torch.einsum("hmn,nhd->mhd", torch.softmax(p, dim=-1), v.float()).half()
|
|
torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)
|
|
|
|
|
|
def test_sm90_debug_print_logits():
|
|
if not is_sm90a_supported(torch.device("cuda")):
|
|
pytest.skip("SM90A is not supported")
|
|
|
|
torch.manual_seed(42)
|
|
variant_decl = r"""
|
|
struct DebugPrintLogits : AttentionVariantBase {
|
|
float sm_scale_log2;
|
|
int qo_len, kv_len;
|
|
|
|
// Init
|
|
template <typename MainloopParams, typename BlockCoord>
|
|
__device__ __host__ DebugPrintLogits(const MainloopParams& params, const BlockCoord& block_coord) {
|
|
sm_scale_log2 = params.additional_params.sm_scale * math::log2e;
|
|
auto [_, __, ___, ____, _____, qo_len_, kv_len_, batch_idx] =
|
|
block_coord;
|
|
|
|
qo_len = qo_len_;
|
|
kv_len = kv_len_;
|
|
}
|
|
|
|
|
|
template <int NUM_ROWS_PER_THREAD>
|
|
__device__ auto GetAttentionUpdater() {
|
|
return OnlineSoftmax<NUM_ROWS_PER_THREAD, /*WITH_SCALE*/false>(sm_scale_log2);
|
|
}
|
|
|
|
|
|
REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, {
|
|
if (qo_idx < qo_len && kv_idx < kv_len) {
|
|
printf(
|
|
"---> LOGITS DEBUG: "
|
|
"qo_idx=%-5d "
|
|
"kv_idx=%-5d "
|
|
"sm_scale_log2=%-12.5f "
|
|
"logits=%-12.5f "
|
|
"\n",
|
|
qo_idx,
|
|
kv_idx,
|
|
sm_scale_log2,
|
|
static_cast<float>(logits));
|
|
}
|
|
logits *= sm_scale_log2;
|
|
return logits;
|
|
})
|
|
};
|
|
"""
|
|
jit_module = gen_customize_single_prefill_module(
|
|
"fa3", # backend
|
|
"debug_print_logits", # uri
|
|
torch.float16, # dtype_q
|
|
torch.float16, # dtype_kv
|
|
torch.float16, # dtype_o
|
|
128, # hidden_dim_qk
|
|
128, # hidden_dim_vo
|
|
[], # additional_tensor_names
|
|
[], # additional_tensor_dtypes
|
|
["sm_scale"], # additional_scalar_names
|
|
["double"], # additional_scalar_dtypes
|
|
"DebugPrintLogits",
|
|
variant_decl,
|
|
).build_and_load()
|
|
|
|
f = functools.partial(single_prefill_with_kv_cache_with_jit_module, jit_module)
|
|
|
|
q = torch.randn(16, 2, 128, dtype=torch.float16, device="cuda")
|
|
k = torch.randn(16, 1, 128, dtype=torch.float16, device="cuda")
|
|
v = torch.randn(16, 1, 128, dtype=torch.float16, device="cuda")
|
|
sm_scale = 1.0 / math.sqrt(128)
|
|
o = f(q, k, v, sm_scale, mask_mode=MaskMode.NON_CAUSAL.value)
|
|
|
|
p = torch.einsum("mhd,nhd->hmn", q.float(), k.float()) * sm_scale
|
|
o_ref = torch.einsum("hmn,nhd->mhd", torch.softmax(p, dim=-1), v.float()).half()
|
|
torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_single_decode_mask()
|
|
test_flash_sigmoid()
|
|
test_dump_logits()
|
|
test_debug_print_logits()
|
|
test_sm90_debug_print_logits()
|
|
test_batch_decode_flash_sigmoid(False)
|
|
test_batch_decode_flash_sigmoid(True)
|
|
test_batch_prefill_flash_sigmoid()
|
|
test_batch_prefill_sm90_flash_sigmoid()
|