#pragma once #include #include #include #include #include #include #include #define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }} #define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }} #define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, ...) \ DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { using AttentionVariant = {{ variant_name }}; __VA_ARGS__(); }) using namespace flashinfer; using DTypeQ = cutlass_dtype_t<{{ dtype_q }}>; using DTypeKV = cutlass_dtype_t<{{ dtype_kv }}>; using DTypeO = cutlass_dtype_t<{{ dtype_o }}>; using IdType = cutlass_dtype_t; constexpr int HEAD_DIM_QK = {{ head_dim_qk }}; constexpr int HEAD_DIM_VO = {{ head_dim_vo }}; constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }}; constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }}; struct Params { using DTypeQ = DTypeQ; using DTypeKV = DTypeKV; using DTypeO = DTypeO; using IdType = IdType; // The QKV matrices. DTypeQ* q_ptr; DTypeKV* k_ptr; DTypeKV* v_ptr; DTypeO* o_ptr; float* lse_ptr; // Additional params struct AdditionalParams { {{ additional_params_decl }}; } additional_params; int64_t q_stride_n; int64_t k_stride_n; int64_t v_stride_n; int64_t o_stride_n; int64_t q_stride_h; int64_t k_stride_h; int64_t v_stride_h; int64_t o_stride_h; int qo_len; int kv_len; int head_dim; int num_qo_heads; int num_kv_heads; int group_size; int window_left; bool causal; }; {{ variant_decl }}