sglang_v0.5.2/flashinfer_0.3.1/csrc/pod_kernel_inst.jinja

33 lines
1.2 KiB
Django/Jinja

#include <flashinfer/attention/default_prefill_params.cuh>
#include <flashinfer/attention/default_decode_params.cuh>
#include <flashinfer/attention/variants.cuh>
#include <flashinfer/attention/scheduler.cuh>
#include <flashinfer/attention/mask.cuh>
#include <flashinfer/attention/pod.cuh>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/utils.cuh>
#include <flashinfer/page.cuh>
#include "pytorch_conversion_utils.h"
#include "pytorch_extension_utils.h"
#include "pod_config.inc"
using namespace flashinfer;
namespace flashinfer {
constexpr auto use_custom_mask_p = {{ mask_mode_p }} == MaskMode::kCustom;
constexpr auto use_custom_mask_d = {{ mask_mode_d }} == MaskMode::kCustom;
// Not sure about the below declaration
constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone;
template cudaError_t PODWithKVCacheTensorDispatched<
{{ head_dim_qk }}, {{ head_dim_vo }}, POS_ENCODING_MODE,
{{ use_fp16_qk_reduction }}, {{ mask_mode_p }}, 16,
{{ mask_mode_d }}, {{ variant_name_p }},
{{ variant_name_d }}, PrefillParams, DecodeParams>(
PrefillParams prefill_params, {{ dtype_o }}* tmp,
DecodeParams decode_params, {{ dtype_o }}* tmp_v,
float *tmp_s, bool enable_pdl, cudaStream_t stream);
};