/* * Copyright (c) 2024 by FlashInfer team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef FLASHINFER_ATTENTION_CUTLASS_MLA_CUH_ #define FLASHINFER_ATTENTION_CUTLASS_MLA_CUH_ #include #include "../cutlass_utils.cuh" #include "../exception.h" #include "cutlass/kernel_hardware_info.h" // From 3rdparty/cutlass/examples/77_blackwell_fmha #include "blackwell/device/sm100_mla.hpp" #include "blackwell/kernel/sm100_mla_tile_scheduler.hpp" namespace flashinfer { namespace attention { using namespace cute; using namespace cutlass::fmha::kernel; template struct IsPersistent { static const bool value = v; }; template > struct MlaSm100 { using Element = T; using ElementAcc = float; using ElementOut = T; using TileShape = Shape<_128, _128, Shape<_512, _64>>; using TileShapeH = cute::tuple_element_t<0, TileShape>; using TileShapeD = cute::tuple_element_t<2, TileShape>; // H K (D_latent D_rope) B using ProblemShape = cute::tuple; using StrideQ = cute::tuple; // H D B using StrideK = cute::tuple; // K D B using StrideO = StrideK; // H D B using StrideLSE = cute::tuple<_1, int>; // H B using TileScheduler = std::conditional_t; using FmhaKernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized< TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler, /*kIsCpAsync=*/true>; using Fmha = cutlass::fmha::device::MLA; }; template typename T::Fmha::Arguments args_from_options(void* out_ptr, void* lse_ptr, void* q_absorbed_ptr, void* ckv_kpe_cache_ptr, void* seq_lens_ptr, void* page_table_ptr, int batches, int page_count_per_seq, int page_count_total, int page_size, int device_index) { cutlass::KernelHardwareInfo hw_info; hw_info.device_id = device_index; hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); int max_seq_len = page_size * page_count_per_seq; using TileShapeH = typename T::TileShapeH; using TileShapeD = typename T::TileShapeD; auto problem_shape = cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches); auto [H, K, D, B] = problem_shape; auto [D_latent, D_rope] = D; // the scale is based on the non-absorbed sizes, change as appropriate // we can't determine this parameter from the info we have, it's an input int D_non_latent = 128; float scale = 1.0 / sqrt(1.0 * (D_non_latent + D_rope)); using StrideQ = typename T::StrideQ; using StrideK = typename T::StrideK; using StrideO = typename T::StrideO; using StrideLSE = typename T::StrideLSE; StrideQ stride_Q = cute::make_tuple(static_cast(0 + D_latent + D_rope), _1{}, static_cast(H * (0 + D_latent + D_rope))); StrideK stride_C = cute::make_tuple(static_cast(0 + D_latent + D_rope), _1{}, static_cast(page_size * (D_latent + D_rope))); StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq); StrideLSE stride_LSE = cute::make_tuple(_1{}, 0 + H); StrideO stride_O = cute::make_tuple(static_cast(0 + D_latent), _1{}, static_cast(0 + H * D_latent)); using Element = typename T::Element; using ElementOut = typename T::ElementOut; using ElementAcc = typename T::ElementAcc; auto Q_ptr = reinterpret_cast(q_absorbed_ptr); auto C_ptr = reinterpret_cast(ckv_kpe_cache_ptr); typename T::Fmha::Arguments arguments{ problem_shape, {scale, Q_ptr, stride_Q, Q_ptr + D_latent, stride_Q, C_ptr, stride_C, C_ptr + D_latent, stride_C, reinterpret_cast(seq_lens_ptr), reinterpret_cast(page_table_ptr), stride_PT, page_count_total, page_size}, {reinterpret_cast(out_ptr), stride_O, // static_cast(lse.data_ptr()), stride_LSE}, static_cast(nullptr), stride_LSE}, hw_info, -1, // split_kv nullptr, // is_var_split_kv=false }; // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute // split_kv automatically based on batch size and sequence length to balance // workload across available SMs. Consider using var_split_kv for manual // control if needed. T::Fmha::set_split_kv(arguments); return arguments; } template cudaError_t runMla(void* workspace_ptr, void* out_ptr, void* lse_ptr, void* q_absorbed_ptr, void* ckv_kpe_cache_ptr, void* seq_lens_ptr, void* page_table_ptr, int batches, int page_count_per_seq, int page_count_total, int page_size, int device_index, cudaStream_t stream) { using MlaSm100Type = MlaSm100; typename MlaSm100Type::Fmha fmha; auto arguments = args_from_options( out_ptr, lse_ptr, q_absorbed_ptr, ckv_kpe_cache_ptr, seq_lens_ptr, page_table_ptr, batches, page_count_per_seq, page_count_total, page_size, device_index); CUTLASS_CHECK(fmha.can_implement(arguments)); CUTLASS_CHECK(fmha.initialize(arguments, workspace_ptr, stream)); CUTLASS_CHECK(fmha.run(arguments, workspace_ptr, stream)); return cudaSuccess; } } // namespace attention } // namespace flashinfer #endif // FLASHINFER_ATTENTION_CUTLASS_MLA_CUH_