117 lines
3.0 KiB
Django/Jinja
117 lines
3.0 KiB
Django/Jinja
#pragma once
|
|
#include <flashinfer/page.cuh>
|
|
#include <flashinfer/math.cuh>
|
|
#include <flashinfer/layout.cuh>
|
|
#include <flashinfer/utils.cuh>
|
|
#include <flashinfer/pos_enc.cuh>
|
|
#include <flashinfer/fastdiv.cuh>
|
|
#include <flashinfer/attention/variant_helper.cuh>
|
|
#include <flashinfer/profiler.cuh>
|
|
|
|
using namespace flashinfer;
|
|
|
|
#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }}
|
|
#define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }}
|
|
|
|
#ifdef FLASHINFER_ENABLE_PROFILER
|
|
#define PROFILER_PARAMS_SETTER \
|
|
params[i].profiler_buffer = static_cast<uint64_t*>(profiler_buffer.data_ptr());
|
|
#else
|
|
#define PROFILER_PARAMS_SETTER
|
|
#endif
|
|
|
|
{{ variant_decl }}
|
|
|
|
template <bool UseLogitsSoftCap>
|
|
struct StandardAttention : AttentionVariantBase {
|
|
float sm_scale_log2;
|
|
float soft_cap_pre_tanh_scale;
|
|
static constexpr bool use_logits_soft_cap = UseLogitsSoftCap;
|
|
PROFILER_CLOSURE_PARAMS_DECL
|
|
|
|
template <typename Params>
|
|
__device__ __host__ StandardAttention(const Params& params, uint32_t batch_idx,
|
|
uint8_t* smem_ptr) {
|
|
if constexpr (UseLogitsSoftCap) {
|
|
soft_cap_pre_tanh_scale = params.sm_scale * math::ptx_rcp(params.logits_soft_cap);
|
|
sm_scale_log2 = math::log2e * params.logits_soft_cap;
|
|
}else{
|
|
sm_scale_log2 = params.sm_scale * math::log2e;
|
|
}
|
|
}
|
|
REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, {
|
|
if constexpr (UseLogitsSoftCap) {
|
|
logits = float(math::tanh(logits * soft_cap_pre_tanh_scale));
|
|
}
|
|
return logits;
|
|
})
|
|
};
|
|
|
|
#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, AttentionVariant, Params, ...) \
|
|
DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { \
|
|
using AttentionVariant = {{ variant_name }}; \
|
|
__VA_ARGS__(); \
|
|
})
|
|
|
|
using DTypeQ = {{ dtype_q }};
|
|
using DTypeKV = {{ dtype_kv }};
|
|
using DTypeO = {{ dtype_o }};
|
|
using IdType = {{ idtype }};
|
|
|
|
constexpr int HEAD_DIM_QK = {{ head_dim_qk }};
|
|
constexpr int HEAD_DIM_VO = {{ head_dim_vo }};
|
|
constexpr auto POS_ENCODING_MODE = {{ pos_encoding_mode }};
|
|
|
|
struct PersistentParams {
|
|
using DTypeQ = DTypeQ;
|
|
using DTypeKV = DTypeKV;
|
|
using DTypeO = DTypeO;
|
|
using IdType = IdType;
|
|
|
|
DTypeQ* q;
|
|
DTypeKV* k;
|
|
DTypeKV* v;
|
|
DTypeO* o;
|
|
DTypeO* partial_o;
|
|
float* partial_lse;
|
|
DTypeO* final_o;
|
|
float* final_lse;
|
|
|
|
IdType* q_indptr;
|
|
IdType* kv_indptr;
|
|
IdType* partial_indptr;
|
|
IdType* kv_indices;
|
|
IdType* q_len;
|
|
IdType* kv_len;
|
|
IdType* q_start;
|
|
IdType* kv_start;
|
|
IdType* kv_end;
|
|
IdType* kv_head_idx_arr;
|
|
IdType* work_indptr;
|
|
IdType* len_kv_chunk;
|
|
|
|
// for state reduction
|
|
IdType* merge_indptr;
|
|
IdType* merge_o_indices;
|
|
IdType* num_packed_qo_len;
|
|
|
|
uint32_t num_kv_heads;
|
|
uint_fastdiv gqa_group_size;
|
|
uint_fastdiv page_size;
|
|
|
|
uint32_t q_stride_n;
|
|
uint32_t q_stride_h;
|
|
uint32_t k_stride_page;
|
|
uint32_t k_stride_h;
|
|
uint32_t k_stride_n;
|
|
uint32_t v_stride_page;
|
|
uint32_t v_stride_h;
|
|
uint32_t v_stride_n;
|
|
|
|
float sm_scale;
|
|
double logits_soft_cap;
|
|
{{ additional_params_decl }}
|
|
|
|
PROFILER_PARAMS_DECL
|
|
};
|