10 lines
479 B
Django/Jinja
10 lines
479 B
Django/Jinja
#include <flashinfer/attention/persistent.cuh>
|
|
#include "batch_attention_config.inc"
|
|
|
|
namespace flashinfer {
|
|
template cudaError_t BatchPagedAttentionPersistent<
|
|
/*CTA_TILE_Q_1=*/128, /*CTA_TILE_Q_2=*/16, {{head_dim_qk}}, {{head_dim_vo}}, {{mask_mode}},
|
|
{{ variant_name }}, PersistentParams>(const PersistentParams params_1, const PersistentParams params_2,
|
|
const uint32_t num_blks_x, const uint32_t num_blks_y, const cudaStream_t stream);
|
|
}; // namespace flashinfer
|