15 lines
587 B
Django/Jinja
15 lines
587 B
Django/Jinja
#include <flashinfer/attention/prefill.cuh>
|
|
#include "batch_prefill_config.inc"
|
|
|
|
namespace flashinfer {
|
|
|
|
constexpr auto use_custom_mask = {{ mask_mode }} == MaskMode::kCustom;
|
|
|
|
{% for cta_tile_q in [16, 64, 128] %}
|
|
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<
|
|
/*CTA_TILE_Q=*/{{cta_tile_q}}, {{head_dim_qk}}, {{head_dim_vo}}, {{pos_encoding_mode}}, {{use_fp16_qk_reduction}}, {{mask_mode}},
|
|
{{ variant_name }}, RaggedParams>(RaggedParams params, {{ dtype_o }}* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream);
|
|
{% endfor %}
|
|
|
|
}; // namespace flashinfer
|