sglang_v0.5.2/flashinfer_0.3.1/csrc/batch_prefill_ragged_kernel...

15 lines
587 B
Django/Jinja

#include <flashinfer/attention/prefill.cuh>
#include "batch_prefill_config.inc"
namespace flashinfer {
constexpr auto use_custom_mask = {{ mask_mode }} == MaskMode::kCustom;
{% for cta_tile_q in [16, 64, 128] %}
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<
/*CTA_TILE_Q=*/{{cta_tile_q}}, {{head_dim_qk}}, {{head_dim_vo}}, {{pos_encoding_mode}}, {{use_fp16_qk_reduction}}, {{mask_mode}},
{{ variant_name }}, RaggedParams>(RaggedParams params, {{ dtype_o }}* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream);
{% endfor %}
}; // namespace flashinfer