sglang_v0.5.2/flashinfer_0.3.1/csrc/batch_decode_mla_config.jinja

22 lines
816 B
Django/Jinja

#pragma once
#include <flashinfer/attention/default_decode_params.cuh>
#include <flashinfer/attention/variants.cuh>
// using namespace flashinfer;
// avoid "at::Layout" is ambiguous error
using DTypeQ = {{ dtype_q }};
using DTypeKV = {{ dtype_kv }};
using DTypeO = {{ dtype_o }};
using IdType = {{ dtype_idx }};
constexpr bool USE_SLIDING_WINDOW = {{ use_sliding_window }};
constexpr bool USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }};
constexpr int HEAD_DIM_CKV = {{ head_dim_ckv }};
constexpr int HEAD_DIM_KPE = {{ head_dim_kpe }};
constexpr int QO_TILE_LEN = {{ qo_tile_len }};
using Params = flashinfer::BatchDecodeParamsMLA<DTypeQ, DTypeKV, DTypeO, IdType>;
using AttentionVariant =
flashinfer::DefaultAttention</*use_custom_mask=*/false, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, /*use_alibi*/false>;