sglang_v0.5.2/flashinfer_0.3.1/include/flashinfer/gemm/group_gemm_sm90.cuh

161 lines
7.4 KiB
Plaintext

/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_GEMM_GROUP_GEMM_SM90_CUH_
#define FLASHINFER_GEMM_GROUP_GEMM_SM90_CUH_
#include "../allocator.h"
#include "../cutlass_utils.cuh"
#include "../utils.cuh"
namespace flashinfer {
namespace group_gemm {
using namespace cute;
#define DISPATCH_WEIGHT_LAYOUT(is_column_major, WEIGHT_LAYOUT, ...) \
if (is_column_major) { \
using WEIGHT_LAYOUT = cutlass::layout::ColumnMajor; \
__VA_ARGS__ \
} else { \
using WEIGHT_LAYOUT = cutlass::layout::RowMajor; \
__VA_ARGS__ \
}
/**
* Panic wrapper for unwinding CUTLASS errors
*/
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \
<< std::endl; \
exit(EXIT_FAILURE); \
} \
}
template <typename DTypeIn, typename DTypeOut>
cudaError_t CutlassSegmentGEMMSM90Run(void* float_buffer, size_t float_buffer_size_in_bytes,
void* int_buffer, size_t int_buffer_size_in_bytes,
void* all_problems, int64_t batch_size, void* x, void* w,
void* y, void* x_stride, void* w_stride, void* y_stride,
bool weight_column_major, cudaStream_t stream) {
auto compute_capacity = GetCudaComputeCapability();
if (compute_capacity.first < 9) {
std::cerr << "CutlassSegmentGEMMSM90Run requires compute capability of at least 9.0"
<< std::endl;
return cudaErrorNotSupported;
}
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
using ElementA = DTypeIn;
using ElementB = DTypeIn;
using ElementC = DTypeOut;
DISPATCH_WEIGHT_LAYOUT(weight_column_major, WEIGHT_LAYOUT, {
if constexpr (std::is_same_v<WEIGHT_LAYOUT, cutlass::layout::RowMajor> &&
sizeof(DTypeIn) == 1) {
std::ostringstream err_msg;
err_msg << "Row-major layout is not supported for fp8 data type";
FLASHINFER_ERROR(err_msg.str());
} else {
using LayoutA = cutlass::layout::RowMajor;
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
using LayoutB = WEIGHT_LAYOUT;
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
using LayoutC = cutlass::layout::RowMajor;
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
constexpr bool is_fp8 = sizeof(DTypeIn) == 1;
using ElementAccumulator = float;
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using TileShape =
typename std::conditional<is_fp8, Shape<_256, _128, _128>, Shape<_128, _128, _128>>::type;
using ClusterShape =
typename std::conditional<is_fp8, Shape<_2, _2, _1>, Shape<_2, _1, _1>>::type;
using StageCountType = cutlass::gemm::collective::StageCountAuto;
using KernelSchedule = typename std::conditional<
is_fp8, cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>::type;
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator,
ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC,
EpilogueSchedule>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, LayoutB*, AlignmentB,
ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
CollectiveEpilogue>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
cutlass::KernelHardwareInfo hw_info;
cudaGetDevice(&hw_info.device_id);
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Gemm::EpilogueOutputOp::Params params;
params =
typename Gemm::EpilogueOutputOp::Params(ElementAccumulator(1.f), ElementAccumulator(0.f));
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGrouped,
{int(batch_size), reinterpret_cast<ProblemShape::UnderlyingProblemShape*>(all_problems),
nullptr},
{static_cast<const DTypeIn**>(x), reinterpret_cast<StrideA*>(x_stride),
static_cast<const DTypeIn**>(w), reinterpret_cast<StrideB*>(w_stride)},
{params, static_cast<const DTypeOut**>(y), reinterpret_cast<StrideC*>(y_stride),
static_cast<DTypeOut**>(y), reinterpret_cast<StrideD*>(y_stride)},
hw_info};
Gemm gemm;
size_t workspace_size = Gemm::get_workspace_size(arguments);
AlignedAllocator float_allocator(float_buffer, float_buffer_size_in_bytes);
auto workspace_ptr = float_allocator.aligned_alloc<void>(workspace_size, 64,
"sm90_group_gemm_float_workspace");
CUTLASS_CHECK(gemm.can_implement(arguments));
CUTLASS_CHECK(gemm.initialize(arguments, workspace_ptr));
CUTLASS_CHECK(gemm.run(stream));
}
});
return cudaSuccess;
}
} // namespace group_gemm
} // namespace flashinfer
#endif // FLASHINFER_GEMM_GROUP_GEMM_SM90_CUH_