sglang_v0.5.2/flashinfer_0.3.1/csrc/batch_mla_config.jinja

34 lines
1.1 KiB
Django/Jinja

#pragma once
#include <flashinfer/page.cuh>
#include <flashinfer/math.cuh>
#include <flashinfer/layout.cuh>
#include <flashinfer/utils.cuh>
#include <flashinfer/pos_enc.cuh>
#include <flashinfer/fastdiv.cuh>
#include <flashinfer/attention/variant_helper.cuh>
#include <flashinfer/attention/mla_params.cuh>
using namespace flashinfer;
#ifdef FLASHINFER_ENABLE_PROFILER
#define ADDITIONAL_FUNC_PARAMS , at::Tensor profiler_buffer
#define ADDITIONAL_PARAMS_SETTER \
params.profiler_buffer = static_cast<uint64_t*>(profiler_buffer.data_ptr());
#else
#define ADDITIONAL_FUNC_PARAMS
#define ADDITIONAL_PARAMS_SETTER
#endif
using DTypeQ = {{ dtype_q }};
using DTypeKV = {{ dtype_kv }};
using DTypeO = {{ dtype_o }};
using IdType = {{ dtype_idx }};
constexpr int HEAD_DIM_CKV = {{ head_dim_ckv }};
constexpr int HEAD_DIM_KPE = {{ head_dim_kpe }};
#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, ...) \
DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { \
using Params = MLAParams<DTypeQ, DTypeKV, DTypeO, IdType>; \
__VA_ARGS__(); \
})