#pragma once #include #include #include #include #include #include #include #include using namespace flashinfer; #define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }} #define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }} #ifdef FLASHINFER_ENABLE_PROFILER #define PROFILER_PARAMS_SETTER \ params[i].profiler_buffer = static_cast(profiler_buffer.data_ptr()); #else #define PROFILER_PARAMS_SETTER #endif {{ variant_decl }} template struct StandardAttention : AttentionVariantBase { float sm_scale_log2; float soft_cap_pre_tanh_scale; static constexpr bool use_logits_soft_cap = UseLogitsSoftCap; PROFILER_CLOSURE_PARAMS_DECL template __device__ __host__ StandardAttention(const Params& params, uint32_t batch_idx, uint8_t* smem_ptr) { if constexpr (UseLogitsSoftCap) { soft_cap_pre_tanh_scale = params.sm_scale * math::ptx_rcp(params.logits_soft_cap); sm_scale_log2 = math::log2e * params.logits_soft_cap; }else{ sm_scale_log2 = params.sm_scale * math::log2e; } } REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, { if constexpr (UseLogitsSoftCap) { logits = float(math::tanh(logits * soft_cap_pre_tanh_scale)); } return logits; }) }; #define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, AttentionVariant, Params, ...) \ DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { \ using AttentionVariant = {{ variant_name }}; \ __VA_ARGS__(); \ }) using DTypeQ = {{ dtype_q }}; using DTypeKV = {{ dtype_kv }}; using DTypeO = {{ dtype_o }}; using IdType = {{ idtype }}; constexpr int HEAD_DIM_QK = {{ head_dim_qk }}; constexpr int HEAD_DIM_VO = {{ head_dim_vo }}; constexpr auto POS_ENCODING_MODE = {{ pos_encoding_mode }}; struct PersistentParams { using DTypeQ = DTypeQ; using DTypeKV = DTypeKV; using DTypeO = DTypeO; using IdType = IdType; DTypeQ* q; DTypeKV* k; DTypeKV* v; DTypeO* o; DTypeO* partial_o; float* partial_lse; DTypeO* final_o; float* final_lse; IdType* q_indptr; IdType* kv_indptr; IdType* partial_indptr; IdType* kv_indices; IdType* q_len; IdType* kv_len; IdType* q_start; IdType* kv_start; IdType* kv_end; IdType* kv_head_idx_arr; IdType* work_indptr; IdType* len_kv_chunk; // for state reduction IdType* merge_indptr; IdType* merge_o_indices; IdType* num_packed_qo_len; uint32_t num_kv_heads; uint_fastdiv gqa_group_size; uint_fastdiv page_size; uint32_t q_stride_n; uint32_t q_stride_h; uint32_t k_stride_page; uint32_t k_stride_h; uint32_t k_stride_n; uint32_t v_stride_page; uint32_t v_stride_h; uint32_t v_stride_n; float sm_scale; double logits_soft_cap; {{ additional_params_decl }} PROFILER_PARAMS_DECL };