73 lines
2.0 KiB
Django/Jinja
73 lines
2.0 KiB
Django/Jinja
#pragma once
|
|
#include <flashinfer/math.cuh>
|
|
#include <flashinfer/layout.cuh>
|
|
#include <flashinfer/utils.cuh>
|
|
#include <flashinfer/pos_enc.cuh>
|
|
#include <flashinfer/fastdiv.cuh>
|
|
#include <flashinfer/attention/variant_helper.cuh>
|
|
|
|
#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }}
|
|
#define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }}
|
|
|
|
|
|
#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, USE_FP16_QK_REDUCTION, AttentionVariant, Params, ...) \
|
|
DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { \
|
|
constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom; \
|
|
using AttentionVariant = {{ variant_name }}; \
|
|
__VA_ARGS__(); \
|
|
})
|
|
|
|
|
|
using namespace flashinfer;
|
|
|
|
using DTypeQ = {{ dtype_q }};
|
|
using DTypeKV = {{ dtype_kv }};
|
|
using DTypeO = {{ dtype_o }};
|
|
using IdType = int32_t;
|
|
constexpr int HEAD_DIM_QK = {{ head_dim_qk }};
|
|
constexpr int HEAD_DIM_VO = {{ head_dim_vo }};
|
|
constexpr bool USE_FP16_QK_REDUCTION = {{ use_fp16_qk_reduction }};
|
|
constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }};
|
|
constexpr auto POS_ENCODING_MODE = {{ pos_encoding_mode }};
|
|
constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }};
|
|
|
|
struct Params {
|
|
using DTypeQ = DTypeQ;
|
|
using DTypeKV = DTypeKV;
|
|
using DTypeO = DTypeO;
|
|
using IdType = int32_t;
|
|
DTypeQ* q;
|
|
DTypeKV* k;
|
|
DTypeKV* v;
|
|
DTypeO* o;
|
|
float* lse;
|
|
uint_fastdiv group_size;
|
|
|
|
{{ additional_params_decl }}
|
|
|
|
uint32_t qo_len;
|
|
uint32_t kv_len;
|
|
uint32_t num_qo_heads;
|
|
uint32_t num_kv_heads;
|
|
uint32_t q_stride_n;
|
|
uint32_t q_stride_h;
|
|
uint32_t k_stride_n;
|
|
uint32_t k_stride_h;
|
|
uint32_t v_stride_n;
|
|
uint32_t v_stride_h;
|
|
uint32_t head_dim;
|
|
int32_t window_left;
|
|
|
|
bool partition_kv;
|
|
|
|
__host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const {
|
|
return qo_len;
|
|
}
|
|
|
|
__host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const {
|
|
return kv_len;
|
|
}
|
|
};
|
|
|
|
{{ variant_decl }}
|