119 lines
2.8 KiB
Django/Jinja
119 lines
2.8 KiB
Django/Jinja
#pragma once
|
|
#include <flashinfer/attention/hopper/attention_updater.cuh>
|
|
#include <flashinfer/attention/hopper/variant_helper.cuh>
|
|
#include <flashinfer/math.cuh>
|
|
#include <flashinfer/layout.cuh>
|
|
#include <flashinfer/cutlass_utils.cuh>
|
|
#include "pytorch_extension_utils.h"
|
|
|
|
#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }}
|
|
#define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }}
|
|
|
|
#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, RaggedParams, PagedParams, ...) \
|
|
DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { using AttentionVariant = {{ variant_name }}; __VA_ARGS__();})
|
|
|
|
using namespace flashinfer;
|
|
|
|
using DTypeQ = cutlass_dtype_t<{{ dtype_q }}>;
|
|
using DTypeKV = cutlass_dtype_t<{{ dtype_kv }}>;
|
|
using DTypeO = cutlass_dtype_t<{{ dtype_o }}>;
|
|
using IdType = cutlass_dtype_t<{{ idtype }}>;
|
|
|
|
constexpr int HEAD_DIM_QK = {{ head_dim_qk }};
|
|
constexpr int HEAD_DIM_VO = {{ head_dim_vo }};
|
|
constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }};
|
|
constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }};
|
|
|
|
struct RaggedParams {
|
|
using DTypeQ = DTypeQ;
|
|
using DTypeKV = DTypeKV;
|
|
using DTypeO = DTypeO;
|
|
using IdType = IdType;
|
|
// The QKV matrices.
|
|
DTypeQ* q_ptr;
|
|
DTypeKV* k_ptr;
|
|
DTypeKV* v_ptr;
|
|
DTypeO* o_ptr;
|
|
float* lse_ptr;
|
|
|
|
IdType* qo_tile_indices;
|
|
IdType* qo_indptr;
|
|
IdType* kv_indptr;
|
|
IdType* qo_lens;
|
|
IdType* kv_lens;
|
|
IdType* head_indices;
|
|
IdType* work_indptr;
|
|
IdType* batch_indices;
|
|
|
|
struct AdditionalParams {
|
|
{{ additional_params_decl }}
|
|
} additional_params;
|
|
|
|
int64_t q_stride_n;
|
|
int64_t k_stride_n;
|
|
int64_t v_stride_n;
|
|
int64_t o_stride_n;
|
|
int64_t q_stride_h;
|
|
int64_t k_stride_h;
|
|
int64_t v_stride_h;
|
|
int64_t o_stride_h;
|
|
int64_t nnz_qo;
|
|
int64_t nnz_kv;
|
|
|
|
int head_dim;
|
|
int num_qo_heads;
|
|
int num_kv_heads;
|
|
int group_size;
|
|
int window_left;
|
|
|
|
bool causal;
|
|
};
|
|
|
|
struct PagedParams {
|
|
using DTypeQ = DTypeQ;
|
|
using DTypeKV = DTypeKV;
|
|
using DTypeO = DTypeO;
|
|
using IdType = IdType;
|
|
// The QKV matrices.
|
|
DTypeQ* q_ptr;
|
|
DTypeKV* k_ptr;
|
|
DTypeKV* v_ptr;
|
|
DTypeO* o_ptr;
|
|
float* lse_ptr;
|
|
|
|
IdType* qo_tile_indices;
|
|
IdType* qo_indptr;
|
|
IdType* kv_indptr;
|
|
IdType* kv_indices;
|
|
IdType* qo_lens;
|
|
IdType* kv_lens;
|
|
IdType* head_indices;
|
|
IdType* work_indptr;
|
|
IdType* batch_indices;
|
|
|
|
struct AdditionalParams {
|
|
{{ additional_params_decl }}
|
|
} additional_params;
|
|
|
|
int64_t q_stride_n;
|
|
int64_t k_stride_n;
|
|
int64_t v_stride_n;
|
|
int64_t o_stride_n;
|
|
int64_t q_stride_h;
|
|
int64_t k_stride_h;
|
|
int64_t v_stride_h;
|
|
int64_t o_stride_h;
|
|
int64_t nnz_qo;
|
|
|
|
int head_dim;
|
|
int num_qo_heads;
|
|
int num_kv_heads;
|
|
int group_size;
|
|
int page_size;
|
|
int window_left;
|
|
|
|
bool causal;
|
|
};
|
|
|
|
{{ variant_decl }}
|