/* * Copyright (c) 2025 by FlashInfer team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef FLASHINFER_ATTENTION_PERSISTENT_TEMPLATE_CUH #define FLASHINFER_ATTENTION_PERSISTENT_TEMPLATE_CUH #include #include #include #include "../profiler.cuh" namespace flashinfer { namespace cg = cooperative_groups; // Define profiler event types for persistent kernels enum class PersistentProfileEventType { kRunner1 = 0U, kRunner2 = 1U, kReduction = 2U, }; struct ProfilerClosure { PROFILER_CLOSURE_PARAMS_DECL }; // Helper metafunction to find maximum threads among multiple BlockPersistentRunners template struct max_threads; template struct max_threads { static constexpr size_t value = Runner::KTraits::NUM_THREADS; }; template struct max_threads { static constexpr size_t value = Runner1::KTraits::NUM_THREADS > Runner2::KTraits::NUM_THREADS ? max_threads::value : max_threads::value; }; // Two runners version template __global__ __launch_bounds__( max_threads:: value) void PersistentKernelTemplate(const __grid_constant__ typename BlockPersistentRunner1::Params params_1, const __grid_constant__ typename BlockPersistentRunner2::Params params_2) { extern __shared__ uint8_t smem[]; #ifdef FLASHINFER_ENABLE_PROFILER ProfilerClosure profiler_closure; // no volatile as this is scope.CTA, and only threadIdx == 0 is modifying PROFILER_INIT(params_1, smem, profiler_closure, 0, 1, (threadIdx.x == 0)); #endif auto& smem_storage_1 = reinterpret_cast(smem); auto& smem_storage_2 = reinterpret_cast(smem); auto grid = cg::this_grid(); #ifndef FLASHINFER_ENABLE_PROFILER BlockPersistentRunner1::Run(params_1, &smem_storage_1); BlockPersistentRunner2::Run(params_2, &smem_storage_2); grid.sync(); BlockReductionRunner::Run(params_1.partial_o, params_1.final_o, params_1.partial_lse, params_1.final_lse, *(params_1.num_packed_qo_len), params_1.gqa_group_size, params_1.num_kv_heads, params_1.merge_indptr, params_1.merge_o_indices, smem); #else BlockPersistentRunner1::Run(params_1, &smem_storage_1, profiler_closure); BlockPersistentRunner2::Run(params_2, &smem_storage_2, profiler_closure); grid.sync(); BlockReductionRunner::Run(params_1.partial_o, params_1.final_o, params_1.partial_lse, params_1.final_lse, *(params_1.num_packed_qo_len), params_1.gqa_group_size, params_1.num_kv_heads, params_1.merge_indptr, params_1.merge_o_indices, smem, profiler_closure); #endif } } // namespace flashinfer #endif // FLASHINFER_ATTENTION_PERSISTENT_TEMPLATE_CUH