#pragma once #include #include #include #include #define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }} #define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }} #define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, ...) {\ using AttentionVariant = {{ variant_name }}; \ __VA_ARGS__(); \ } using namespace flashinfer; using DTypeQ = {{ dtype_q }}; using DTypeKV = {{ dtype_kv }}; using DTypeO = {{ dtype_o }}; using IdType = int32_t; constexpr int HEAD_DIM_QK = {{ head_dim_qk }}; constexpr int HEAD_DIM_VO = {{ head_dim_vo }}; constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }}; constexpr auto POS_ENCODING_MODE = {{ pos_encoding_mode }}; constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }}; struct Params { using DTypeQ = DTypeQ; using DTypeKV = DTypeKV; using DTypeO = DTypeO; using IdType = int32_t; DTypeQ* q; DTypeKV* k; DTypeKV* v; DTypeO* o; float* lse; {{ additional_params_decl }} uint32_t kv_len; uint32_t num_qo_heads; uint32_t num_kv_heads; uint32_t q_stride_n; uint32_t q_stride_h; uint32_t kv_stride_n; uint32_t kv_stride_h; int32_t window_left; uint32_t kv_chunk_size; __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { return 1; } __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { return kv_len; } }; {{ variant_decl }}