sglang_v0.5.2/flashinfer_0.3.1/csrc/batch_decode_mla_pybind.cu

22 lines
1.2 KiB
Plaintext

#include "mla_config.inc"
#include "pytorch_extension_utils.h"
at::Tensor BatchDecodeWithPagedKVCachePlanMLA(at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer,
at::Tensor page_locked_int_workspace_buffer,
at::Tensor indptr, int64_t batch_size,
int64_t num_qo_heads, int64_t page_size,
bool enable_cuda_graph);
void BatchDecodeWithPagedKVCacheRunMLA(
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor plan_info_vec,
at::Tensor q_nope, at::Tensor q_pe, at::Tensor paged_ckv_cache, at::Tensor paged_kpe_cache,
at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len,
at::Tensor o, double sm_scale, int64_t window_left, double logits_soft_cap, double rope_scale,
double rope_theta, std::optional<at::Tensor> maybe_lse, bool enable_pdl);
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
m.def("plan", BatchDecodeWithPagedKVCachePlanMLA);
m.def("run", BatchDecodeWithPagedKVCacheRunMLA);
}