sglang_v0.5.2/flashinfer_0.3.1/csrc/pod_customize_config.jinja

43 lines
1.8 KiB
Django/Jinja

#pragma once
#include <flashinfer/page.cuh>
#include <flashinfer/math.cuh>
#include <flashinfer/layout.cuh>
#include <flashinfer/utils.cuh>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/fastdiv.cuh>
#include <flashinfer/attention/scheduler.cuh>
#include <flashinfer/attention/mask.cuh>
#include <flashinfer/attention/variant_helper.cuh>
#include <flashinfer/attention/default_prefill_params.cuh>
using namespace flashinfer;
using DTypeQ = {{ dtype_q }};
using DTypeKV = {{ dtype_kv }};
using DTypeO = {{ dtype_o }};
using IdType = {{ idtype }};
constexpr int HEAD_DIM_QK = {{ head_dim_qk }};
constexpr int HEAD_DIM_VO = {{ head_dim_vo }};
constexpr bool USE_FP16_QK_REDUCTION = {{ use_fp16_qk_reduction }};
constexpr auto USE_LOGITS_SOFT_CAP_P = {{ use_logits_soft_cap_p }};
constexpr auto POS_ENCODING_MODE_P = {{ pos_encoding_mode_p }};
constexpr auto USE_SLIDING_WINDOW_P = {{ use_sliding_window_p }};
constexpr auto USE_LOGITS_SOFT_CAP_D = {{ use_logits_soft_cap_d }};
constexpr auto POS_ENCODING_MODE_D = {{ pos_encoding_mode_d }};
constexpr auto USE_SLIDING_WINDOW_D = {{ use_sliding_window_d }};
constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone;
constexpr bool USE_LOGITS_SOFT_CAP = false;
using PrefillParams = SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>;
using DecodeParams = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>;
#define DISPATCH_context(MASK_MODE_P, MASK_MODE_D, DTypeQ, DTypeKV, HEAD_DIM_QK, \
USE_SLIDING_WINDOW_P, USE_SLIDING_WINDOW_D, USE_LOGITS_SOFT_CAP, ...) \
DISPATCH_MASK_MODE(mask_mode_p, MASK_MODE_P, { \
DISPATCH_MASK_MODE(mask_mode_d, MASK_MODE_D, { \
__VA_ARGS__(); \
}); \
});