33 lines
1.2 KiB
Django/Jinja
33 lines
1.2 KiB
Django/Jinja
#include <flashinfer/attention/default_prefill_params.cuh>
|
|
#include <flashinfer/attention/default_decode_params.cuh>
|
|
#include <flashinfer/attention/variants.cuh>
|
|
#include <flashinfer/attention/scheduler.cuh>
|
|
#include <flashinfer/attention/mask.cuh>
|
|
#include <flashinfer/attention/pod.cuh>
|
|
#include <flashinfer/pos_enc.cuh>
|
|
#include <flashinfer/utils.cuh>
|
|
#include <flashinfer/page.cuh>
|
|
|
|
#include "pytorch_conversion_utils.h"
|
|
#include "pytorch_extension_utils.h"
|
|
|
|
#include "pod_config.inc"
|
|
|
|
using namespace flashinfer;
|
|
|
|
namespace flashinfer {
|
|
constexpr auto use_custom_mask_p = {{ mask_mode_p }} == MaskMode::kCustom;
|
|
constexpr auto use_custom_mask_d = {{ mask_mode_d }} == MaskMode::kCustom;
|
|
// Not sure about the below declaration
|
|
constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone;
|
|
|
|
template cudaError_t PODWithKVCacheTensorDispatched<
|
|
{{ head_dim_qk }}, {{ head_dim_vo }}, POS_ENCODING_MODE,
|
|
{{ use_fp16_qk_reduction }}, {{ mask_mode_p }}, 16,
|
|
{{ mask_mode_d }}, {{ variant_name_p }},
|
|
{{ variant_name_d }}, PrefillParams, DecodeParams>(
|
|
PrefillParams prefill_params, {{ dtype_o }}* tmp,
|
|
DecodeParams decode_params, {{ dtype_o }}* tmp_v,
|
|
float *tmp_s, bool enable_pdl, cudaStream_t stream);
|
|
};
|