sglang_v0.5.2/flashinfer_0.3.1/csrc/trtllm_gemm_runner.cu

358 lines
14 KiB
Plaintext

/*
* Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/cuda/EmptyTensor.h>
#include <cuda.h>
#include <string>
#include "flashinfer/trtllm/common.h"
#include "flashinfer/trtllm/gemm/trtllmGen_gemm_export/Enums.h"
#include "flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmInterface.h"
#include "flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/DtypeDecl.h"
#include "flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/SfLayoutDecl.h"
#include "pytorch_extension_utils.h"
namespace {
static thread_local gemm::gemm::GemmInterface::ModuleCache globalTrtllmGenGemmModuleCache;
} // namespace
namespace flashinfer {
struct TrtllmGenGemmRunnerOptions {
gemm::trtllm::gen::Dtype eltType;
gemm::trtllm::gen::Dtype outputType;
bool transposeMmaOutput{false};
gemm::trtllm::gen::SfLayout sfLayoutB;
};
int64_t select_kernel_fp8(int32_t M, int32_t N, int32_t K,
const gemm::gemm::GemmInterface& interface) {
static constexpr const char* KERNEL_NAME_HIGH_N_K_RATIO =
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x128u2_s6_et64x8_m64x8x32_cga1x1x1_16dp256b_TN_transOut_"
"noShflA_dsFp8_schedP2x2x1x3_sm100f";
static constexpr const char* KERNEL_NAME_LOW_N_K_RATIO =
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128u2_s6_et64x32_m64x32x32_cga1x1x1_16dp256b_TN_"
"transOut_noShflA_dsFp8_schedS_sm100f";
static constexpr const char* KERNEL_NAME_LARGE_N =
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128u2_s6_et64x32_m64x32x32_cga1x1x1_16dp256b_TN_"
"transOut_noShflA_dsFp8_schedP2x2x1x3_sm100f";
static constexpr const char* KERNEL_NAME_DEFAULT =
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x128u2_s6_et64x16_m64x16x32_cga1x1x1_16dp256b_TN_"
"transOut_noShflA_dsFp8_schedS_sm100f";
double const n_k_ratio = static_cast<double>(N) / static_cast<double>(K);
std::string kernel_name;
if (n_k_ratio >= 32) {
kernel_name = KERNEL_NAME_HIGH_N_K_RATIO;
} else if (n_k_ratio <= 2.0) {
kernel_name = KERNEL_NAME_LOW_N_K_RATIO;
} else if (N >= 20000) {
kernel_name = KERNEL_NAME_LARGE_N;
} else {
kernel_name = KERNEL_NAME_DEFAULT;
}
auto const& configs = interface.getGemmConfigs();
size_t const num_configs = interface.getNumGemmConfigs();
for (size_t i = 0; i < num_configs; ++i) {
if (std::string(configs[i].mFunctionName) == kernel_name) {
return static_cast<int64_t>(i);
}
}
TORCH_CHECK(false, "Kernel not found");
}
class TrtllmGenGemmRunner {
public:
explicit TrtllmGenGemmRunner(TrtllmGenGemmRunnerOptions const& options) : mOptions(options) {
// Select a GEMM kernel config to use
auto const gemm = gemm::gemm::GemmInterface();
auto const configs = gemm.getGemmConfigs();
mPassingConfigIndices.clear();
for (size_t i = 0; i < gemm.getNumGemmConfigs(); ++i) {
auto const options = configs[i].mOptions;
if (options.mDtypeA == mOptions.eltType && options.mDtypeC == mOptions.outputType &&
options.mTransposeMmaOutput == mOptions.transposeMmaOutput &&
options.mSfLayoutB == mOptions.sfLayoutB) {
mPassingConfigIndices.push_back(i);
}
}
TORCH_CHECK(mPassingConfigIndices.size() > 0,
"No valid tactic found for the given options (precision, transpose, sf layout)");
}
int64_t getWorkspaceSizeInBytes(int64_t m, int64_t n, int64_t k, int64_t tactic) {
auto gemm = gemm::gemm::GemmInterface();
auto const configs = gemm.getGemmConfigs();
TORCH_CHECK(tactic >= 0 && tactic < gemm.getNumGemmConfigs(),
"Invalid tactic in getWorkspaceSizeInBytes");
auto const config = configs[tactic];
gemm::gemm::GemmData gemmData;
gemmData.mProblemDimensions.mM = mOptions.transposeMmaOutput ? n : m;
gemmData.mProblemDimensions.mN = mOptions.transposeMmaOutput ? m : n;
gemmData.mProblemDimensions.mK = k;
gemmData.mProblemDimensions.mRank = 0;
gemmData.mProblemDimensions.mWorldSize = 1;
return gemm.getWorkspaceSizeInBytes(config, gemmData);
}
void run(int64_t m, int64_t n, int64_t k, void const* a, void const* aScale, void const* b,
void const* bScale, void* c, void* cScale, void* cScalePtr, void* workspace,
CUstream stream, int32_t device_index, int64_t tactic) {
auto gemm = gemm::gemm::GemmInterface();
auto const configs = gemm.getGemmConfigs();
TORCH_CHECK(tactic >= 0 && tactic < gemm.getNumGemmConfigs(), "Invalid tactic id in run");
auto const& config = configs[tactic];
TORCH_CHECK(config.mOptions.mSfLayoutB == mOptions.sfLayoutB, "Invalid sf layout in run");
gemm::gemm::GemmData gemmData;
// Dims
gemmData.mProblemDimensions.mM = mOptions.transposeMmaOutput ? n : m;
gemmData.mProblemDimensions.mN = mOptions.transposeMmaOutput ? m : n;
gemmData.mProblemDimensions.mK = k;
gemmData.mProblemDimensions.mRank = 0;
gemmData.mProblemDimensions.mWorldSize = 1;
// Inputs
gemmData.mInputBuffers.mPtrA = mOptions.transposeMmaOutput ? b : a;
gemmData.mInputBuffers.mPtrSfA = mOptions.transposeMmaOutput ? bScale : aScale;
gemmData.mInputBuffers.mPtrB = mOptions.transposeMmaOutput ? a : b;
gemmData.mInputBuffers.mPtrSfB = mOptions.transposeMmaOutput ? aScale : bScale;
gemmData.mInputBuffers.mPtrScaleC = cScale;
// Outputs
gemmData.mOutputBuffers.mPtrC = c;
gemmData.mOutputBuffers.mPtrSfC = cScalePtr;
TORCH_CHECK(gemm.isValidConfig(config, gemmData), "unsupported tactic id in run");
const int32_t multiProcessorCount = [device_index]() {
static thread_local int32_t cached_multi_processor_count = -1;
static thread_local int cached_device_index = -1;
if (device_index == cached_device_index && cached_multi_processor_count != -1) {
return cached_multi_processor_count;
} else {
int32_t count;
cudaError_t cudaStatus =
cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, device_index);
TORCH_CHECK(cudaStatus == cudaSuccess,
"Failed to get device attribute: ", cudaGetErrorString(cudaStatus));
cached_multi_processor_count = count;
cached_device_index = device_index;
return count;
}
}();
TORCH_CHECK(gemm.run(config, workspace, gemmData, static_cast<void*>(stream),
multiProcessorCount, true, globalTrtllmGenGemmModuleCache) == 0,
"Error occurred when running GEMM!");
}
std::vector<int64_t> getValidTactics(int64_t m, int64_t n, int64_t k) const {
auto const gemm = gemm::gemm::GemmInterface();
auto const configs = gemm.getGemmConfigs();
gemm::gemm::GemmData gemmData;
// Dims
gemmData.mProblemDimensions.mM = mOptions.transposeMmaOutput ? n : m;
gemmData.mProblemDimensions.mN = mOptions.transposeMmaOutput ? m : n;
gemmData.mProblemDimensions.mK = k;
gemmData.mProblemDimensions.mRank = 0;
gemmData.mProblemDimensions.mWorldSize = 1;
std::vector<int64_t> sortedIndices = mPassingConfigIndices;
std::sort(sortedIndices.begin(), sortedIndices.end(), [&configs](int64_t idx0, int64_t idx1) {
auto const& optionsA = configs[idx0].mOptions;
auto const& optionsB = configs[idx1].mOptions;
// Sort by tileK sizes first
if (optionsA.mTileK != optionsB.mTileK) {
return optionsA.mTileK > optionsB.mTileK;
}
// Then by splitK sizes
if (optionsA.mNumSlicesForSplitK != optionsB.mNumSlicesForSplitK) {
return optionsA.mNumSlicesForSplitK > optionsB.mNumSlicesForSplitK;
}
// Then by unroll loop 2x for mma
if (optionsA.mUseUnrollLoop2xForMma != optionsB.mUseUnrollLoop2xForMma) {
return optionsA.mUseUnrollLoop2xForMma;
}
return false;
});
bool findLoop2xMma = false;
std::vector<int64_t> validTactics;
for (auto const& configIndex : sortedIndices) {
auto const& config = configs[configIndex];
if (gemm.isValidConfig(config, gemmData)) {
validTactics.push_back(configIndex);
// when loop2x mma is found, only add the tactic that has loop2x mma
if (!findLoop2xMma) {
if (config.mOptions.mUseUnrollLoop2xForMma) {
findLoop2xMma = true;
}
} else {
if (!config.mOptions.mUseUnrollLoop2xForMma) {
break;
}
}
}
}
return validTactics;
}
int64_t selectHeuristic(int64_t m, int64_t n, int64_t k) const {
if (mOptions.eltType == gemm::trtllm::gen::Dtype::E4m3) {
return select_kernel_fp8(m, n, k, gemm::gemm::GemmInterface());
} else if (mOptions.eltType == gemm::trtllm::gen::Dtype::E2m1) {
auto sortedIndices = getValidTactics(m, n, k);
TORCH_CHECK(!sortedIndices.empty(), "No valid tactic found");
// the getValidTactics is sorted by priority, so the first one is the best one
return sortedIndices[0];
} else {
TORCH_CHECK(false, "Unsupported eltType");
}
}
private:
TrtllmGenGemmRunnerOptions mOptions;
std::vector<int64_t> mPassingConfigIndices;
};
void trtllm_gemm(at::Tensor workspace_buffer, at::Tensor a, at::Tensor b, at::Tensor a_scale,
at::Tensor b_scale, at::optional<at::Tensor> globalScale, at::Tensor out,
bool use_8x4_sf_layout, int64_t tactic) {
TORCH_CHECK(a.device() == b.device(), "a and b must be on the same device");
TORCH_CHECK(a.device() == out.device(), "a and out must be on the same device");
TORCH_CHECK(a.is_cuda() && a.is_contiguous(), "a must be a contiguous CUDA tensor");
TORCH_CHECK(b.is_cuda() && b.is_contiguous(), "b must be a contiguous CUDA tensor");
TORCH_CHECK(out.is_cuda() && out.is_contiguous(), "out must be a contiguous CUDA tensor");
TORCH_CHECK(workspace_buffer.is_cuda() && workspace_buffer.is_contiguous(),
"workspace_buffer must be a contiguous CUDA tensor");
TORCH_CHECK(workspace_buffer.sizes().size() == 1, "workspace_buffer must be a 1D CUDA tensor");
TORCH_CHECK(a.dim() == 2, "a must be a matrix");
TORCH_CHECK(b.dim() == 2, "b must be a matrix");
TORCH_CHECK(a.scalar_type() == b.scalar_type(), "a and b must have the same scalar type");
TORCH_CHECK(
a.scalar_type() == at::ScalarType::Float8_e4m3fn || a.scalar_type() == at::ScalarType::Byte,
"a must be a Float8 or Byte(e2m1) tensor");
bool is_fp8 = a.scalar_type() == at::ScalarType::Float8_e4m3fn;
if (is_fp8) {
TORCH_CHECK(!globalScale.has_value(), "globalScale must be a none tensor");
} else {
TORCH_CHECK(a_scale.is_cuda() && a_scale.is_contiguous(),
"a_scale must be a contiguous CUDA tensor");
TORCH_CHECK(b_scale.is_cuda() && b_scale.is_contiguous(),
"b_scale must be a contiguous CUDA tensor");
if (globalScale.has_value()) {
TORCH_CHECK(globalScale.value().is_cuda() && globalScale.value().is_contiguous(),
"globalScale must be a contiguous CUDA tensor");
}
}
int32_t m = a.size(0);
int32_t k = is_fp8 ? a.size(1) : a.size(1) * 2;
int32_t n = b.size(0);
TORCH_CHECK(b.size(1) == a.size(1), "Matrix dimensions don't match for multiplication");
TORCH_CHECK(out.size(0) == m && out.size(1) == n, "Output tensor has wrong dimensions");
auto runner = flashinfer::TrtllmGenGemmRunner(flashinfer::TrtllmGenGemmRunnerOptions{
.eltType = is_fp8 ? gemm::trtllm::gen::Dtype::E4m3 : gemm::trtllm::gen::Dtype::E2m1,
.outputType = gemm::trtllm::gen::Dtype::Bfloat16,
.transposeMmaOutput = true,
.sfLayoutB = use_8x4_sf_layout ? gemm::trtllm::gen::SfLayout::R8c4
: gemm::trtllm::gen::SfLayout::R128c4,
});
if (tactic == -1) {
tactic = runner.selectHeuristic(m, n, k);
}
auto stream = at::cuda::getCurrentCUDAStream(a.device().index());
auto runKernel = [&](void* workspace) {
runner.run(m, n, k, a.data_ptr(), a_scale.data_ptr(), b.data_ptr(), b_scale.data_ptr(),
out.data_ptr(), globalScale.has_value() ? globalScale.value().data_ptr() : nullptr,
nullptr, workspace, stream, a.device().index(), tactic);
};
int64_t const required_workspace_size = runner.getWorkspaceSizeInBytes(m, n, k, tactic);
int64_t const provided_workspace_size =
workspace_buffer.numel() * workspace_buffer.element_size();
if (provided_workspace_size < required_workspace_size) {
auto new_workspace = at::detail::empty_cuda({required_workspace_size}, at::ScalarType::Char,
a.device(), std::nullopt);
runKernel(new_workspace.data_ptr());
} else {
runKernel(workspace_buffer.data_ptr());
}
}
enum class Dtype : int64_t {
E2m1 = 0,
E4m3 = 1,
Bfloat16 = 2,
};
std::vector<int64_t> trtllm_gemm_tactics(int64_t m, int64_t n, int64_t k, int64_t input_dtype,
int64_t output_dtype, bool use_8x4_sf_layout) {
TORCH_CHECK(input_dtype == static_cast<int64_t>(Dtype::E4m3) ||
input_dtype == static_cast<int64_t>(Dtype::E2m1),
"Unsupported input dtype");
TORCH_CHECK(output_dtype == static_cast<int64_t>(Dtype::Bfloat16), "Unsupported output dtype");
auto runner = flashinfer::TrtllmGenGemmRunner(flashinfer::TrtllmGenGemmRunnerOptions{
.eltType = input_dtype == static_cast<int64_t>(Dtype::E4m3) ? gemm::trtllm::gen::Dtype::E4m3
: gemm::trtllm::gen::Dtype::E2m1,
.outputType = gemm::trtllm::gen::Dtype::Bfloat16,
.transposeMmaOutput = true,
.sfLayoutB = use_8x4_sf_layout ? gemm::trtllm::gen::SfLayout::R8c4
: gemm::trtllm::gen::SfLayout::R128c4,
});
return runner.getValidTactics(m, n, k);
}
namespace trtllm_cubin_loader {
#include <flashinfer/cubin_loader.h>
}
} // namespace flashinfer
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
m.def("trtllm_gemm", &flashinfer::trtllm_gemm);
m.def("trtllm_gemm_tactics", &flashinfer::trtllm_gemm_tactics);
}