sglang_v0.5.2/flashinfer_0.3.1/csrc/batch_attention_customize_c...

117 lines
3.0 KiB
Django/Jinja

#pragma once
#include <flashinfer/page.cuh>
#include <flashinfer/math.cuh>
#include <flashinfer/layout.cuh>
#include <flashinfer/utils.cuh>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/fastdiv.cuh>
#include <flashinfer/attention/variant_helper.cuh>
#include <flashinfer/profiler.cuh>
using namespace flashinfer;
#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }}
#define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }}
#ifdef FLASHINFER_ENABLE_PROFILER
#define PROFILER_PARAMS_SETTER \
params[i].profiler_buffer = static_cast<uint64_t*>(profiler_buffer.data_ptr());
#else
#define PROFILER_PARAMS_SETTER
#endif
{{ variant_decl }}
template <bool UseLogitsSoftCap>
struct StandardAttention : AttentionVariantBase {
float sm_scale_log2;
float soft_cap_pre_tanh_scale;
static constexpr bool use_logits_soft_cap = UseLogitsSoftCap;
PROFILER_CLOSURE_PARAMS_DECL
template <typename Params>
__device__ __host__ StandardAttention(const Params& params, uint32_t batch_idx,
uint8_t* smem_ptr) {
if constexpr (UseLogitsSoftCap) {
soft_cap_pre_tanh_scale = params.sm_scale * math::ptx_rcp(params.logits_soft_cap);
sm_scale_log2 = math::log2e * params.logits_soft_cap;
}else{
sm_scale_log2 = params.sm_scale * math::log2e;
}
}
REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, {
if constexpr (UseLogitsSoftCap) {
logits = float(math::tanh(logits * soft_cap_pre_tanh_scale));
}
return logits;
})
};
#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, AttentionVariant, Params, ...) \
DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { \
using AttentionVariant = {{ variant_name }}; \
__VA_ARGS__(); \
})
using DTypeQ = {{ dtype_q }};
using DTypeKV = {{ dtype_kv }};
using DTypeO = {{ dtype_o }};
using IdType = {{ idtype }};
constexpr int HEAD_DIM_QK = {{ head_dim_qk }};
constexpr int HEAD_DIM_VO = {{ head_dim_vo }};
constexpr auto POS_ENCODING_MODE = {{ pos_encoding_mode }};
struct PersistentParams {
using DTypeQ = DTypeQ;
using DTypeKV = DTypeKV;
using DTypeO = DTypeO;
using IdType = IdType;
DTypeQ* q;
DTypeKV* k;
DTypeKV* v;
DTypeO* o;
DTypeO* partial_o;
float* partial_lse;
DTypeO* final_o;
float* final_lse;
IdType* q_indptr;
IdType* kv_indptr;
IdType* partial_indptr;
IdType* kv_indices;
IdType* q_len;
IdType* kv_len;
IdType* q_start;
IdType* kv_start;
IdType* kv_end;
IdType* kv_head_idx_arr;
IdType* work_indptr;
IdType* len_kv_chunk;
// for state reduction
IdType* merge_indptr;
IdType* merge_o_indices;
IdType* num_packed_qo_len;
uint32_t num_kv_heads;
uint_fastdiv gqa_group_size;
uint_fastdiv page_size;
uint32_t q_stride_n;
uint32_t q_stride_h;
uint32_t k_stride_page;
uint32_t k_stride_h;
uint32_t k_stride_n;
uint32_t v_stride_page;
uint32_t v_stride_h;
uint32_t v_stride_n;
float sm_scale;
double logits_soft_cap;
{{ additional_params_decl }}
PROFILER_PARAMS_DECL
};