#pragma once #include #include #include #include #include #include #include #define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }} #define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }} #define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, USE_FP16_QK_REDUCTION, AttentionVariant, RaggedParams, PagedParams, ...) \ DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, { \ DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { \ constexpr auto use_custom_mask = MASK_MODE == MaskMode::kCustom; \ using AttentionVariant = {{ variant_name }}; \ __VA_ARGS__(); \ })}) using namespace flashinfer; using DTypeQ = {{ dtype_q }}; using DTypeKV = {{ dtype_kv }}; using DTypeO = {{ dtype_o }}; using IdType = {{ idtype }}; constexpr int HEAD_DIM_QK = {{ head_dim_qk }}; constexpr int HEAD_DIM_VO = {{ head_dim_vo }}; constexpr bool USE_FP16_QK_REDUCTION = {{ use_fp16_qk_reduction }}; constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }}; constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }}; struct RaggedParams { using DTypeQ = DTypeQ; using DTypeKV = DTypeKV; using DTypeO = DTypeO; using IdType = IdType; DTypeQ* q; DTypeKV* k; DTypeKV* v; IdType* q_indptr; IdType* kv_indptr; DTypeO* o; float* lse; uint_fastdiv group_size; IdType* maybe_q_rope_offset; IdType* maybe_k_rope_offset; {{ additional_params_decl }} uint32_t num_qo_heads; uint32_t num_kv_heads; uint32_t q_stride_n; uint32_t q_stride_h; uint32_t k_stride_n; uint32_t k_stride_h; uint32_t v_stride_n; uint32_t v_stride_h; int32_t window_left; IdType* request_indices; IdType* qo_tile_indices; IdType* kv_tile_indices; IdType* merge_indptr; IdType* o_indptr; IdType* kv_chunk_size_ptr; bool* block_valid_mask; uint32_t max_total_num_rows; uint32_t* total_num_rows; uint32_t padded_batch_size; bool partition_kv; __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { return q_indptr[batch_idx + 1] - q_indptr[batch_idx]; } __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { return kv_indptr[batch_idx + 1] - kv_indptr[batch_idx]; } }; struct PagedParams { using DTypeQ = DTypeQ; using DTypeKV = DTypeKV; using DTypeO = DTypeO; using IdType = IdType; DTypeQ* q; paged_kv_t paged_kv; IdType* q_indptr; DTypeO* o; float* lse; uint_fastdiv group_size; IdType* maybe_q_rope_offset; {{ additional_params_decl }} uint32_t num_qo_heads; IdType q_stride_n; IdType q_stride_h; int32_t window_left; IdType* request_indices; IdType* qo_tile_indices; IdType* kv_tile_indices; IdType* merge_indptr; IdType* o_indptr; bool* block_valid_mask; IdType* kv_chunk_size_ptr; uint32_t max_total_num_rows; uint32_t* total_num_rows; uint32_t padded_batch_size; bool partition_kv; __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { return q_indptr[batch_idx + 1] - q_indptr[batch_idx]; } __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { return paged_kv.get_length(batch_idx); } }; {{ variant_decl }}