#pragma once #include #include #include #include #include #include #include #include #include #include using namespace flashinfer; using DTypeQ = {{ dtype_q }}; using DTypeKV = {{ dtype_kv }}; using DTypeO = {{ dtype_o }}; using IdType = {{ idtype }}; constexpr int HEAD_DIM_QK = {{ head_dim_qk }}; constexpr int HEAD_DIM_VO = {{ head_dim_vo }}; constexpr bool USE_FP16_QK_REDUCTION = {{ use_fp16_qk_reduction }}; constexpr auto USE_LOGITS_SOFT_CAP_P = {{ use_logits_soft_cap_p }}; constexpr auto POS_ENCODING_MODE_P = {{ pos_encoding_mode_p }}; constexpr auto USE_SLIDING_WINDOW_P = {{ use_sliding_window_p }}; constexpr auto USE_LOGITS_SOFT_CAP_D = {{ use_logits_soft_cap_d }}; constexpr auto POS_ENCODING_MODE_D = {{ pos_encoding_mode_d }}; constexpr auto USE_SLIDING_WINDOW_D = {{ use_sliding_window_d }}; constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; constexpr bool USE_LOGITS_SOFT_CAP = false; using PrefillParams = SinglePrefillParams; using DecodeParams = BatchPrefillPagedParams; #define DISPATCH_context(MASK_MODE_P, MASK_MODE_D, DTypeQ, DTypeKV, HEAD_DIM_QK, \ USE_SLIDING_WINDOW_P, USE_SLIDING_WINDOW_D, USE_LOGITS_SOFT_CAP, ...) \ DISPATCH_MASK_MODE(mask_mode_p, MASK_MODE_P, { \ DISPATCH_MASK_MODE(mask_mode_d, MASK_MODE_D, { \ __VA_ARGS__(); \ }); \ });