43 lines
1.8 KiB
Django/Jinja
43 lines
1.8 KiB
Django/Jinja
#pragma once
|
|
#include <flashinfer/page.cuh>
|
|
#include <flashinfer/math.cuh>
|
|
#include <flashinfer/layout.cuh>
|
|
#include <flashinfer/utils.cuh>
|
|
#include <flashinfer/pos_enc.cuh>
|
|
#include <flashinfer/fastdiv.cuh>
|
|
#include <flashinfer/attention/scheduler.cuh>
|
|
#include <flashinfer/attention/mask.cuh>
|
|
#include <flashinfer/attention/variant_helper.cuh>
|
|
#include <flashinfer/attention/default_prefill_params.cuh>
|
|
|
|
using namespace flashinfer;
|
|
|
|
using DTypeQ = {{ dtype_q }};
|
|
using DTypeKV = {{ dtype_kv }};
|
|
using DTypeO = {{ dtype_o }};
|
|
using IdType = {{ idtype }};
|
|
constexpr int HEAD_DIM_QK = {{ head_dim_qk }};
|
|
constexpr int HEAD_DIM_VO = {{ head_dim_vo }};
|
|
constexpr bool USE_FP16_QK_REDUCTION = {{ use_fp16_qk_reduction }};
|
|
constexpr auto USE_LOGITS_SOFT_CAP_P = {{ use_logits_soft_cap_p }};
|
|
constexpr auto POS_ENCODING_MODE_P = {{ pos_encoding_mode_p }};
|
|
constexpr auto USE_SLIDING_WINDOW_P = {{ use_sliding_window_p }};
|
|
|
|
constexpr auto USE_LOGITS_SOFT_CAP_D = {{ use_logits_soft_cap_d }};
|
|
constexpr auto POS_ENCODING_MODE_D = {{ pos_encoding_mode_d }};
|
|
constexpr auto USE_SLIDING_WINDOW_D = {{ use_sliding_window_d }};
|
|
|
|
constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone;
|
|
constexpr bool USE_LOGITS_SOFT_CAP = false;
|
|
|
|
using PrefillParams = SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>;
|
|
using DecodeParams = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>;
|
|
|
|
#define DISPATCH_context(MASK_MODE_P, MASK_MODE_D, DTypeQ, DTypeKV, HEAD_DIM_QK, \
|
|
USE_SLIDING_WINDOW_P, USE_SLIDING_WINDOW_D, USE_LOGITS_SOFT_CAP, ...) \
|
|
DISPATCH_MASK_MODE(mask_mode_p, MASK_MODE_P, { \
|
|
DISPATCH_MASK_MODE(mask_mode_d, MASK_MODE_D, { \
|
|
__VA_ARGS__(); \
|
|
}); \
|
|
});
|