14 lines
366 B
Django/Jinja
14 lines
366 B
Django/Jinja
#include <flashinfer/attention/decode.cuh>
|
|
#include "batch_decode_config.inc"
|
|
|
|
using namespace flashinfer;
|
|
|
|
namespace flashinfer {
|
|
|
|
template cudaError_t
|
|
BatchDecodeWithPagedKVCacheDispatched<{{ head_dim_qk }}, {{ pos_encoding_mode }}, {{ variant_name }}, Params>(
|
|
Params params, {{ dtype_o }}* tmp_v,
|
|
float* tmp_s, bool enable_pdl, cudaStream_t stream);
|
|
|
|
};
|