sglang_v0.5.2/flashinfer_0.3.1/csrc/batch_prefill_sm90_customiz...

119 lines
2.8 KiB
Django/Jinja

#pragma once
#include <flashinfer/attention/hopper/attention_updater.cuh>
#include <flashinfer/attention/hopper/variant_helper.cuh>
#include <flashinfer/math.cuh>
#include <flashinfer/layout.cuh>
#include <flashinfer/cutlass_utils.cuh>
#include "pytorch_extension_utils.h"
#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }}
#define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }}
#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, RaggedParams, PagedParams, ...) \
DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { using AttentionVariant = {{ variant_name }}; __VA_ARGS__();})
using namespace flashinfer;
using DTypeQ = cutlass_dtype_t<{{ dtype_q }}>;
using DTypeKV = cutlass_dtype_t<{{ dtype_kv }}>;
using DTypeO = cutlass_dtype_t<{{ dtype_o }}>;
using IdType = cutlass_dtype_t<{{ idtype }}>;
constexpr int HEAD_DIM_QK = {{ head_dim_qk }};
constexpr int HEAD_DIM_VO = {{ head_dim_vo }};
constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }};
constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }};
struct RaggedParams {
using DTypeQ = DTypeQ;
using DTypeKV = DTypeKV;
using DTypeO = DTypeO;
using IdType = IdType;
// The QKV matrices.
DTypeQ* q_ptr;
DTypeKV* k_ptr;
DTypeKV* v_ptr;
DTypeO* o_ptr;
float* lse_ptr;
IdType* qo_tile_indices;
IdType* qo_indptr;
IdType* kv_indptr;
IdType* qo_lens;
IdType* kv_lens;
IdType* head_indices;
IdType* work_indptr;
IdType* batch_indices;
struct AdditionalParams {
{{ additional_params_decl }}
} additional_params;
int64_t q_stride_n;
int64_t k_stride_n;
int64_t v_stride_n;
int64_t o_stride_n;
int64_t q_stride_h;
int64_t k_stride_h;
int64_t v_stride_h;
int64_t o_stride_h;
int64_t nnz_qo;
int64_t nnz_kv;
int head_dim;
int num_qo_heads;
int num_kv_heads;
int group_size;
int window_left;
bool causal;
};
struct PagedParams {
using DTypeQ = DTypeQ;
using DTypeKV = DTypeKV;
using DTypeO = DTypeO;
using IdType = IdType;
// The QKV matrices.
DTypeQ* q_ptr;
DTypeKV* k_ptr;
DTypeKV* v_ptr;
DTypeO* o_ptr;
float* lse_ptr;
IdType* qo_tile_indices;
IdType* qo_indptr;
IdType* kv_indptr;
IdType* kv_indices;
IdType* qo_lens;
IdType* kv_lens;
IdType* head_indices;
IdType* work_indptr;
IdType* batch_indices;
struct AdditionalParams {
{{ additional_params_decl }}
} additional_params;
int64_t q_stride_n;
int64_t k_stride_n;
int64_t v_stride_n;
int64_t o_stride_n;
int64_t q_stride_h;
int64_t k_stride_h;
int64_t v_stride_h;
int64_t o_stride_h;
int64_t nnz_qo;
int head_dim;
int num_qo_heads;
int num_kv_heads;
int group_size;
int page_size;
int window_left;
bool causal;
};
{{ variant_decl }}