22 lines
816 B
Django/Jinja
22 lines
816 B
Django/Jinja
#pragma once
|
|
#include <flashinfer/attention/default_decode_params.cuh>
|
|
#include <flashinfer/attention/variants.cuh>
|
|
|
|
// using namespace flashinfer;
|
|
// avoid "at::Layout" is ambiguous error
|
|
using DTypeQ = {{ dtype_q }};
|
|
using DTypeKV = {{ dtype_kv }};
|
|
using DTypeO = {{ dtype_o }};
|
|
using IdType = {{ dtype_idx }};
|
|
|
|
constexpr bool USE_SLIDING_WINDOW = {{ use_sliding_window }};
|
|
constexpr bool USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }};
|
|
constexpr int HEAD_DIM_CKV = {{ head_dim_ckv }};
|
|
constexpr int HEAD_DIM_KPE = {{ head_dim_kpe }};
|
|
|
|
constexpr int QO_TILE_LEN = {{ qo_tile_len }};
|
|
|
|
using Params = flashinfer::BatchDecodeParamsMLA<DTypeQ, DTypeKV, DTypeO, IdType>;
|
|
using AttentionVariant =
|
|
flashinfer::DefaultAttention</*use_custom_mask=*/false, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, /*use_alibi*/false>;
|