sglang_v0.5.2/flashinfer_0.3.1/include/flashinfer/gemm/gemm_groupwise_sm120.cuh

190 lines
7.6 KiB
Plaintext

/*
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_GEMM_GEMM_GROUPWISE_SM120_CUH_
#define FLASHINFER_GEMM_GEMM_GROUPWISE_SM120_CUH_
#include <flashinfer/allocator.h>
#include <cassert>
#include <iostream>
#include <iterator>
#include <type_traits>
#include <typeinfo>
#include "cutlass/cutlass.h"
#include "cutlass/detail/blockwise_scale_layout.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/packed_stride.hpp"
namespace flashinfer {
namespace gemm {
// SM120 uses Cooperative schedule with 128x128x128 tile shape
template <int ScaleGranularityM, int ScaleGranularityN, int ScaleGranularityK, bool ScaleMajorK,
typename DTypeIn, typename DTypeOut>
cudaError_t CutlassGroupwiseScaledGEMMSM120(void* float_buffer, size_t float_buffer_size_in_bytes,
DTypeIn* A_ptr, DTypeIn* B_ptr, float* SFA_ptr,
float* SFB_ptr, DTypeOut* D_ptr, int m, int n, int k,
int l, cudaStream_t stream) {
// SM120 only supports these specific scale granularities
static_assert(ScaleGranularityM == 1 || ScaleGranularityM == 128,
"SM120 only supports ScaleGranularityM = 1 or 128");
static_assert(ScaleGranularityN == 128, "SM120 only supports ScaleGranularityN = 128");
static_assert(ScaleGranularityK == 128, "SM120 only supports ScaleGranularityK = 128");
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM121_SUPPORTED)
using namespace cute;
using ElementA = DTypeIn;
using LayoutA = cutlass::layout::RowMajor;
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
using ElementB = DTypeIn;
using LayoutB = cutlass::layout::ColumnMajor;
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
using ElementC = DTypeOut;
using LayoutC = cutlass::layout::RowMajor;
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = AlignmentC;
using ElementAccumulator = float;
using ElementCompute = float;
// SM120 uses fixed 128x128x128 tile shape (Cooperative schedule)
// TODO (yongwww): add PingPong schedule (64x128x128)
using MmaTileShape_MNK = Shape<_128, _128, _128>;
using ClusterShape_MNK = Shape<_1, _1, _1>;
// SM120's Sm120BlockwiseScaleConfig takes UMMA::Major parameters based on ScaleMajorK
using ScaleConfig = std::conditional_t<
ScaleMajorK,
cutlass::detail::Sm120BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN,
ScaleGranularityK, UMMA::Major::K, UMMA::Major::K>,
cutlass::detail::Sm120BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN,
ScaleGranularityK, UMMA::Major::MN,
UMMA::Major::MN>>;
// Use decltype like SM100 does for consistency
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, MmaTileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementCompute, ElementC,
LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;
// SM120 uses automatic stage count with epilogue carveout
using StageCount = cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, ElementA,
cute::tuple<LayoutA, LayoutSFA>, AlignmentA, ElementB, cute::tuple<LayoutB, LayoutSFB>,
AlignmentB, ElementAccumulator, MmaTileShape_MNK, ClusterShape_MNK, StageCount,
cutlass::gemm::KernelScheduleSm120Blockwise>::CollectiveOp;
using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, CollectiveMainloop,
CollectiveEpilogue, void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, l));
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, l));
auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, l));
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, l));
auto layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, l));
auto layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, l));
// For beta=0 case, C and D can be the same buffer
DTypeOut* C_ptr = D_ptr; // Use D as both input and output when beta=0
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{m, n, k, l},
{A_ptr, stride_A, B_ptr, stride_B, SFA_ptr, layout_SFA, SFB_ptr, layout_SFB},
{{}, C_ptr, stride_C, D_ptr, stride_D}}; // C and D point to same buffer when beta=0
// Set alpha and beta for the epilogue
arguments.epilogue.thread.alpha = 1.0f;
arguments.epilogue.thread.beta = 0.0f;
// Check device compute capability first
int device_id = 0;
cudaGetDevice(&device_id);
cudaDeviceProp props;
cudaGetDeviceProperties(&props, device_id);
Gemm gemm;
cutlass::Status status = gemm.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
return cudaErrorNotSupported;
}
size_t workspace_size = Gemm::get_workspace_size(arguments);
if (workspace_size > float_buffer_size_in_bytes) {
return cudaErrorInsufficientDriver;
}
// Pass workspace pointer only if needed
void* kernel_workspace = nullptr;
if (workspace_size > 0) { // Only provide a pointer if workspace is actually needed
kernel_workspace = float_buffer;
}
status = gemm.initialize(arguments, kernel_workspace);
if (status != cutlass::Status::kSuccess) {
// Don't continue if initialization failed
return cudaErrorNotSupported;
}
status = gemm.run(stream);
if (status != cutlass::Status::kSuccess) {
return cudaErrorUnknown;
}
// Sync to ensure kernel completes
cudaError_t cuda_err = cudaStreamSynchronize(stream);
if (cuda_err != cudaSuccess) {
return cuda_err;
}
return cudaSuccess;
#else
return cudaErrorNotSupported;
#endif
}
} // namespace gemm
} // namespace flashinfer
#endif // FLASHINFER_GEMM_GEMM_GROUPWISE_SM120_CUH_