sglang_v0.5.2/flashinfer_0.3.1/tvm_binding/batch_mla_config.jinja

28 lines
874 B
Django/Jinja

#pragma once
#include <flashinfer/page.cuh>
#include <flashinfer/math.cuh>
#include <flashinfer/layout.cuh>
#include <flashinfer/utils.cuh>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/fastdiv.cuh>
#include <flashinfer/attention/variant_helper.cuh>
#include <flashinfer/attention/mla_params.cuh>
using namespace flashinfer;
#define ADDITIONAL_FUNC_PARAMS
#define ADDITIONAL_PARAMS_SETTER
using DTypeQ = {{ dtype_q }};
using DTypeKV = {{ dtype_kv }};
using DTypeO = {{ dtype_o }};
using IdType = {{ dtype_idx }};
constexpr int HEAD_DIM_CKV = {{ head_dim_ckv }};
constexpr int HEAD_DIM_KPE = {{ head_dim_kpe }};
#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, ...) \
DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { \
using Params = MLAParams<DTypeQ, DTypeKV, DTypeO, IdType>; \
__VA_ARGS__(); \
})