#pragma once #include #include #include #include #include #include "pytorch_extension_utils.h" #define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }} #define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }} #define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, RaggedParams, PagedParams, ...) \ DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { using AttentionVariant = {{ variant_name }}; __VA_ARGS__();}) using namespace flashinfer; using DTypeQ = cutlass_dtype_t<{{ dtype_q }}>; using DTypeKV = cutlass_dtype_t<{{ dtype_kv }}>; using DTypeO = cutlass_dtype_t<{{ dtype_o }}>; using IdType = cutlass_dtype_t<{{ idtype }}>; constexpr int HEAD_DIM_QK = {{ head_dim_qk }}; constexpr int HEAD_DIM_VO = {{ head_dim_vo }}; constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }}; constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }}; struct RaggedParams { using DTypeQ = DTypeQ; using DTypeKV = DTypeKV; using DTypeO = DTypeO; using IdType = IdType; // The QKV matrices. DTypeQ* q_ptr; DTypeKV* k_ptr; DTypeKV* v_ptr; DTypeO* o_ptr; float* lse_ptr; IdType* qo_tile_indices; IdType* qo_indptr; IdType* kv_indptr; IdType* qo_lens; IdType* kv_lens; IdType* head_indices; IdType* work_indptr; IdType* batch_indices; struct AdditionalParams { {{ additional_params_decl }} } additional_params; int64_t q_stride_n; int64_t k_stride_n; int64_t v_stride_n; int64_t o_stride_n; int64_t q_stride_h; int64_t k_stride_h; int64_t v_stride_h; int64_t o_stride_h; int64_t nnz_qo; int64_t nnz_kv; int head_dim; int num_qo_heads; int num_kv_heads; int group_size; int window_left; bool causal; }; struct PagedParams { using DTypeQ = DTypeQ; using DTypeKV = DTypeKV; using DTypeO = DTypeO; using IdType = IdType; // The QKV matrices. DTypeQ* q_ptr; DTypeKV* k_ptr; DTypeKV* v_ptr; DTypeO* o_ptr; float* lse_ptr; IdType* qo_tile_indices; IdType* qo_indptr; IdType* kv_indptr; IdType* kv_indices; IdType* qo_lens; IdType* kv_lens; IdType* head_indices; IdType* work_indptr; IdType* batch_indices; struct AdditionalParams { {{ additional_params_decl }} } additional_params; int64_t q_stride_n; int64_t k_stride_n; int64_t v_stride_n; int64_t o_stride_n; int64_t q_stride_h; int64_t k_stride_h; int64_t v_stride_h; int64_t o_stride_h; int64_t nnz_qo; int head_dim; int num_qo_heads; int num_kv_heads; int group_size; int page_size; int window_left; bool causal; }; {{ variant_decl }}