16 lines
475 B
Django/Jinja
16 lines
475 B
Django/Jinja
#include <flashinfer/attention/prefill.cuh>
|
|
#include "single_prefill_config.inc"
|
|
|
|
using namespace flashinfer;
|
|
|
|
namespace flashinfer {
|
|
|
|
constexpr auto use_custom_mask = {{ mask_mode }} == MaskMode::kCustom;
|
|
|
|
template cudaError_t SinglePrefillWithKVCacheDispatched<
|
|
{{ head_dim_qk }}, {{ head_dim_vo }}, {{ pos_encoding_mode }}, {{ use_fp16_qk_reduction }}, {{ mask_mode }}, {{ variant_name }}, Params>(
|
|
Params params, {{ dtype_o }}* tmp,
|
|
cudaStream_t stream);
|
|
|
|
};
|