sglang_v0.5.2/flashinfer_0.3.1/include/flashinfer/gemm/cutlass_gemm_configs.h

427 lines
14 KiB
C++

/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_CUTLASS_GEMM_CONFIG_H_
#define FLASHINFER_CUTLASS_GEMM_CONFIG_H_
#include <cassert>
#include <iostream>
#include <sstream>
#include <string>
#include "cute/tensor.hpp"
namespace flashinfer {
namespace gemm {
// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K
// shape
// in the kernel layout details when doing weight only quantization.
enum class CutlassTileConfig {
// Signals that we should run heuristics do choose a config
Undefined,
// Signals that we should run heuristics do choose a config
ChooseWithHeuristic,
// SiMT config
CtaShape128x128x8_WarpShape64x64x8,
// TensorCore configs CTA_N = 128, CTA_K = 64
// Warp configs for M=16
CtaShape16x128x64_WarpShape16x32x64,
// Warp configs for M=32
CtaShape32x128x64_WarpShape32x32x64,
// Warp configs for M=64
CtaShape64x128x64_WarpShape32x64x64,
CtaShape64x64x128_WarpShape32x64x64,
CtaShape64x128x64_WarpShape64x32x64,
// Warp configs for M=128
CtaShape128x64x64_WarpShape64x32x64,
CtaShape128x128x64_WarpShape64x32x64,
CtaShape128x128x64_WarpShape64x64x64,
CtaShape128x128x64_WarpShape128x32x64,
CtaShape128x256x64_WarpShape64x64x64,
// Warp configs for M=256
CtaShape256x128x64_WarpShape64x64x64,
// TensorCore config CTA_N = 64, CTA_K = 128
CtaShape128x64x128_WarpShape64x32x128,
// TensorCore config CTA_N = 256, CTA_K = 64
CtaShape16x256x64_WarpShape16x64x64,
// TensorCore config CTA_N = 256, CTA_K = 128
CtaShape16x256x128_WarpShape16x64x128
};
enum class SplitKStyle {
NO_SPLIT_K,
SPLIT_K_SERIAL,
STREAM_K, // Sm80+
// SPLIT_K_PARALLEL // Not supported yet
};
enum class CutlassTileConfigSM90 {
// Signals that we should run heuristics do choose a config
Undefined,
// Signals that we should run heuristics do choose a config
ChooseWithHeuristic,
// CTA configs for M=64
CtaShape64x16x128B,
CtaShape64x32x128B,
CtaShape64x64x128B,
CtaShape64x128x128B,
CtaShape64x256x128B,
// CTA configs for M=128
CtaShape128x16x128B,
CtaShape128x32x128B,
CtaShape128x64x128B,
CtaShape128x128x128B,
CtaShape128x256x128B,
// CTA configs for M=128
CtaShape256x128x128B,
};
enum class CutlassTileConfigSM100 {
// Signals that we should run heuristics do choose a config
Undefined,
// Signals that we should run heuristics do choose a config
ChooseWithHeuristic,
/*
* Grouped GEMM
*/
// M=64
CtaShape64x32x128B,
CtaShape64x64x128B,
CtaShape64x128x128B,
CtaShape64x256x128B,
// M=128
CtaShape128x8x256B,
CtaShape128x16x128B,
CtaShape128x32x128B,
CtaShape128x64x128B,
CtaShape128x128x128B,
CtaShape128x256x128B,
CtaShape128x128x256B,
CtaShape128x256x256B,
// M=256
CtaShape256x64x128B,
CtaShape256x128x128B,
CtaShape256x256x128B,
};
enum class CutlassTileConfigSM120 {
// Signals that we should run heuristics do choose a config
Undefined,
// Signals that we should run heuristics do choose a config
ChooseWithHeuristic,
CtaShape128x128x128B,
CtaShape128x128x64B,
CtaShape256x128x64B,
CtaShape128x256x64B,
CtaShape128x128x256B,
CtaShape256x128x128B,
};
enum class MainloopScheduleType {
AUTO, // Automatically selects between pingpong and cooperative schedules on Hopper. On older
// architectures, this defaults to the "legacy" main loop schedule.
PINGPONG,
COOPERATIVE,
WARPSPECIALIZED
};
static auto get_mainloop_schedule_name(MainloopScheduleType schedule) {
if (schedule == MainloopScheduleType::AUTO) {
return "auto";
} else if (schedule == MainloopScheduleType::PINGPONG) {
return "pingpong";
} else if (schedule == MainloopScheduleType::COOPERATIVE) {
return "cooperative";
} else if (schedule == MainloopScheduleType::WARPSPECIALIZED) {
return "warpspecialized";
}
return "unknown schedule";
}
enum class EpilogueScheduleType {
AUTO, // Automatically chooses an epilogue schedule compatible with the selected main loop
// schedule for Hopper. For architectures older than hopper, the epilogue is always
// performed by the same thread block as the main loop.
};
enum class TileShape {
TileShape_64x16x128,
TileShape_64x32x128,
TileShape_64x64x128,
TileShape_64x128x128,
TileShape_64x256x128,
TileShape_64x512x128,
TileShape_128x16x128,
TileShape_128x32x128,
TileShape_128x64x128,
TileShape_128x128x128,
TileShape_128x256x128
};
template <TileShape Shape_MNK>
constexpr auto get_tile_shape() {
using namespace cute;
if constexpr (Shape_MNK == TileShape::TileShape_64x16x128) {
return cute::Shape<_64, _16, _128>{};
} else if constexpr (Shape_MNK == TileShape::TileShape_64x32x128) {
return cute::Shape<_64, _32, _128>{};
} else if constexpr (Shape_MNK == TileShape::TileShape_64x64x128) {
return cute::Shape<_64, _64, _128>{};
} else if constexpr (Shape_MNK == TileShape::TileShape_64x128x128) {
return cute::Shape<_64, _128, _128>{};
} else if constexpr (Shape_MNK == TileShape::TileShape_64x256x128) {
return cute::Shape<_64, _256, _128>{};
} else if constexpr (Shape_MNK == TileShape::TileShape_64x512x128) {
return cute::Shape<_64, _512, _128>{};
} else if constexpr (Shape_MNK == TileShape::TileShape_128x16x128) {
return cute::Shape<_128, _16, _128>{};
} else if constexpr (Shape_MNK == TileShape::TileShape_128x32x128) {
return cute::Shape<_128, _32, _128>{};
} else if constexpr (Shape_MNK == TileShape::TileShape_128x64x128) {
return cute::Shape<_128, _64, _128>{};
} else if constexpr (Shape_MNK == TileShape::TileShape_128x128x128) {
return cute::Shape<_128, _128, _128>{};
} else if constexpr (Shape_MNK == TileShape::TileShape_128x256x128) {
return cute::Shape<_128, _256, _128>{};
}
}
static auto get_tile_shape_name(TileShape Shape_MNK) {
if (Shape_MNK == TileShape::TileShape_64x16x128) {
return "64x16x128";
} else if (Shape_MNK == TileShape::TileShape_64x32x128) {
return "64x32x128";
} else if (Shape_MNK == TileShape::TileShape_64x64x128) {
return "64x64x128";
} else if (Shape_MNK == TileShape::TileShape_64x128x128) {
return "64x128x128";
} else if (Shape_MNK == TileShape::TileShape_64x256x128) {
return "64x256x128";
} else if (Shape_MNK == TileShape::TileShape_64x512x128) {
return "64x512x128";
} else if (Shape_MNK == TileShape::TileShape_128x16x128) {
return "128x16x128";
} else if (Shape_MNK == TileShape::TileShape_128x32x128) {
return "128x32x128";
} else if (Shape_MNK == TileShape::TileShape_128x64x128) {
return "128x64x128";
} else if (Shape_MNK == TileShape::TileShape_128x128x128) {
return "128x128x128";
} else if (Shape_MNK == TileShape::TileShape_128x256x128) {
return "128x256x128";
}
return "Unknown shape";
}
enum class ClusterShape {
ClusterShape_1x1x1,
ClusterShape_2x1x1,
ClusterShape_1x2x1,
ClusterShape_2x2x1,
ClusterShape_1x4x1,
ClusterShape_4x2x1,
ClusterShape_2x4x1,
ClusterShape_4x4x1,
ClusterShape_1x8x1,
ClusterShape_8x1x1
};
static auto get_cluster_shape_name(ClusterShape Shape_MNK) {
if (Shape_MNK == ClusterShape::ClusterShape_1x1x1) {
return "1x1x1";
} else if (Shape_MNK == ClusterShape::ClusterShape_2x1x1) {
return "2x1x1";
} else if (Shape_MNK == ClusterShape::ClusterShape_1x2x1) {
return "1x2x1";
} else if (Shape_MNK == ClusterShape::ClusterShape_2x2x1) {
return "2x2x1";
} else if (Shape_MNK == ClusterShape::ClusterShape_1x8x1) {
return "1x8x1";
} else if (Shape_MNK == ClusterShape::ClusterShape_8x1x1) {
return "8x1x1";
}
return "Unknown shape";
}
template <ClusterShape Shape_MNK>
constexpr auto get_cluster_shape() {
using namespace cute;
if constexpr (Shape_MNK == ClusterShape::ClusterShape_1x1x1) {
return cute::Shape<_1, _1, _1>{};
} else if constexpr (Shape_MNK == ClusterShape::ClusterShape_2x1x1) {
return cute::Shape<_2, _1, _1>{};
} else if constexpr (Shape_MNK == ClusterShape::ClusterShape_1x2x1) {
return cute::Shape<_1, _2, _1>{};
} else if constexpr (Shape_MNK == ClusterShape::ClusterShape_2x2x1) {
return cute::Shape<_2, _2, _1>{};
} else if constexpr (Shape_MNK == ClusterShape::ClusterShape_1x8x1) {
return cute::Shape<_1, _8, _1>{};
} else if constexpr (Shape_MNK == ClusterShape::ClusterShape_8x1x1) {
return cute::Shape<_8, _1, _1>{};
}
}
struct CutlassGemmConfig {
enum CandidateConfigTypeParam : int {
NONE = 0,
WEIGHT_ONLY = 1u << 0,
SIMT_ONLY = 1u << 1,
INT8_ONLY = 1u << 2,
HOPPER = 1u << 3,
BLACKWELL = 1u << 4,
GROUPED_GEMM = 1u << 5,
FP8_ONLY = 1u << 6,
FP4_ONLY = 1u << 7
};
CutlassTileConfig tile_config_sm80 = CutlassTileConfig::ChooseWithHeuristic;
SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K;
int split_k_factor = -1;
int stages = -1;
// config options for sm90
CutlassTileConfigSM90 tile_config_sm90 = CutlassTileConfigSM90::ChooseWithHeuristic;
CutlassTileConfigSM100 tile_config_sm100 = CutlassTileConfigSM100::ChooseWithHeuristic;
CutlassTileConfigSM120 tile_config_sm120 = CutlassTileConfigSM120::ChooseWithHeuristic;
MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO;
EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO;
ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1;
bool enableCudaKernel = false;
int sm_version = 80; // Use 80 as a catch all for <90
bool is_tma_warp_specialized = false;
CutlassGemmConfig() = default;
CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor,
int stages)
: tile_config_sm80(tile_config),
split_k_style(split_k_style),
split_k_factor(split_k_factor),
stages(stages),
sm_version(80) {}
CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule,
EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape)
: tile_config_sm90(tile_config_sm90),
mainloop_schedule(mainloop_schedule),
epilogue_schedule(epilogue_schedule),
cluster_shape(cluster_shape),
sm_version(90),
is_tma_warp_specialized(true) {}
CutlassGemmConfig(CutlassTileConfigSM100 tile_config_sm100,
MainloopScheduleType mainloop_schedule, EpilogueScheduleType epilogue_schedule,
ClusterShape cluster_shape)
: tile_config_sm100(tile_config_sm100),
mainloop_schedule(mainloop_schedule),
epilogue_schedule(epilogue_schedule),
cluster_shape(cluster_shape),
sm_version(100),
is_tma_warp_specialized(true) {}
CutlassGemmConfig(CutlassTileConfigSM120 tile_config_sm120,
MainloopScheduleType mainloop_schedule, EpilogueScheduleType epilogue_schedule,
ClusterShape cluster_shape)
: tile_config_sm120(tile_config_sm120),
mainloop_schedule(mainloop_schedule),
epilogue_schedule(epilogue_schedule),
cluster_shape(cluster_shape),
sm_version(120),
is_tma_warp_specialized(true) {}
int getTileConfigAsInt() const {
if (sm_version == 120) return (int)tile_config_sm120;
if (sm_version == 110) return (int)tile_config_sm100;
if (sm_version >= 100) return (int)tile_config_sm100;
if (sm_version == 90) return (int)tile_config_sm90;
if (sm_version < 90) return (int)tile_config_sm80;
assert(false && "Invalid SM version");
return -1;
}
std::string toString() const {
std::stringstream tactic;
tactic << "Cutlass GEMM Tactic";
if (is_tma_warp_specialized) {
assert(sm_version >= 90 && "Invalid cutlass GEMM config");
tactic << "\n\tstyle=TMA Warp Specialized"
<< "\n\tsm: " << sm_version << "\n\ttile shape ID: " << getTileConfigAsInt()
<< "\n\tcluster shape ID: " << (int)cluster_shape
<< "\n\tmainloop sched: " << (int)mainloop_schedule
<< "\n\tepi sched: " << (int)epilogue_schedule
<< "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false");
} else if (tile_config_sm80 != flashinfer::gemm::CutlassTileConfig::ChooseWithHeuristic) {
assert(sm_version < 90 && "Invalid cutlass GEMM config");
tactic << "\n\tstyle=compatible"
<< "\n\ttile shape ID: " << (int)tile_config_sm80 << "\n\tstages: " << (int)stages
<< "\n\tsplit k: " << (int)split_k_factor
<< "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false");
} else if (enableCudaKernel) {
tactic << "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false");
} else {
tactic << "\n\tundefined";
}
tactic << "\n";
return tactic.str();
}
};
inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config) {
// clang-format off
if (config.is_tma_warp_specialized)
{
out << "tile_config_sm90_enum: " << config.getTileConfigAsInt()
<< ", mainloop_schedule_enum: " << int(config.mainloop_schedule)
<< ", epilogue_schedule_enum: " << int(config.epilogue_schedule)
<< ", cluster_shape_enum: " << int(config.cluster_shape)
<< ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false");
}
else
{
out << "tile_config_enum: " << config.getTileConfigAsInt()
<< ", split_k_style_enum: " << int(config.split_k_style)
<< ", split_k_factor: " << config.split_k_factor
<< ", stages: " << config.stages
<< ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false");
}
// clang-format on
return out;
}
} // namespace gemm
} // namespace flashinfer
#endif // FLASHINFER_CUTLASS_GEMM_CONFIG_H_