28 lines
874 B
Django/Jinja
28 lines
874 B
Django/Jinja
#pragma once
|
|
#include <flashinfer/page.cuh>
|
|
#include <flashinfer/math.cuh>
|
|
#include <flashinfer/layout.cuh>
|
|
#include <flashinfer/utils.cuh>
|
|
#include <flashinfer/pos_enc.cuh>
|
|
#include <flashinfer/fastdiv.cuh>
|
|
#include <flashinfer/attention/variant_helper.cuh>
|
|
#include <flashinfer/attention/mla_params.cuh>
|
|
|
|
using namespace flashinfer;
|
|
|
|
#define ADDITIONAL_FUNC_PARAMS
|
|
#define ADDITIONAL_PARAMS_SETTER
|
|
|
|
using DTypeQ = {{ dtype_q }};
|
|
using DTypeKV = {{ dtype_kv }};
|
|
using DTypeO = {{ dtype_o }};
|
|
using IdType = {{ dtype_idx }};
|
|
constexpr int HEAD_DIM_CKV = {{ head_dim_ckv }};
|
|
constexpr int HEAD_DIM_KPE = {{ head_dim_kpe }};
|
|
|
|
#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, ...) \
|
|
DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { \
|
|
using Params = MLAParams<DTypeQ, DTypeKV, DTypeO, IdType>; \
|
|
__VA_ARGS__(); \
|
|
})
|