/* * Copyright (c) 2025 by FlashInfer team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef FLASHINFER_GEMM_GROUPWISE_SM100_CUH_ #define FLASHINFER_GEMM_GROUPWISE_SM100_CUH_ #include #include "../allocator.h" #include "../cutlass_utils.cuh" #include "../utils.cuh" namespace flashinfer { namespace gemm { using namespace cute; template cudaError_t CutlassGroupwiseScaledGEMMSM100(void* float_buffer, size_t float_buffer_size_in_bytes, DTypeIn* A_ptr, DTypeIn* B_ptr, float* SFA_ptr, float* SFB_ptr, DTypeOut* C_ptr, int m, int n, int k, int l, cudaStream_t stream) { using ElementA = DTypeIn; // Element type for A matrix operand using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A // matrix in units of elements (up to 16 bytes) // B matrix configuration using ElementB = DTypeIn; // Element type for B matrix operand using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A // matrix in units of elements (up to 16 bytes) // C/D matrix configuration using ElementC = DTypeOut; // Element type for C and D matrix operands using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A // matrix in units of elements (up to 16 bytes) using ElementD = ElementC; using LayoutD = LayoutC; constexpr int AlignmentD = AlignmentC; // MMA type using ElementAccumulator = float; // Element Accumulator will also be our scale factor type using ElementCompute = float; using MmaTileShape_MNK = Shape, _128, _128>; using ClusterShape_MNK = Shape, _1, _1>; // NOTE(Zihao):: UMMA::Major::MN, UMMA::Major::MN is the fastest configuration. using ScaleConfig = std::conditional_t< ScaleMajorK, cutlass::detail::Sm100BlockwiseScaleConfig, cutlass::detail::Sm100BlockwiseScaleConfig>; using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, MmaTileShape_MNK, ClusterShape_MNK, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutC, AlignmentD, cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementA, cute::tuple, AlignmentA, ElementB, cute::tuple, AlignmentB, ElementAccumulator, MmaTileShape_MNK, ClusterShape_MNK, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, cutlass::gemm::KernelScheduleSm100Blockwise>::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveMainloop, CollectiveEpilogue, void>; // Default to ClusterLaunchControl (CLC) based tile scheduler using Gemm = cutlass::gemm::device::GemmUniversalAdapter; using StrideA = typename Gemm::GemmKernel::StrideA; using StrideB = typename Gemm::GemmKernel::StrideB; using StrideC = typename Gemm::GemmKernel::StrideC; using StrideD = typename Gemm::GemmKernel::StrideD; auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, l)); auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, l)); auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, l)); auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, l)); auto layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, l)); auto layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, l)); typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm, {m, n, k, l}, { A_ptr, stride_A, B_ptr, stride_B, SFA_ptr, layout_SFA, SFB_ptr, layout_SFB, }, { {}, // epilogue.thread C_ptr, stride_C, C_ptr, stride_C, }}; auto& fusion_args = arguments.epilogue.thread; fusion_args.alpha = 1.0f; fusion_args.beta = 0.0f; Gemm gemm; size_t workspace_size = Gemm::get_workspace_size(arguments); AlignedAllocator float_allocator(float_buffer, float_buffer_size_in_bytes); auto workspace_ptr = float_allocator.aligned_alloc(workspace_size, 16, "sm100_groupwise_gemm_float_workspace"); CUTLASS_CHECK(gemm.can_implement(arguments)); CUTLASS_CHECK(gemm.initialize(arguments, workspace_ptr)); CUTLASS_CHECK(gemm.run(stream)); return cudaSuccess; } } // namespace gemm } // namespace flashinfer #endif // FLASHINFER_GEMM_GROUPWISE_SM100_CUH_