sglang_v0.5.2/flashinfer_0.3.1/include/flashinfer/gemm/gemm_groupwise_sm100.cuh

156 lines
7.3 KiB
Plaintext

/*
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_GEMM_GROUPWISE_SM100_CUH_
#define FLASHINFER_GEMM_GROUPWISE_SM100_CUH_
#include <type_traits>
#include "../allocator.h"
#include "../cutlass_utils.cuh"
#include "../utils.cuh"
namespace flashinfer {
namespace gemm {
using namespace cute;
template <int ScaleGranularityM, int ScaleGranularityN, int ScaleGranularityK, bool ScaleMajorK,
int MmaSM, typename DTypeIn, typename DTypeOut>
cudaError_t CutlassGroupwiseScaledGEMMSM100(void* float_buffer, size_t float_buffer_size_in_bytes,
DTypeIn* A_ptr, DTypeIn* B_ptr, float* SFA_ptr,
float* SFB_ptr, DTypeOut* C_ptr, int m, int n, int k,
int l, cudaStream_t stream) {
using ElementA = DTypeIn; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA =
128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A
// matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = DTypeIn; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB =
128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of A
// matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = DTypeOut; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC =
128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of A
// matrix in units of elements (up to 16 bytes)
using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = AlignmentC;
// MMA type
using ElementAccumulator = float; // Element Accumulator will also be our scale factor type
using ElementCompute = float;
using MmaTileShape_MNK = Shape<cute::Int<MmaSM * 128>, _128, _128>;
using ClusterShape_MNK = Shape<cute::Int<MmaSM>, _1, _1>;
// NOTE(Zihao):: UMMA::Major::MN, UMMA::Major::MN is the fastest configuration.
using ScaleConfig = std::conditional_t<
ScaleMajorK,
cutlass::detail::Sm100BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN,
ScaleGranularityK, UMMA::Major::K, UMMA::Major::K>,
cutlass::detail::Sm100BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN,
ScaleGranularityK, UMMA::Major::MN,
UMMA::Major::MN>>;
using LayoutSFA =
decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand
using LayoutSFB =
decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, MmaTileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementCompute, ElementC,
LayoutC, AlignmentC, ElementD, LayoutC, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementA,
cute::tuple<LayoutA, LayoutSFA>, AlignmentA, ElementB, cute::tuple<LayoutB, LayoutSFB>,
AlignmentB, ElementAccumulator, MmaTileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::KernelScheduleSm100Blockwise>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
void>; // Default to ClusterLaunchControl (CLC) based tile scheduler
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, l));
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, l));
auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, l));
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, l));
auto layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, l));
auto layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, l));
typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm,
{m, n, k, l},
{
A_ptr,
stride_A,
B_ptr,
stride_B,
SFA_ptr,
layout_SFA,
SFB_ptr,
layout_SFB,
},
{
{}, // epilogue.thread
C_ptr,
stride_C,
C_ptr,
stride_C,
}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha = 1.0f;
fusion_args.beta = 0.0f;
Gemm gemm;
size_t workspace_size = Gemm::get_workspace_size(arguments);
AlignedAllocator float_allocator(float_buffer, float_buffer_size_in_bytes);
auto workspace_ptr = float_allocator.aligned_alloc<void>(workspace_size, 16,
"sm100_groupwise_gemm_float_workspace");
CUTLASS_CHECK(gemm.can_implement(arguments));
CUTLASS_CHECK(gemm.initialize(arguments, workspace_ptr));
CUTLASS_CHECK(gemm.run(stream));
return cudaSuccess;
}
} // namespace gemm
} // namespace flashinfer
#endif // FLASHINFER_GEMM_GROUPWISE_SM100_CUH_