sglang_v0.5.2/flashinfer_0.3.1/csrc/single_prefill_sm90_customi...

68 lines
1.8 KiB
Django/Jinja

#pragma once
#include <flashinfer/attention/hopper/attention_updater.cuh>
#include <flashinfer/attention/hopper/variant_helper.cuh>
#include <flashinfer/math.cuh>
#include <flashinfer/layout.cuh>
#include <flashinfer/cutlass_utils.cuh>
#include <flashinfer/utils.cuh>
#include <flashinfer/pos_enc.cuh>
#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }}
#define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }}
#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, ...) \
DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { using AttentionVariant = {{ variant_name }}; __VA_ARGS__(); })
using namespace flashinfer;
using DTypeQ = cutlass_dtype_t<{{ dtype_q }}>;
using DTypeKV = cutlass_dtype_t<{{ dtype_kv }}>;
using DTypeO = cutlass_dtype_t<{{ dtype_o }}>;
using IdType = cutlass_dtype_t<int32_t>;
constexpr int HEAD_DIM_QK = {{ head_dim_qk }};
constexpr int HEAD_DIM_VO = {{ head_dim_vo }};
constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }};
constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }};
struct Params {
using DTypeQ = DTypeQ;
using DTypeKV = DTypeKV;
using DTypeO = DTypeO;
using IdType = IdType;
// The QKV matrices.
DTypeQ* q_ptr;
DTypeKV* k_ptr;
DTypeKV* v_ptr;
DTypeO* o_ptr;
float* lse_ptr;
// Additional params
struct AdditionalParams {
{{ additional_params_decl }};
} additional_params;
int64_t q_stride_n;
int64_t k_stride_n;
int64_t v_stride_n;
int64_t o_stride_n;
int64_t q_stride_h;
int64_t k_stride_h;
int64_t v_stride_h;
int64_t o_stride_h;
int qo_len;
int kv_len;
int head_dim;
int num_qo_heads;
int num_kv_heads;
int group_size;
int window_left;
bool causal;
};
{{ variant_decl }}