#pragma once #include #include #include #include #include #include #include #include using namespace flashinfer; #define ADDITIONAL_FUNC_PARAMS #define ADDITIONAL_PARAMS_SETTER using DTypeQ = {{ dtype_q }}; using DTypeKV = {{ dtype_kv }}; using DTypeO = {{ dtype_o }}; using IdType = {{ dtype_idx }}; constexpr int HEAD_DIM_CKV = {{ head_dim_ckv }}; constexpr int HEAD_DIM_KPE = {{ head_dim_kpe }}; #define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, ...) \ DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { \ using Params = MLAParams; \ __VA_ARGS__(); \ })