982 lines
33 KiB
Python
982 lines
33 KiB
Python
import enum
|
|
import os
|
|
from itertools import chain, product
|
|
|
|
from .cutlass_library import (
|
|
enum_auto,
|
|
DataTypeNames,
|
|
DataTypeSize,
|
|
DataType,
|
|
DataTypeTag,
|
|
GemmKind,
|
|
GemmKindNames,
|
|
KernelScheduleType,
|
|
KernelScheduleTag,
|
|
KernelScheduleSuffixes,
|
|
EpilogueScheduleType,
|
|
EpilogueScheduleTag,
|
|
EpilogueScheduleSuffixes,
|
|
)
|
|
from ..cpp_ext import is_cuda_version_at_least
|
|
|
|
|
|
################################################################################
|
|
# Epilogue Tag enum and string utils
|
|
class TrtLlm_EpilogueTag(enum.Enum):
|
|
epilogue_op_default = enum_auto()
|
|
epilogue_op_bias = enum_auto()
|
|
epilogue_op_silu = enum_auto()
|
|
epilogue_op_gelu = enum_auto()
|
|
|
|
|
|
class TrtLlm_EpilogueFusion(enum.Enum):
|
|
epilogue_fusion_none = enum_auto()
|
|
epilogue_fusion_finalize = enum_auto()
|
|
|
|
|
|
EpiTagNames = {
|
|
TrtLlm_EpilogueTag.epilogue_op_default: "lc", # linear combination
|
|
TrtLlm_EpilogueTag.epilogue_op_bias: "lc_bias", # linear combination with bias addition
|
|
TrtLlm_EpilogueTag.epilogue_op_silu: "silu", # silu or swiglu
|
|
TrtLlm_EpilogueTag.epilogue_op_gelu: "gelu", # gelu or geglu
|
|
}
|
|
|
|
EpiTag = {
|
|
TrtLlm_EpilogueTag.epilogue_op_default: "tensorrt_llm::cutlass_extensions::EpilogueOpDefault",
|
|
TrtLlm_EpilogueTag.epilogue_op_bias: "tensorrt_llm::cutlass_extensions::EpilogueOpBias",
|
|
TrtLlm_EpilogueTag.epilogue_op_silu: "tensorrt_llm::cutlass_extensions::EpilogueOpDefaultSilu",
|
|
TrtLlm_EpilogueTag.epilogue_op_gelu: "tensorrt_llm::cutlass_extensions::EpilogueOpDefaultFtGelu",
|
|
}
|
|
|
|
EpiFusion = {
|
|
TrtLlm_EpilogueFusion.epilogue_fusion_none: "tensorrt_llm::TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE",
|
|
TrtLlm_EpilogueFusion.epilogue_fusion_finalize: "tensorrt_llm::TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE",
|
|
}
|
|
|
|
EpiFusionSuffixes = {
|
|
None: "",
|
|
TrtLlm_EpilogueFusion.epilogue_fusion_none: "EpilogueFusion_NONE",
|
|
TrtLlm_EpilogueFusion.epilogue_fusion_finalize: "EpilogueFusion_FINALIZE",
|
|
}
|
|
|
|
|
|
################################################################################
|
|
# Quantization Operation and string utils
|
|
class TrtLlm_QuantOp(enum.Enum):
|
|
per_column_scale_only = enum_auto()
|
|
finegrained_scale_only = enum_auto()
|
|
finegrained_scale_and_zeros = enum_auto()
|
|
none = enum_auto()
|
|
|
|
|
|
QuantOpNames = {
|
|
TrtLlm_QuantOp.per_column_scale_only: "cs",
|
|
TrtLlm_QuantOp.finegrained_scale_only: "fgs",
|
|
TrtLlm_QuantOp.finegrained_scale_and_zeros: "fgsz",
|
|
TrtLlm_QuantOp.none: "noquant",
|
|
}
|
|
|
|
QuantOpTag = {
|
|
TrtLlm_QuantOp.per_column_scale_only: "cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY",
|
|
TrtLlm_QuantOp.finegrained_scale_only: "cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY",
|
|
TrtLlm_QuantOp.finegrained_scale_and_zeros: "cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS",
|
|
TrtLlm_QuantOp.none: "void",
|
|
}
|
|
|
|
################################################################################
|
|
# The activations, biases, scales and zeros are instantiated using CUDA types,
|
|
# not CUTLASS types. This map materializes the name of the CUDA type.
|
|
|
|
|
|
class e2m1_type: # WAR until we have upgraded everything to a supported version
|
|
pass
|
|
|
|
|
|
e2m1 = e2m1_type()
|
|
|
|
|
|
def GetDataTypeBits(type):
|
|
if isinstance(type, e2m1_type):
|
|
return 4
|
|
return DataTypeSize[type]
|
|
|
|
|
|
def GetDataTypeNames(type, is_mx_fpx=None):
|
|
mxprefix = ""
|
|
if is_mx_fpx is not None:
|
|
mxprefix = "mx_" if is_mx_fpx else "nv_"
|
|
if isinstance(type, e2m1_type):
|
|
return mxprefix + "e2m1"
|
|
return mxprefix + DataTypeNames[type]
|
|
|
|
|
|
CudaTypeName = {
|
|
e2m1: "SafeFP4",
|
|
DataType.e4m3: "__nv_fp8_e4m3",
|
|
DataType.bf16: "__nv_bfloat16",
|
|
DataType.f16: "half",
|
|
DataType.f32: "float",
|
|
DataType.e2m1: "__nv_fp4_e2m1",
|
|
DataType.ue8m0: "cutlass::float_ue8m0_t",
|
|
DataType.u4: "cutlass::uint4b_t",
|
|
}
|
|
|
|
|
|
################################################################################
|
|
# A data structure holding all info to instantiate gemm launchers in TRT LLM.
|
|
class TrtLlm_GemmLauncher:
|
|
def __init__(
|
|
self,
|
|
gemm_kind,
|
|
arch,
|
|
act_type,
|
|
weight_type,
|
|
scalezero_type,
|
|
bias_type,
|
|
output_type,
|
|
quant_op,
|
|
epi_tag,
|
|
cta_shape,
|
|
warp_shape,
|
|
stages,
|
|
cga_shape,
|
|
mainloop_schedule,
|
|
epi_schedule,
|
|
epi_fusion=None,
|
|
is_mx_fpx=False,
|
|
):
|
|
self.gemm_kind = gemm_kind
|
|
self.arch = arch
|
|
self.act_type = act_type
|
|
self.weight_type = weight_type
|
|
self.scalezero_type = scalezero_type
|
|
self.bias_type = bias_type
|
|
self.output_type = output_type
|
|
self.quant_op = quant_op
|
|
self.epi_tag = epi_tag
|
|
self.cta_shape = cta_shape
|
|
self.warp_shape = warp_shape
|
|
self.stages = stages
|
|
self.cga_shape = cga_shape
|
|
self.mainloop_schedule = mainloop_schedule
|
|
self.epi_schedule = epi_schedule
|
|
self.epi_fusion = epi_fusion
|
|
self.is_mx_fpx = is_mx_fpx
|
|
|
|
def __repr__(self):
|
|
kernel_prefix = "{}_sm{}_{}_{}_{}_{}_{}_{}_{}_{}x{}x{}_{}x{}x{}_{}".format(
|
|
GemmKindNames[self.gemm_kind],
|
|
self.arch,
|
|
GetDataTypeNames(self.act_type, self.is_mx_fpx),
|
|
GetDataTypeNames(self.weight_type, self.is_mx_fpx),
|
|
GetDataTypeNames(self.scalezero_type),
|
|
GetDataTypeNames(self.bias_type),
|
|
GetDataTypeNames(self.output_type),
|
|
QuantOpNames[self.quant_op],
|
|
EpiTagNames[self.epi_tag],
|
|
self.cta_shape[0],
|
|
self.cta_shape[1],
|
|
self.cta_shape[2],
|
|
self.warp_shape[0],
|
|
self.warp_shape[1],
|
|
self.warp_shape[2],
|
|
self.stages,
|
|
)
|
|
|
|
hopper_suffix = "_{}x{}x{}{}{}{}".format(
|
|
self.cga_shape[0],
|
|
self.cga_shape[1],
|
|
self.cga_shape[2],
|
|
KernelScheduleSuffixes[self.mainloop_schedule],
|
|
EpilogueScheduleSuffixes[self.epi_schedule],
|
|
EpiFusionSuffixes[self.epi_fusion],
|
|
)
|
|
|
|
if self.arch >= 90:
|
|
return kernel_prefix + hopper_suffix
|
|
elif self.arch > 100:
|
|
raise ValueError(f"SM{self.arch} not supported yet.")
|
|
return kernel_prefix
|
|
|
|
|
|
################################################################################
|
|
def tuple_to_cute_shape(shape):
|
|
return f"cute::Shape<cute::Int<{shape[0]}>, cute::Int<{shape[1]}>, cute::Int<{shape[2]}>>"
|
|
|
|
|
|
def instantiate_operation_tma_warp_specialized(operation):
|
|
act_tag = CudaTypeName[operation.act_type]
|
|
scale_zero_tag = CudaTypeName[operation.scalezero_type]
|
|
bias_tag = CudaTypeName[operation.bias_type]
|
|
out_tag = CudaTypeName[operation.output_type]
|
|
|
|
quant_op = QuantOpTag[operation.quant_op]
|
|
epi_tag = EpiTag[operation.epi_tag]
|
|
|
|
cute_cta_shape = tuple_to_cute_shape(operation.cta_shape)
|
|
cute_cga_shape = tuple_to_cute_shape(operation.cga_shape)
|
|
|
|
kernel_sched = KernelScheduleTag[operation.mainloop_schedule]
|
|
epi_sched = EpilogueScheduleTag[operation.epi_schedule]
|
|
|
|
if operation.gemm_kind == GemmKind.Gemm:
|
|
weight_tag = DataTypeTag[operation.weight_type]
|
|
instantiation = f"""
|
|
template void sm90_generic_mixed_gemm_kernelLauncher<{act_tag}, {weight_tag}, {scale_zero_tag}, {bias_tag}, {out_tag},
|
|
{quant_op}, {epi_tag},
|
|
{cute_cta_shape}, {cute_cga_shape},
|
|
{kernel_sched}, {epi_sched}> (
|
|
const {act_tag}*, const {weight_tag}*, const {scale_zero_tag}*, const {scale_zero_tag}*, const {bias_tag}*, const float,
|
|
{out_tag}*, int, int, int, const int, tensorrt_llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
|
|
);
|
|
"""
|
|
elif operation.gemm_kind == GemmKind.Grouped:
|
|
if operation.act_type != operation.weight_type and (
|
|
operation.act_type != DataType.e4m3 or operation.weight_type != e2m1
|
|
):
|
|
# Mixed MoE GEMM
|
|
weight_tag = CudaTypeName[operation.weight_type]
|
|
instantiation = f"""
|
|
template void sm90_generic_mixed_moe_gemm_kernelLauncher<{act_tag}, {weight_tag}, {out_tag},
|
|
{epi_tag}, {cute_cta_shape}, {cute_cga_shape}, {kernel_sched}, {epi_sched}, {quant_op}> (
|
|
GroupedGemmInput<{act_tag}, {weight_tag}, {out_tag}, {out_tag}>inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size);
|
|
"""
|
|
else:
|
|
# Similar to MixedInput above, we must modify the tags for grouped gemm as CUTLASS library does not have the updated schedules
|
|
assert operation.mainloop_schedule in [
|
|
KernelScheduleType.TmaWarpSpecializedCooperative,
|
|
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
|
|
]
|
|
assert operation.epi_schedule == EpilogueScheduleType.NoSmemWarpSpecialized
|
|
kernel_sched.replace("::Kernel", "::KernelGrouped")
|
|
epi_sched += "Grouped"
|
|
|
|
# arch_tag = f"cutlass::arch::Sm{operation.arch}"
|
|
arch_tag = f"Sm{operation.arch}"
|
|
weight_tag = CudaTypeName[operation.weight_type]
|
|
assert operation.epi_fusion is not None
|
|
epi_fusion = EpiFusion[operation.epi_fusion]
|
|
|
|
epi_fusion = epi_fusion.split(":")[-1]
|
|
epi_tag = epi_tag.split(":")[-1]
|
|
|
|
guard_map = {
|
|
e2m1: "defined(ENABLE_FP4)",
|
|
DataType.e4m3: "defined(ENABLE_FP8)",
|
|
DataType.bf16: "defined(ENABLE_BF16)",
|
|
}
|
|
guard_act = guard_map.get(operation.act_type, "1")
|
|
guard_weight = guard_map.get(operation.weight_type, "1")
|
|
# TODO Revert this once compiler bug is fixed so we can use template instead of macro again
|
|
# instantiation = f"""
|
|
# template void tma_warp_specialized_generic_moe_gemm_kernelLauncher<{arch_tag}, {act_tag}, {weight_tag}, {out_tag},
|
|
# {epi_tag}, {epi_fusion}, {cute_cta_shape}, {cute_cga_shape}, false>
|
|
# (TmaWarpSpecializedGroupedGemmInput, int, int, cudaStream_t, int*, size_t*);
|
|
# """
|
|
instantiation = f"""
|
|
#if {guard_act} && {guard_weight}\n
|
|
INSTANTIATE_TMA_WARP_SPECIALIZED_MOE_GEMM({arch_tag}, {act_tag}, {weight_tag}, {out_tag},
|
|
{epi_tag}, {epi_fusion}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.cga_shape[0]}, {operation.cga_shape[1]}, {operation.cga_shape[2]}, {"true" if operation.is_mx_fpx else "false"}, false);\n
|
|
#endif
|
|
"""
|
|
return instantiation
|
|
|
|
|
|
def instantiate_operation_sm80(operation):
|
|
act_tag = DataTypeTag[operation.dtype]
|
|
weight_tag = DataTypeTag[operation.dtype]
|
|
epi_tag = EpiTag[operation.epi_tag]
|
|
|
|
instantiation = f"""
|
|
template void sm80_generic_fused_moe_gemm_kernelLauncher<{act_tag}, {weight_tag}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.stage}, {epi_tag}>
|
|
({act_tag} const* A, {weight_tag} const* B, {act_tag} const* biases, bool bias_is_broadcast, {act_tag}* C, int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, int* kernel_occupancy);
|
|
"""
|
|
return instantiation
|
|
|
|
|
|
def instantiate_operation(operation):
|
|
if operation.arch == 80:
|
|
return instantiate_operation_sm80(operation)
|
|
elif operation.arch >= 90:
|
|
return instantiate_operation_tma_warp_specialized(operation)
|
|
|
|
|
|
def get_file_content(launcher_inl_files, operations):
|
|
assert operations
|
|
include_list = list()
|
|
for file in launcher_inl_files:
|
|
include_list.append(f'#include "{file}"')
|
|
includes = "\n".join(include_list)
|
|
|
|
insts_list = list()
|
|
for op in operations:
|
|
insts_list.append(instantiate_operation(op))
|
|
instantiations = "\n".join(insts_list)
|
|
|
|
file_content = f"""{includes}
|
|
namespace tensorrt_llm
|
|
{{
|
|
namespace kernels
|
|
{{
|
|
namespace cutlass_kernels
|
|
{{
|
|
|
|
{instantiations}
|
|
|
|
}} // namespace cutlass_kernels
|
|
}} // namespace kernels
|
|
}} // namespace tensorrt_llm
|
|
"""
|
|
return file_content
|
|
|
|
|
|
def clean_leftover_files(output_dir, generated_files):
|
|
"""Remove leftover generated files that weren't created in this run."""
|
|
for root, _dirs, files in os.walk(output_dir):
|
|
for file in files:
|
|
file_path = os.path.join(root, file)
|
|
if file_path not in generated_files:
|
|
os.remove(file_path)
|
|
|
|
|
|
def write_file(launcher_inl_files, operations, output_file):
|
|
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
|
# Avoid changing modified time if file content is up to date
|
|
content = get_file_content(launcher_inl_files, operations)
|
|
try:
|
|
with open(output_file, mode="r") as f:
|
|
if f.read() == content:
|
|
return
|
|
except FileNotFoundError:
|
|
pass
|
|
with open(output_file, mode="w") as f:
|
|
f.write(content)
|
|
|
|
|
|
from operator import mul, truediv
|
|
|
|
|
|
def elementwise(x, y, f):
|
|
return tuple(f(a, b) for (a, b) in zip(x, y))
|
|
|
|
|
|
def is_gemm_op_valid_sm100(op):
|
|
# TODO These are much more restricted than theory dictates, investigate if more can be enabled in future
|
|
tile_m, tile_n, _ = elementwise(op.cta_shape, op.cga_shape, truediv)
|
|
cga_m, cga_n, _ = op.cga_shape
|
|
|
|
# Default shapes
|
|
# This is epilogue tile size. For two CTA this is actually size 128/256 for the MMA
|
|
if tile_m not in [64, 128]:
|
|
return False
|
|
|
|
# FP4 Has some much more limited sizes
|
|
if op.act_type == e2m1 or op.weight_type == e2m1:
|
|
# TODO 128x256x256 FP4 compiles but crashes
|
|
# if tile_n % 64 != 0 or tile_n < 128:
|
|
# return False
|
|
if tile_n not in [64, 128, 256] or tile_m != 128:
|
|
return False
|
|
|
|
# Shapes for fp8 small N shapes
|
|
if (
|
|
op.act_type == DataType.e4m3
|
|
and (tile_n == 16 or tile_n == 8)
|
|
and (cga_m == 1 and cga_n == 1)
|
|
):
|
|
# todo: double check why this is disable in CUTLASS backend. @yuhan
|
|
if tile_m == 128 and tile_n == 8:
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
# Default alignment requirements
|
|
if tile_n % 32 != 0 or tile_n < 32 or tile_n > 256:
|
|
return False
|
|
|
|
# Two CTA mode needs bigger tile n alignment
|
|
if cga_m % 2 == 0 and tile_n % 64 != 0:
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def is_gemm_op_valid(op):
|
|
tile_m, tile_n, _ = op.cta_shape
|
|
cga_m, cga_n, _ = op.cga_shape
|
|
|
|
if cga_m == 1 and cga_n == 1:
|
|
return True
|
|
|
|
if cga_m == 2 and cga_n == 1 and tile_m >= 128:
|
|
return True
|
|
|
|
if cga_m == 1 and cga_n == 2 and tile_n >= 128:
|
|
return True
|
|
|
|
if cga_m == 2 and cga_n == 2 and tile_m >= 128 and tile_n >= 128:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def is_grouped_gemm_op_valid(op):
|
|
if not is_gemm_op_valid(op):
|
|
return False
|
|
|
|
if op.epi_tag != TrtLlm_EpilogueTag.epilogue_op_default:
|
|
return False
|
|
|
|
if op.epi_schedule != EpilogueScheduleType.NoSmemWarpSpecialized:
|
|
return False
|
|
|
|
if op.mainloop_schedule not in [
|
|
KernelScheduleType.TmaWarpSpecializedCooperative,
|
|
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
|
|
]:
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def is_op_valid(op):
|
|
if op.arch >= 100:
|
|
return is_gemm_op_valid_sm100(op)
|
|
|
|
if op.gemm_kind == GemmKind.Gemm:
|
|
return is_gemm_op_valid(op)
|
|
if op.gemm_kind == GemmKind.Grouped:
|
|
return is_grouped_gemm_op_valid(op)
|
|
|
|
|
|
################################################################################
|
|
def generate_sm90_mixed_gemm_operations():
|
|
arch = 90
|
|
|
|
# For legacy reasons, we use unsigned types for the weights. The instanitated template
|
|
# will remap those back to the signed type.
|
|
# Takes the form (activation_type, weight_type, scalezero_type, bias_type, output_type)
|
|
supported_dtypes = [
|
|
(DataType.e4m3, DataType.u4, DataType.f16, DataType.f16, DataType.f16),
|
|
(DataType.e4m3, DataType.u4, DataType.f16, DataType.bf16, DataType.bf16),
|
|
(DataType.f16, DataType.u4, DataType.f16, DataType.f16, DataType.f16),
|
|
(DataType.bf16, DataType.u4, DataType.bf16, DataType.bf16, DataType.bf16),
|
|
(DataType.f16, DataType.u8, DataType.f16, DataType.f16, DataType.f16),
|
|
(DataType.bf16, DataType.u8, DataType.bf16, DataType.bf16, DataType.bf16),
|
|
]
|
|
|
|
quant_ops = [
|
|
TrtLlm_QuantOp.per_column_scale_only,
|
|
TrtLlm_QuantOp.finegrained_scale_only,
|
|
TrtLlm_QuantOp.finegrained_scale_and_zeros,
|
|
]
|
|
|
|
epi_tags = [TrtLlm_EpilogueTag.epilogue_op_bias]
|
|
|
|
M_TILES = [64, 128]
|
|
N_TILES = [16, 32, 64, 128, 256]
|
|
cta_shapes_mn = product(M_TILES, N_TILES)
|
|
|
|
warp_shape = [4, 1, 1]
|
|
stages = 0 # auto
|
|
|
|
cga_shapes = product([1, 2], [1, 2], [1])
|
|
|
|
partial_args = product(
|
|
supported_dtypes, quant_ops, epi_tags, cta_shapes_mn, cga_shapes
|
|
)
|
|
|
|
operations = list()
|
|
for dtype_combo, quant_op, epi_tag, cta_shape_mn, cga_shape in partial_args:
|
|
max_k_bits = 128 * 8
|
|
cta_shape_k = max_k_bits // GetDataTypeBits(dtype_combo[0])
|
|
cta_shape_mnk = cta_shape_mn + (cta_shape_k,)
|
|
|
|
use_coop = cta_shape_mn[0] == 128
|
|
mainloop_schedule = (
|
|
KernelScheduleType.TmaWarpSpecializedCooperative
|
|
if use_coop
|
|
else KernelScheduleType.TmaWarpSpecializedPingpong
|
|
)
|
|
epi_schedule = (
|
|
EpilogueScheduleType.TmaWarpSpecializedCooperative
|
|
if use_coop
|
|
else EpilogueScheduleType.TmaWarpSpecialized
|
|
)
|
|
|
|
fpA_intB_operation = TrtLlm_GemmLauncher(
|
|
GemmKind.Gemm,
|
|
arch,
|
|
*dtype_combo,
|
|
quant_op,
|
|
epi_tag,
|
|
cta_shape_mnk,
|
|
warp_shape,
|
|
stages,
|
|
cga_shape,
|
|
mainloop_schedule,
|
|
epi_schedule,
|
|
)
|
|
|
|
if is_op_valid(fpA_intB_operation):
|
|
operations.append(fpA_intB_operation)
|
|
|
|
return operations
|
|
|
|
|
|
def generate_sm90_grouped_gemm_operations(is_arch_enabled):
|
|
if not is_arch_enabled:
|
|
return []
|
|
arch = 90
|
|
supported_dtypes = [DataType.f16, DataType.bf16, DataType.f32, DataType.e4m3]
|
|
quant_ops = [TrtLlm_QuantOp.none]
|
|
epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default]
|
|
M_TILES = [128] # Currently M tile must be 128 for Grouped GEMM
|
|
N_TILES = [16, 32, 64, 128, 256]
|
|
cta_shapes_mn = list(product(M_TILES, N_TILES)) + [(256, 128)]
|
|
|
|
warp_shape = [0, 0, 0] # ignored except for naming
|
|
stages = 0 # auto
|
|
|
|
epi_fusions = [
|
|
TrtLlm_EpilogueFusion.epilogue_fusion_none,
|
|
TrtLlm_EpilogueFusion.epilogue_fusion_finalize,
|
|
]
|
|
|
|
cga_shapes = product([1, 2], [1, 2], [1])
|
|
|
|
partial_args = product(
|
|
supported_dtypes, quant_ops, epi_tags, epi_fusions, cta_shapes_mn, cga_shapes
|
|
)
|
|
|
|
operations = list()
|
|
for dtype, quant_op, epi_tag, epi_fusion, cta_shape_mn, cga_shape in partial_args:
|
|
max_k_bits = 128 * 8
|
|
cta_shape_k = max_k_bits // GetDataTypeBits(dtype)
|
|
cta_shape_mnk = cta_shape_mn + (cta_shape_k,)
|
|
|
|
mainloop_schedule = (
|
|
KernelScheduleType.TmaWarpSpecializedCooperative
|
|
if dtype != DataType.e4m3
|
|
else KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum
|
|
)
|
|
epi_schedule = EpilogueScheduleType.NoSmemWarpSpecialized
|
|
|
|
otypes = [dtype]
|
|
if dtype == DataType.e4m3:
|
|
otypes = [DataType.f16, DataType.bf16]
|
|
|
|
for otype in otypes:
|
|
moe_gemm_operation = TrtLlm_GemmLauncher(
|
|
GemmKind.Grouped,
|
|
arch,
|
|
dtype,
|
|
dtype,
|
|
dtype,
|
|
dtype,
|
|
otype,
|
|
quant_op,
|
|
epi_tag,
|
|
cta_shape_mnk,
|
|
warp_shape,
|
|
stages,
|
|
cga_shape,
|
|
mainloop_schedule,
|
|
epi_schedule,
|
|
epi_fusion,
|
|
)
|
|
|
|
if is_op_valid(moe_gemm_operation):
|
|
operations.append(moe_gemm_operation)
|
|
return operations
|
|
|
|
|
|
def generate_sm90_mixed_type_grouped_gemm_operations(is_arch_enabled):
|
|
if not is_arch_enabled:
|
|
return []
|
|
arch = 90
|
|
|
|
# act_type, weight_type, scalezero_type, bias_type, output_type
|
|
supported_dtypes_int4 = [
|
|
(DataType.e4m3, DataType.u4, DataType.f16, DataType.f16, DataType.f16),
|
|
(DataType.e4m3, DataType.u4, DataType.bf16, DataType.bf16, DataType.bf16),
|
|
]
|
|
|
|
if is_cuda_version_at_least("12.8"):
|
|
supported_dtypes_fp4 = [
|
|
(DataType.f16, DataType.e2m1, DataType.ue8m0, DataType.f16, DataType.f16),
|
|
(
|
|
DataType.bf16,
|
|
DataType.e2m1,
|
|
DataType.ue8m0,
|
|
DataType.bf16,
|
|
DataType.bf16,
|
|
),
|
|
]
|
|
else:
|
|
supported_dtypes_fp4 = []
|
|
|
|
quant_ops = [TrtLlm_QuantOp.finegrained_scale_only]
|
|
|
|
epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default]
|
|
|
|
M_TILES = [64, 128] # Currently M tile must be 128 for Grouped GEMM
|
|
N_TILES = [16, 32, 64, 128]
|
|
K_TILES = [128, 256, 512]
|
|
cta_shapes_mnk_int4 = list(product(M_TILES, N_TILES, K_TILES))
|
|
|
|
M_TILES = [64, 128] # Currently M tile must be 128 for Grouped GEMM
|
|
N_TILES = [16, 32, 64]
|
|
K_TILES = [128, 256]
|
|
cta_shapes_mnk_fp4 = list(product(M_TILES, N_TILES, K_TILES))
|
|
cta_shapes_mnk_fp4.append((128, 128, 128))
|
|
|
|
warp_shape = [0, 0, 0] # ignored except for naming
|
|
stages = 0 # auto
|
|
|
|
cga_shapes = list(product([1, 2], [1, 2], [1]))
|
|
|
|
partial_args_int4 = product(
|
|
supported_dtypes_int4, quant_ops, epi_tags, cta_shapes_mnk_int4, cga_shapes
|
|
)
|
|
partial_args_fp4 = product(
|
|
supported_dtypes_fp4, quant_ops, epi_tags, cta_shapes_mnk_fp4, cga_shapes
|
|
)
|
|
partial_args = chain(partial_args_int4, partial_args_fp4)
|
|
|
|
operations = list()
|
|
for dtype_combo, quant_op, epi_tag, cta_shape_mnk, cga_shape in partial_args:
|
|
use_coop = cta_shape_mnk[0] >= 128
|
|
mainloop_schedules = (
|
|
[
|
|
KernelScheduleType.TmaWarpSpecializedCooperative,
|
|
KernelScheduleType.TmaWarpSpecializedPingpong,
|
|
]
|
|
if use_coop
|
|
else [KernelScheduleType.TmaWarpSpecializedPingpong]
|
|
)
|
|
epi_schedule = EpilogueScheduleType.TmaWarpSpecializedCooperative
|
|
for mainloop_schedule in mainloop_schedules:
|
|
if (
|
|
cta_shape_mnk[0] == 128
|
|
and cta_shape_mnk[1] == 128
|
|
and mainloop_schedule
|
|
== KernelScheduleType.TmaWarpSpecializedCooperative
|
|
):
|
|
continue
|
|
moe_gemm_operation = TrtLlm_GemmLauncher(
|
|
GemmKind.Grouped,
|
|
arch,
|
|
*dtype_combo,
|
|
quant_op,
|
|
epi_tag,
|
|
cta_shape_mnk,
|
|
warp_shape,
|
|
stages,
|
|
cga_shape,
|
|
mainloop_schedule,
|
|
epi_schedule,
|
|
)
|
|
operations.append(moe_gemm_operation)
|
|
return operations
|
|
|
|
|
|
def generate_sm90_operations(is_arch_enabled):
|
|
operations = generate_sm90_mixed_gemm_operations()
|
|
operations.extend(generate_sm90_grouped_gemm_operations(is_arch_enabled))
|
|
operations.extend(generate_sm90_mixed_type_grouped_gemm_operations(is_arch_enabled))
|
|
return operations
|
|
|
|
|
|
def calc_shape_mnk_sm100_grouped_gemm(cta_shape_mn, dtype):
|
|
max_k_bits = 128 * 8
|
|
cta_shape_k = max_k_bits // GetDataTypeBits(dtype)
|
|
if dtype == DataType.e4m3 and (cta_shape_mn[1] == 8):
|
|
cta_shape_k = 256
|
|
if dtype == DataType.e4m3 and (cta_shape_mn[1] == 16):
|
|
cta_shape_k = 128
|
|
return cta_shape_mn + (cta_shape_k,)
|
|
|
|
|
|
def generate_sm120_grouped_gemm_operations(is_arch_enabled):
|
|
if not is_arch_enabled:
|
|
return []
|
|
arch = 120
|
|
supported_dtypes = [e2m1]
|
|
quant_ops = [TrtLlm_QuantOp.none]
|
|
epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default]
|
|
cta_shapes_mnk = [
|
|
[128, 128, 128],
|
|
[128, 128, 256],
|
|
[256, 128, 128],
|
|
[128, 256, 128],
|
|
]
|
|
|
|
warp_shape = [0, 0, 0] # ignored except for naming
|
|
stages = 0 # auto
|
|
|
|
epi_fusions = [
|
|
TrtLlm_EpilogueFusion.epilogue_fusion_none,
|
|
# TrtLlm_EpilogueFusion.epilogue_fusion_finalize
|
|
]
|
|
|
|
cga_shapes = [[1, 1, 1]]
|
|
|
|
partial_args = product(
|
|
supported_dtypes, quant_ops, epi_tags, epi_fusions, cta_shapes_mnk, cga_shapes
|
|
)
|
|
|
|
operations = list()
|
|
for dtype, quant_op, epi_tag, epi_fusion, cta_shape_mnk, cga_shape in partial_args:
|
|
cga_tile_shape_mnk = elementwise(cta_shape_mnk, cga_shape, mul)
|
|
|
|
# Ignored
|
|
mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative
|
|
epi_schedule = EpilogueScheduleType.NoSmemWarpSpecialized
|
|
|
|
otypes = [dtype]
|
|
if dtype in [DataType.e4m3, e2m1]:
|
|
otypes = [DataType.f16, DataType.bf16]
|
|
|
|
for otype in otypes:
|
|
moe_gemm_operation = TrtLlm_GemmLauncher(
|
|
GemmKind.Grouped,
|
|
arch,
|
|
dtype,
|
|
dtype,
|
|
dtype,
|
|
dtype,
|
|
otype,
|
|
quant_op,
|
|
epi_tag,
|
|
cga_tile_shape_mnk,
|
|
warp_shape,
|
|
stages,
|
|
cga_shape,
|
|
mainloop_schedule,
|
|
epi_schedule,
|
|
epi_fusion,
|
|
)
|
|
|
|
operations.append(moe_gemm_operation)
|
|
return operations
|
|
|
|
|
|
def generate_sm120_operations(is_arch_enabled):
|
|
operations = generate_sm120_grouped_gemm_operations(is_arch_enabled)
|
|
return operations
|
|
|
|
|
|
def generate_sm100_grouped_gemm_operations(is_arch_enabled):
|
|
if not is_arch_enabled:
|
|
return []
|
|
arch = 100
|
|
supported_dtypes = [
|
|
DataType.f16,
|
|
DataType.bf16,
|
|
DataType.f32,
|
|
DataType.e4m3,
|
|
e2m1,
|
|
(DataType.e4m3, e2m1),
|
|
]
|
|
quant_ops = [TrtLlm_QuantOp.none]
|
|
epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default]
|
|
cta_shapes_m = [64, 128]
|
|
cta_shapes_n = [8, 16, 32, 64, 128, 256]
|
|
cta_shapes_mn = product(cta_shapes_m, cta_shapes_n)
|
|
|
|
warp_shape = [0, 0, 0] # ignored except for naming
|
|
stages = 0 # auto
|
|
|
|
epi_fusions = [
|
|
TrtLlm_EpilogueFusion.epilogue_fusion_none,
|
|
# TrtLlm_EpilogueFusion.epilogue_fusion_finalize
|
|
]
|
|
|
|
cga_shapes = list(product([1, 2], [1, 2], [1]))
|
|
|
|
partial_args = product(
|
|
supported_dtypes, quant_ops, epi_tags, epi_fusions, cta_shapes_mn, cga_shapes
|
|
)
|
|
|
|
operations = list()
|
|
for dtype, quant_op, epi_tag, epi_fusion, cta_shape_mn, cga_shape in partial_args:
|
|
if isinstance(dtype, tuple):
|
|
dtype, weight_type = dtype
|
|
else:
|
|
weight_type = dtype
|
|
|
|
cta_shape_mnk = calc_shape_mnk_sm100_grouped_gemm(cta_shape_mn, dtype)
|
|
cga_tile_shape_mnk = elementwise(cta_shape_mnk, cga_shape, mul)
|
|
|
|
# Ignored
|
|
mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative
|
|
epi_schedule = EpilogueScheduleType.NoSmemWarpSpecialized
|
|
|
|
otypes = [dtype]
|
|
if dtype in [DataType.e4m3, e2m1]:
|
|
otypes = [DataType.f16, DataType.bf16]
|
|
|
|
for otype in otypes:
|
|
moe_gemm_operation = TrtLlm_GemmLauncher(
|
|
GemmKind.Grouped,
|
|
arch,
|
|
dtype,
|
|
weight_type,
|
|
otype,
|
|
otype,
|
|
otype,
|
|
quant_op,
|
|
epi_tag,
|
|
cga_tile_shape_mnk,
|
|
warp_shape,
|
|
stages,
|
|
cga_shape,
|
|
mainloop_schedule,
|
|
epi_schedule,
|
|
epi_fusion,
|
|
is_mx_fpx=(dtype == DataType.e4m3 and weight_type == e2m1),
|
|
)
|
|
|
|
if is_op_valid(moe_gemm_operation):
|
|
operations.append(moe_gemm_operation)
|
|
return operations
|
|
|
|
|
|
def generate_sm100_operations(is_arch_enabled):
|
|
operations = generate_sm100_grouped_gemm_operations(is_arch_enabled)
|
|
return operations
|
|
|
|
|
|
class GemmSm80LauncherConfig:
|
|
def __init__(self, gemm_kind, arch, dtype, epi_tag, cta_shape, stage):
|
|
self.gemm_kind = gemm_kind
|
|
self.arch = arch
|
|
self.dtype = dtype
|
|
self.epi_tag = epi_tag
|
|
self.cta_shape = cta_shape
|
|
self.stage = stage
|
|
|
|
|
|
def generate_sm80_fused_grouped_gemm_operations():
|
|
arch = 80
|
|
supported_dtypes = [DataType.f16, DataType.bf16]
|
|
epi_tags = [
|
|
TrtLlm_EpilogueTag.epilogue_op_silu,
|
|
TrtLlm_EpilogueTag.epilogue_op_gelu,
|
|
]
|
|
cta_shapes_mnk = [
|
|
(16, 128, 64),
|
|
(16, 256, 64),
|
|
(32, 128, 64),
|
|
(64, 128, 64),
|
|
(128, 128, 64),
|
|
]
|
|
|
|
stages = [2, 3, 4]
|
|
|
|
partial_args = product(supported_dtypes, epi_tags, cta_shapes_mnk, stages)
|
|
|
|
operations = list()
|
|
for dtype, epi_tag, cta_shape_mnk, stage in partial_args:
|
|
item = GemmSm80LauncherConfig(
|
|
GemmKind.Grouped, arch, dtype, epi_tag, cta_shape_mnk, stage
|
|
)
|
|
operations.append(item)
|
|
return operations
|
|
|
|
|
|
def generate_sm80_operations(is_arch_enabled):
|
|
operations = generate_sm80_fused_grouped_gemm_operations()
|
|
return operations
|
|
|
|
|
|
def generate_gemm_operations(output_dir, architectures):
|
|
arches = architectures.split(";")
|
|
# Get the absolute path of the provided directory
|
|
output_dir = os.path.abspath(output_dir)
|
|
|
|
fpA_intB_inl = "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl"
|
|
moe_gemm_inl = "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl"
|
|
# moe_gemm_inl = "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl"
|
|
moe_mixed_gemm_inl = "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl"
|
|
# moe_mixed_gemm_inl = "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl"
|
|
sm80_moe_gemm_inl = "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl"
|
|
# sm80_moe_gemm_inl = "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl"
|
|
|
|
inl_map = {
|
|
(GemmKind.Gemm, 90): [fpA_intB_inl],
|
|
(GemmKind.Grouped, 90): [moe_gemm_inl],
|
|
(GemmKind.Grouped, 100): [moe_gemm_inl],
|
|
(GemmKind.Grouped, 120): [moe_gemm_inl],
|
|
(GemmKind.Grouped, 80): [sm80_moe_gemm_inl],
|
|
}
|
|
|
|
def has_arch(sm):
|
|
return f"{sm}" in arches or f"{sm}-real" in arches
|
|
|
|
# The goal here is to group kernels with common instantiations together in order to reduce template instantiation overheads.
|
|
# Template instantiation dominates the time in a compilation unit, so it is the most important factor to improve.
|
|
operations = []
|
|
operations += generate_sm120_operations(has_arch(120) or has_arch(121))
|
|
operations += generate_sm100_operations(has_arch(100))
|
|
operations += generate_sm90_operations(has_arch(90))
|
|
operations += generate_sm80_operations(has_arch(80) or has_arch(89))
|
|
|
|
def should_skip(op):
|
|
return False # All kernels have a public implementation
|
|
|
|
# The mixed dtype grouped gemm for w4afp8 has a different launcher
|
|
def is_mixed_dtype_grouped(op):
|
|
if isinstance(op, GemmSm80LauncherConfig):
|
|
return False
|
|
# Only w4a8fp8 and not wfp4afp8
|
|
return (
|
|
(op.act_type != op.weight_type)
|
|
and (op.gemm_kind == GemmKind.Grouped)
|
|
and (op.act_type != DataType.e4m3 or op.weight_type != e2m1)
|
|
)
|
|
|
|
# Fix OOM error in CI. If len(operations) is more than GROUP_SIZE, it will be split into multiple sub groups.
|
|
GROUP_SIZE = 8
|
|
op_groups = dict()
|
|
for op in operations:
|
|
if should_skip(op):
|
|
continue
|
|
# This dict key is used to group kernels with common instantiations together
|
|
# Similar implementations should live in the same file so the compiler can share the cutlass state
|
|
# Without this we see significant memory consumption, and separating them also does not reduce the compilation time
|
|
# because most time is spent parsing the same cutlass files
|
|
# We separate by: Architecture, Leading dimension of the CTA shape, FP4 (i.e. block scaled MMA), mixed input
|
|
# TODO Do a more scientific analysis of this
|
|
dict_key = (
|
|
op.gemm_kind,
|
|
op.arch,
|
|
op.cta_shape[0],
|
|
op.arch >= 100 and (op.weight_type == e2m1 or op.is_mx_fpx),
|
|
is_mixed_dtype_grouped(op),
|
|
)
|
|
op_group = op_groups.get(dict_key, [])
|
|
if len(op_group) == 0 or len(op_group[-1]) >= GROUP_SIZE:
|
|
op_group.append([op])
|
|
else:
|
|
op_group[-1].append(op)
|
|
op_groups[dict_key] = op_group
|
|
|
|
file_list = []
|
|
for key, value in op_groups.items():
|
|
gemm_kind, arch, m, block_scale, is_mixed = key
|
|
for i, op_sub_group in enumerate(value):
|
|
out_file = os.path.join(
|
|
output_dir,
|
|
GemmKindNames[gemm_kind],
|
|
str(arch),
|
|
f"cutlass_kernel_file_{GemmKindNames[gemm_kind]}_sm{arch}_M{m}{'_BS' if block_scale else ''}{'_Mixed' if is_mixed else ''}_group{i}.generated.cu",
|
|
)
|
|
inl_file = [moe_mixed_gemm_inl] if is_mixed else inl_map[key[:2]]
|
|
write_file(inl_file, op_sub_group, out_file)
|
|
file_list.append(out_file)
|
|
|
|
# Clean up any leftover files from previous runs
|
|
clean_leftover_files(output_dir, set(file_list))
|