sglang_v0.5.2/flashinfer_0.3.1/flashinfer/jit/cutlass_gemm/generate_kernels.py

982 lines
33 KiB
Python

import enum
import os
from itertools import chain, product
from .cutlass_library import (
enum_auto,
DataTypeNames,
DataTypeSize,
DataType,
DataTypeTag,
GemmKind,
GemmKindNames,
KernelScheduleType,
KernelScheduleTag,
KernelScheduleSuffixes,
EpilogueScheduleType,
EpilogueScheduleTag,
EpilogueScheduleSuffixes,
)
from ..cpp_ext import is_cuda_version_at_least
################################################################################
# Epilogue Tag enum and string utils
class TrtLlm_EpilogueTag(enum.Enum):
epilogue_op_default = enum_auto()
epilogue_op_bias = enum_auto()
epilogue_op_silu = enum_auto()
epilogue_op_gelu = enum_auto()
class TrtLlm_EpilogueFusion(enum.Enum):
epilogue_fusion_none = enum_auto()
epilogue_fusion_finalize = enum_auto()
EpiTagNames = {
TrtLlm_EpilogueTag.epilogue_op_default: "lc", # linear combination
TrtLlm_EpilogueTag.epilogue_op_bias: "lc_bias", # linear combination with bias addition
TrtLlm_EpilogueTag.epilogue_op_silu: "silu", # silu or swiglu
TrtLlm_EpilogueTag.epilogue_op_gelu: "gelu", # gelu or geglu
}
EpiTag = {
TrtLlm_EpilogueTag.epilogue_op_default: "tensorrt_llm::cutlass_extensions::EpilogueOpDefault",
TrtLlm_EpilogueTag.epilogue_op_bias: "tensorrt_llm::cutlass_extensions::EpilogueOpBias",
TrtLlm_EpilogueTag.epilogue_op_silu: "tensorrt_llm::cutlass_extensions::EpilogueOpDefaultSilu",
TrtLlm_EpilogueTag.epilogue_op_gelu: "tensorrt_llm::cutlass_extensions::EpilogueOpDefaultFtGelu",
}
EpiFusion = {
TrtLlm_EpilogueFusion.epilogue_fusion_none: "tensorrt_llm::TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE",
TrtLlm_EpilogueFusion.epilogue_fusion_finalize: "tensorrt_llm::TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE",
}
EpiFusionSuffixes = {
None: "",
TrtLlm_EpilogueFusion.epilogue_fusion_none: "EpilogueFusion_NONE",
TrtLlm_EpilogueFusion.epilogue_fusion_finalize: "EpilogueFusion_FINALIZE",
}
################################################################################
# Quantization Operation and string utils
class TrtLlm_QuantOp(enum.Enum):
per_column_scale_only = enum_auto()
finegrained_scale_only = enum_auto()
finegrained_scale_and_zeros = enum_auto()
none = enum_auto()
QuantOpNames = {
TrtLlm_QuantOp.per_column_scale_only: "cs",
TrtLlm_QuantOp.finegrained_scale_only: "fgs",
TrtLlm_QuantOp.finegrained_scale_and_zeros: "fgsz",
TrtLlm_QuantOp.none: "noquant",
}
QuantOpTag = {
TrtLlm_QuantOp.per_column_scale_only: "cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY",
TrtLlm_QuantOp.finegrained_scale_only: "cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY",
TrtLlm_QuantOp.finegrained_scale_and_zeros: "cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS",
TrtLlm_QuantOp.none: "void",
}
################################################################################
# The activations, biases, scales and zeros are instantiated using CUDA types,
# not CUTLASS types. This map materializes the name of the CUDA type.
class e2m1_type: # WAR until we have upgraded everything to a supported version
pass
e2m1 = e2m1_type()
def GetDataTypeBits(type):
if isinstance(type, e2m1_type):
return 4
return DataTypeSize[type]
def GetDataTypeNames(type, is_mx_fpx=None):
mxprefix = ""
if is_mx_fpx is not None:
mxprefix = "mx_" if is_mx_fpx else "nv_"
if isinstance(type, e2m1_type):
return mxprefix + "e2m1"
return mxprefix + DataTypeNames[type]
CudaTypeName = {
e2m1: "SafeFP4",
DataType.e4m3: "__nv_fp8_e4m3",
DataType.bf16: "__nv_bfloat16",
DataType.f16: "half",
DataType.f32: "float",
DataType.e2m1: "__nv_fp4_e2m1",
DataType.ue8m0: "cutlass::float_ue8m0_t",
DataType.u4: "cutlass::uint4b_t",
}
################################################################################
# A data structure holding all info to instantiate gemm launchers in TRT LLM.
class TrtLlm_GemmLauncher:
def __init__(
self,
gemm_kind,
arch,
act_type,
weight_type,
scalezero_type,
bias_type,
output_type,
quant_op,
epi_tag,
cta_shape,
warp_shape,
stages,
cga_shape,
mainloop_schedule,
epi_schedule,
epi_fusion=None,
is_mx_fpx=False,
):
self.gemm_kind = gemm_kind
self.arch = arch
self.act_type = act_type
self.weight_type = weight_type
self.scalezero_type = scalezero_type
self.bias_type = bias_type
self.output_type = output_type
self.quant_op = quant_op
self.epi_tag = epi_tag
self.cta_shape = cta_shape
self.warp_shape = warp_shape
self.stages = stages
self.cga_shape = cga_shape
self.mainloop_schedule = mainloop_schedule
self.epi_schedule = epi_schedule
self.epi_fusion = epi_fusion
self.is_mx_fpx = is_mx_fpx
def __repr__(self):
kernel_prefix = "{}_sm{}_{}_{}_{}_{}_{}_{}_{}_{}x{}x{}_{}x{}x{}_{}".format(
GemmKindNames[self.gemm_kind],
self.arch,
GetDataTypeNames(self.act_type, self.is_mx_fpx),
GetDataTypeNames(self.weight_type, self.is_mx_fpx),
GetDataTypeNames(self.scalezero_type),
GetDataTypeNames(self.bias_type),
GetDataTypeNames(self.output_type),
QuantOpNames[self.quant_op],
EpiTagNames[self.epi_tag],
self.cta_shape[0],
self.cta_shape[1],
self.cta_shape[2],
self.warp_shape[0],
self.warp_shape[1],
self.warp_shape[2],
self.stages,
)
hopper_suffix = "_{}x{}x{}{}{}{}".format(
self.cga_shape[0],
self.cga_shape[1],
self.cga_shape[2],
KernelScheduleSuffixes[self.mainloop_schedule],
EpilogueScheduleSuffixes[self.epi_schedule],
EpiFusionSuffixes[self.epi_fusion],
)
if self.arch >= 90:
return kernel_prefix + hopper_suffix
elif self.arch > 100:
raise ValueError(f"SM{self.arch} not supported yet.")
return kernel_prefix
################################################################################
def tuple_to_cute_shape(shape):
return f"cute::Shape<cute::Int<{shape[0]}>, cute::Int<{shape[1]}>, cute::Int<{shape[2]}>>"
def instantiate_operation_tma_warp_specialized(operation):
act_tag = CudaTypeName[operation.act_type]
scale_zero_tag = CudaTypeName[operation.scalezero_type]
bias_tag = CudaTypeName[operation.bias_type]
out_tag = CudaTypeName[operation.output_type]
quant_op = QuantOpTag[operation.quant_op]
epi_tag = EpiTag[operation.epi_tag]
cute_cta_shape = tuple_to_cute_shape(operation.cta_shape)
cute_cga_shape = tuple_to_cute_shape(operation.cga_shape)
kernel_sched = KernelScheduleTag[operation.mainloop_schedule]
epi_sched = EpilogueScheduleTag[operation.epi_schedule]
if operation.gemm_kind == GemmKind.Gemm:
weight_tag = DataTypeTag[operation.weight_type]
instantiation = f"""
template void sm90_generic_mixed_gemm_kernelLauncher<{act_tag}, {weight_tag}, {scale_zero_tag}, {bias_tag}, {out_tag},
{quant_op}, {epi_tag},
{cute_cta_shape}, {cute_cga_shape},
{kernel_sched}, {epi_sched}> (
const {act_tag}*, const {weight_tag}*, const {scale_zero_tag}*, const {scale_zero_tag}*, const {bias_tag}*, const float,
{out_tag}*, int, int, int, const int, tensorrt_llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
);
"""
elif operation.gemm_kind == GemmKind.Grouped:
if operation.act_type != operation.weight_type and (
operation.act_type != DataType.e4m3 or operation.weight_type != e2m1
):
# Mixed MoE GEMM
weight_tag = CudaTypeName[operation.weight_type]
instantiation = f"""
template void sm90_generic_mixed_moe_gemm_kernelLauncher<{act_tag}, {weight_tag}, {out_tag},
{epi_tag}, {cute_cta_shape}, {cute_cga_shape}, {kernel_sched}, {epi_sched}, {quant_op}> (
GroupedGemmInput<{act_tag}, {weight_tag}, {out_tag}, {out_tag}>inputs, TmaWarpSpecializedGroupedGemmInput hopper_inputs, int sm_count_, size_t* workspace_size);
"""
else:
# Similar to MixedInput above, we must modify the tags for grouped gemm as CUTLASS library does not have the updated schedules
assert operation.mainloop_schedule in [
KernelScheduleType.TmaWarpSpecializedCooperative,
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
]
assert operation.epi_schedule == EpilogueScheduleType.NoSmemWarpSpecialized
kernel_sched.replace("::Kernel", "::KernelGrouped")
epi_sched += "Grouped"
# arch_tag = f"cutlass::arch::Sm{operation.arch}"
arch_tag = f"Sm{operation.arch}"
weight_tag = CudaTypeName[operation.weight_type]
assert operation.epi_fusion is not None
epi_fusion = EpiFusion[operation.epi_fusion]
epi_fusion = epi_fusion.split(":")[-1]
epi_tag = epi_tag.split(":")[-1]
guard_map = {
e2m1: "defined(ENABLE_FP4)",
DataType.e4m3: "defined(ENABLE_FP8)",
DataType.bf16: "defined(ENABLE_BF16)",
}
guard_act = guard_map.get(operation.act_type, "1")
guard_weight = guard_map.get(operation.weight_type, "1")
# TODO Revert this once compiler bug is fixed so we can use template instead of macro again
# instantiation = f"""
# template void tma_warp_specialized_generic_moe_gemm_kernelLauncher<{arch_tag}, {act_tag}, {weight_tag}, {out_tag},
# {epi_tag}, {epi_fusion}, {cute_cta_shape}, {cute_cga_shape}, false>
# (TmaWarpSpecializedGroupedGemmInput, int, int, cudaStream_t, int*, size_t*);
# """
instantiation = f"""
#if {guard_act} && {guard_weight}\n
INSTANTIATE_TMA_WARP_SPECIALIZED_MOE_GEMM({arch_tag}, {act_tag}, {weight_tag}, {out_tag},
{epi_tag}, {epi_fusion}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.cga_shape[0]}, {operation.cga_shape[1]}, {operation.cga_shape[2]}, {"true" if operation.is_mx_fpx else "false"}, false);\n
#endif
"""
return instantiation
def instantiate_operation_sm80(operation):
act_tag = DataTypeTag[operation.dtype]
weight_tag = DataTypeTag[operation.dtype]
epi_tag = EpiTag[operation.epi_tag]
instantiation = f"""
template void sm80_generic_fused_moe_gemm_kernelLauncher<{act_tag}, {weight_tag}, {operation.cta_shape[0]}, {operation.cta_shape[1]}, {operation.cta_shape[2]}, {operation.stage}, {epi_tag}>
({act_tag} const* A, {weight_tag} const* B, {act_tag} const* biases, bool bias_is_broadcast, {act_tag}* C, int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, int* kernel_occupancy);
"""
return instantiation
def instantiate_operation(operation):
if operation.arch == 80:
return instantiate_operation_sm80(operation)
elif operation.arch >= 90:
return instantiate_operation_tma_warp_specialized(operation)
def get_file_content(launcher_inl_files, operations):
assert operations
include_list = list()
for file in launcher_inl_files:
include_list.append(f'#include "{file}"')
includes = "\n".join(include_list)
insts_list = list()
for op in operations:
insts_list.append(instantiate_operation(op))
instantiations = "\n".join(insts_list)
file_content = f"""{includes}
namespace tensorrt_llm
{{
namespace kernels
{{
namespace cutlass_kernels
{{
{instantiations}
}} // namespace cutlass_kernels
}} // namespace kernels
}} // namespace tensorrt_llm
"""
return file_content
def clean_leftover_files(output_dir, generated_files):
"""Remove leftover generated files that weren't created in this run."""
for root, _dirs, files in os.walk(output_dir):
for file in files:
file_path = os.path.join(root, file)
if file_path not in generated_files:
os.remove(file_path)
def write_file(launcher_inl_files, operations, output_file):
os.makedirs(os.path.dirname(output_file), exist_ok=True)
# Avoid changing modified time if file content is up to date
content = get_file_content(launcher_inl_files, operations)
try:
with open(output_file, mode="r") as f:
if f.read() == content:
return
except FileNotFoundError:
pass
with open(output_file, mode="w") as f:
f.write(content)
from operator import mul, truediv
def elementwise(x, y, f):
return tuple(f(a, b) for (a, b) in zip(x, y))
def is_gemm_op_valid_sm100(op):
# TODO These are much more restricted than theory dictates, investigate if more can be enabled in future
tile_m, tile_n, _ = elementwise(op.cta_shape, op.cga_shape, truediv)
cga_m, cga_n, _ = op.cga_shape
# Default shapes
# This is epilogue tile size. For two CTA this is actually size 128/256 for the MMA
if tile_m not in [64, 128]:
return False
# FP4 Has some much more limited sizes
if op.act_type == e2m1 or op.weight_type == e2m1:
# TODO 128x256x256 FP4 compiles but crashes
# if tile_n % 64 != 0 or tile_n < 128:
# return False
if tile_n not in [64, 128, 256] or tile_m != 128:
return False
# Shapes for fp8 small N shapes
if (
op.act_type == DataType.e4m3
and (tile_n == 16 or tile_n == 8)
and (cga_m == 1 and cga_n == 1)
):
# todo: double check why this is disable in CUTLASS backend. @yuhan
if tile_m == 128 and tile_n == 8:
return False
else:
return True
# Default alignment requirements
if tile_n % 32 != 0 or tile_n < 32 or tile_n > 256:
return False
# Two CTA mode needs bigger tile n alignment
if cga_m % 2 == 0 and tile_n % 64 != 0:
return False
return True
def is_gemm_op_valid(op):
tile_m, tile_n, _ = op.cta_shape
cga_m, cga_n, _ = op.cga_shape
if cga_m == 1 and cga_n == 1:
return True
if cga_m == 2 and cga_n == 1 and tile_m >= 128:
return True
if cga_m == 1 and cga_n == 2 and tile_n >= 128:
return True
if cga_m == 2 and cga_n == 2 and tile_m >= 128 and tile_n >= 128:
return True
return False
def is_grouped_gemm_op_valid(op):
if not is_gemm_op_valid(op):
return False
if op.epi_tag != TrtLlm_EpilogueTag.epilogue_op_default:
return False
if op.epi_schedule != EpilogueScheduleType.NoSmemWarpSpecialized:
return False
if op.mainloop_schedule not in [
KernelScheduleType.TmaWarpSpecializedCooperative,
KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
]:
return False
return True
def is_op_valid(op):
if op.arch >= 100:
return is_gemm_op_valid_sm100(op)
if op.gemm_kind == GemmKind.Gemm:
return is_gemm_op_valid(op)
if op.gemm_kind == GemmKind.Grouped:
return is_grouped_gemm_op_valid(op)
################################################################################
def generate_sm90_mixed_gemm_operations():
arch = 90
# For legacy reasons, we use unsigned types for the weights. The instanitated template
# will remap those back to the signed type.
# Takes the form (activation_type, weight_type, scalezero_type, bias_type, output_type)
supported_dtypes = [
(DataType.e4m3, DataType.u4, DataType.f16, DataType.f16, DataType.f16),
(DataType.e4m3, DataType.u4, DataType.f16, DataType.bf16, DataType.bf16),
(DataType.f16, DataType.u4, DataType.f16, DataType.f16, DataType.f16),
(DataType.bf16, DataType.u4, DataType.bf16, DataType.bf16, DataType.bf16),
(DataType.f16, DataType.u8, DataType.f16, DataType.f16, DataType.f16),
(DataType.bf16, DataType.u8, DataType.bf16, DataType.bf16, DataType.bf16),
]
quant_ops = [
TrtLlm_QuantOp.per_column_scale_only,
TrtLlm_QuantOp.finegrained_scale_only,
TrtLlm_QuantOp.finegrained_scale_and_zeros,
]
epi_tags = [TrtLlm_EpilogueTag.epilogue_op_bias]
M_TILES = [64, 128]
N_TILES = [16, 32, 64, 128, 256]
cta_shapes_mn = product(M_TILES, N_TILES)
warp_shape = [4, 1, 1]
stages = 0 # auto
cga_shapes = product([1, 2], [1, 2], [1])
partial_args = product(
supported_dtypes, quant_ops, epi_tags, cta_shapes_mn, cga_shapes
)
operations = list()
for dtype_combo, quant_op, epi_tag, cta_shape_mn, cga_shape in partial_args:
max_k_bits = 128 * 8
cta_shape_k = max_k_bits // GetDataTypeBits(dtype_combo[0])
cta_shape_mnk = cta_shape_mn + (cta_shape_k,)
use_coop = cta_shape_mn[0] == 128
mainloop_schedule = (
KernelScheduleType.TmaWarpSpecializedCooperative
if use_coop
else KernelScheduleType.TmaWarpSpecializedPingpong
)
epi_schedule = (
EpilogueScheduleType.TmaWarpSpecializedCooperative
if use_coop
else EpilogueScheduleType.TmaWarpSpecialized
)
fpA_intB_operation = TrtLlm_GemmLauncher(
GemmKind.Gemm,
arch,
*dtype_combo,
quant_op,
epi_tag,
cta_shape_mnk,
warp_shape,
stages,
cga_shape,
mainloop_schedule,
epi_schedule,
)
if is_op_valid(fpA_intB_operation):
operations.append(fpA_intB_operation)
return operations
def generate_sm90_grouped_gemm_operations(is_arch_enabled):
if not is_arch_enabled:
return []
arch = 90
supported_dtypes = [DataType.f16, DataType.bf16, DataType.f32, DataType.e4m3]
quant_ops = [TrtLlm_QuantOp.none]
epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default]
M_TILES = [128] # Currently M tile must be 128 for Grouped GEMM
N_TILES = [16, 32, 64, 128, 256]
cta_shapes_mn = list(product(M_TILES, N_TILES)) + [(256, 128)]
warp_shape = [0, 0, 0] # ignored except for naming
stages = 0 # auto
epi_fusions = [
TrtLlm_EpilogueFusion.epilogue_fusion_none,
TrtLlm_EpilogueFusion.epilogue_fusion_finalize,
]
cga_shapes = product([1, 2], [1, 2], [1])
partial_args = product(
supported_dtypes, quant_ops, epi_tags, epi_fusions, cta_shapes_mn, cga_shapes
)
operations = list()
for dtype, quant_op, epi_tag, epi_fusion, cta_shape_mn, cga_shape in partial_args:
max_k_bits = 128 * 8
cta_shape_k = max_k_bits // GetDataTypeBits(dtype)
cta_shape_mnk = cta_shape_mn + (cta_shape_k,)
mainloop_schedule = (
KernelScheduleType.TmaWarpSpecializedCooperative
if dtype != DataType.e4m3
else KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum
)
epi_schedule = EpilogueScheduleType.NoSmemWarpSpecialized
otypes = [dtype]
if dtype == DataType.e4m3:
otypes = [DataType.f16, DataType.bf16]
for otype in otypes:
moe_gemm_operation = TrtLlm_GemmLauncher(
GemmKind.Grouped,
arch,
dtype,
dtype,
dtype,
dtype,
otype,
quant_op,
epi_tag,
cta_shape_mnk,
warp_shape,
stages,
cga_shape,
mainloop_schedule,
epi_schedule,
epi_fusion,
)
if is_op_valid(moe_gemm_operation):
operations.append(moe_gemm_operation)
return operations
def generate_sm90_mixed_type_grouped_gemm_operations(is_arch_enabled):
if not is_arch_enabled:
return []
arch = 90
# act_type, weight_type, scalezero_type, bias_type, output_type
supported_dtypes_int4 = [
(DataType.e4m3, DataType.u4, DataType.f16, DataType.f16, DataType.f16),
(DataType.e4m3, DataType.u4, DataType.bf16, DataType.bf16, DataType.bf16),
]
if is_cuda_version_at_least("12.8"):
supported_dtypes_fp4 = [
(DataType.f16, DataType.e2m1, DataType.ue8m0, DataType.f16, DataType.f16),
(
DataType.bf16,
DataType.e2m1,
DataType.ue8m0,
DataType.bf16,
DataType.bf16,
),
]
else:
supported_dtypes_fp4 = []
quant_ops = [TrtLlm_QuantOp.finegrained_scale_only]
epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default]
M_TILES = [64, 128] # Currently M tile must be 128 for Grouped GEMM
N_TILES = [16, 32, 64, 128]
K_TILES = [128, 256, 512]
cta_shapes_mnk_int4 = list(product(M_TILES, N_TILES, K_TILES))
M_TILES = [64, 128] # Currently M tile must be 128 for Grouped GEMM
N_TILES = [16, 32, 64]
K_TILES = [128, 256]
cta_shapes_mnk_fp4 = list(product(M_TILES, N_TILES, K_TILES))
cta_shapes_mnk_fp4.append((128, 128, 128))
warp_shape = [0, 0, 0] # ignored except for naming
stages = 0 # auto
cga_shapes = list(product([1, 2], [1, 2], [1]))
partial_args_int4 = product(
supported_dtypes_int4, quant_ops, epi_tags, cta_shapes_mnk_int4, cga_shapes
)
partial_args_fp4 = product(
supported_dtypes_fp4, quant_ops, epi_tags, cta_shapes_mnk_fp4, cga_shapes
)
partial_args = chain(partial_args_int4, partial_args_fp4)
operations = list()
for dtype_combo, quant_op, epi_tag, cta_shape_mnk, cga_shape in partial_args:
use_coop = cta_shape_mnk[0] >= 128
mainloop_schedules = (
[
KernelScheduleType.TmaWarpSpecializedCooperative,
KernelScheduleType.TmaWarpSpecializedPingpong,
]
if use_coop
else [KernelScheduleType.TmaWarpSpecializedPingpong]
)
epi_schedule = EpilogueScheduleType.TmaWarpSpecializedCooperative
for mainloop_schedule in mainloop_schedules:
if (
cta_shape_mnk[0] == 128
and cta_shape_mnk[1] == 128
and mainloop_schedule
== KernelScheduleType.TmaWarpSpecializedCooperative
):
continue
moe_gemm_operation = TrtLlm_GemmLauncher(
GemmKind.Grouped,
arch,
*dtype_combo,
quant_op,
epi_tag,
cta_shape_mnk,
warp_shape,
stages,
cga_shape,
mainloop_schedule,
epi_schedule,
)
operations.append(moe_gemm_operation)
return operations
def generate_sm90_operations(is_arch_enabled):
operations = generate_sm90_mixed_gemm_operations()
operations.extend(generate_sm90_grouped_gemm_operations(is_arch_enabled))
operations.extend(generate_sm90_mixed_type_grouped_gemm_operations(is_arch_enabled))
return operations
def calc_shape_mnk_sm100_grouped_gemm(cta_shape_mn, dtype):
max_k_bits = 128 * 8
cta_shape_k = max_k_bits // GetDataTypeBits(dtype)
if dtype == DataType.e4m3 and (cta_shape_mn[1] == 8):
cta_shape_k = 256
if dtype == DataType.e4m3 and (cta_shape_mn[1] == 16):
cta_shape_k = 128
return cta_shape_mn + (cta_shape_k,)
def generate_sm120_grouped_gemm_operations(is_arch_enabled):
if not is_arch_enabled:
return []
arch = 120
supported_dtypes = [e2m1]
quant_ops = [TrtLlm_QuantOp.none]
epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default]
cta_shapes_mnk = [
[128, 128, 128],
[128, 128, 256],
[256, 128, 128],
[128, 256, 128],
]
warp_shape = [0, 0, 0] # ignored except for naming
stages = 0 # auto
epi_fusions = [
TrtLlm_EpilogueFusion.epilogue_fusion_none,
# TrtLlm_EpilogueFusion.epilogue_fusion_finalize
]
cga_shapes = [[1, 1, 1]]
partial_args = product(
supported_dtypes, quant_ops, epi_tags, epi_fusions, cta_shapes_mnk, cga_shapes
)
operations = list()
for dtype, quant_op, epi_tag, epi_fusion, cta_shape_mnk, cga_shape in partial_args:
cga_tile_shape_mnk = elementwise(cta_shape_mnk, cga_shape, mul)
# Ignored
mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative
epi_schedule = EpilogueScheduleType.NoSmemWarpSpecialized
otypes = [dtype]
if dtype in [DataType.e4m3, e2m1]:
otypes = [DataType.f16, DataType.bf16]
for otype in otypes:
moe_gemm_operation = TrtLlm_GemmLauncher(
GemmKind.Grouped,
arch,
dtype,
dtype,
dtype,
dtype,
otype,
quant_op,
epi_tag,
cga_tile_shape_mnk,
warp_shape,
stages,
cga_shape,
mainloop_schedule,
epi_schedule,
epi_fusion,
)
operations.append(moe_gemm_operation)
return operations
def generate_sm120_operations(is_arch_enabled):
operations = generate_sm120_grouped_gemm_operations(is_arch_enabled)
return operations
def generate_sm100_grouped_gemm_operations(is_arch_enabled):
if not is_arch_enabled:
return []
arch = 100
supported_dtypes = [
DataType.f16,
DataType.bf16,
DataType.f32,
DataType.e4m3,
e2m1,
(DataType.e4m3, e2m1),
]
quant_ops = [TrtLlm_QuantOp.none]
epi_tags = [TrtLlm_EpilogueTag.epilogue_op_default]
cta_shapes_m = [64, 128]
cta_shapes_n = [8, 16, 32, 64, 128, 256]
cta_shapes_mn = product(cta_shapes_m, cta_shapes_n)
warp_shape = [0, 0, 0] # ignored except for naming
stages = 0 # auto
epi_fusions = [
TrtLlm_EpilogueFusion.epilogue_fusion_none,
# TrtLlm_EpilogueFusion.epilogue_fusion_finalize
]
cga_shapes = list(product([1, 2], [1, 2], [1]))
partial_args = product(
supported_dtypes, quant_ops, epi_tags, epi_fusions, cta_shapes_mn, cga_shapes
)
operations = list()
for dtype, quant_op, epi_tag, epi_fusion, cta_shape_mn, cga_shape in partial_args:
if isinstance(dtype, tuple):
dtype, weight_type = dtype
else:
weight_type = dtype
cta_shape_mnk = calc_shape_mnk_sm100_grouped_gemm(cta_shape_mn, dtype)
cga_tile_shape_mnk = elementwise(cta_shape_mnk, cga_shape, mul)
# Ignored
mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative
epi_schedule = EpilogueScheduleType.NoSmemWarpSpecialized
otypes = [dtype]
if dtype in [DataType.e4m3, e2m1]:
otypes = [DataType.f16, DataType.bf16]
for otype in otypes:
moe_gemm_operation = TrtLlm_GemmLauncher(
GemmKind.Grouped,
arch,
dtype,
weight_type,
otype,
otype,
otype,
quant_op,
epi_tag,
cga_tile_shape_mnk,
warp_shape,
stages,
cga_shape,
mainloop_schedule,
epi_schedule,
epi_fusion,
is_mx_fpx=(dtype == DataType.e4m3 and weight_type == e2m1),
)
if is_op_valid(moe_gemm_operation):
operations.append(moe_gemm_operation)
return operations
def generate_sm100_operations(is_arch_enabled):
operations = generate_sm100_grouped_gemm_operations(is_arch_enabled)
return operations
class GemmSm80LauncherConfig:
def __init__(self, gemm_kind, arch, dtype, epi_tag, cta_shape, stage):
self.gemm_kind = gemm_kind
self.arch = arch
self.dtype = dtype
self.epi_tag = epi_tag
self.cta_shape = cta_shape
self.stage = stage
def generate_sm80_fused_grouped_gemm_operations():
arch = 80
supported_dtypes = [DataType.f16, DataType.bf16]
epi_tags = [
TrtLlm_EpilogueTag.epilogue_op_silu,
TrtLlm_EpilogueTag.epilogue_op_gelu,
]
cta_shapes_mnk = [
(16, 128, 64),
(16, 256, 64),
(32, 128, 64),
(64, 128, 64),
(128, 128, 64),
]
stages = [2, 3, 4]
partial_args = product(supported_dtypes, epi_tags, cta_shapes_mnk, stages)
operations = list()
for dtype, epi_tag, cta_shape_mnk, stage in partial_args:
item = GemmSm80LauncherConfig(
GemmKind.Grouped, arch, dtype, epi_tag, cta_shape_mnk, stage
)
operations.append(item)
return operations
def generate_sm80_operations(is_arch_enabled):
operations = generate_sm80_fused_grouped_gemm_operations()
return operations
def generate_gemm_operations(output_dir, architectures):
arches = architectures.split(";")
# Get the absolute path of the provided directory
output_dir = os.path.abspath(output_dir)
fpA_intB_inl = "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl"
moe_gemm_inl = "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl"
# moe_gemm_inl = "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl"
moe_mixed_gemm_inl = "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl"
# moe_mixed_gemm_inl = "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl"
sm80_moe_gemm_inl = "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl"
# sm80_moe_gemm_inl = "tensorrt_llm/kernels/internal_cutlass_kernels/src/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl"
inl_map = {
(GemmKind.Gemm, 90): [fpA_intB_inl],
(GemmKind.Grouped, 90): [moe_gemm_inl],
(GemmKind.Grouped, 100): [moe_gemm_inl],
(GemmKind.Grouped, 120): [moe_gemm_inl],
(GemmKind.Grouped, 80): [sm80_moe_gemm_inl],
}
def has_arch(sm):
return f"{sm}" in arches or f"{sm}-real" in arches
# The goal here is to group kernels with common instantiations together in order to reduce template instantiation overheads.
# Template instantiation dominates the time in a compilation unit, so it is the most important factor to improve.
operations = []
operations += generate_sm120_operations(has_arch(120) or has_arch(121))
operations += generate_sm100_operations(has_arch(100))
operations += generate_sm90_operations(has_arch(90))
operations += generate_sm80_operations(has_arch(80) or has_arch(89))
def should_skip(op):
return False # All kernels have a public implementation
# The mixed dtype grouped gemm for w4afp8 has a different launcher
def is_mixed_dtype_grouped(op):
if isinstance(op, GemmSm80LauncherConfig):
return False
# Only w4a8fp8 and not wfp4afp8
return (
(op.act_type != op.weight_type)
and (op.gemm_kind == GemmKind.Grouped)
and (op.act_type != DataType.e4m3 or op.weight_type != e2m1)
)
# Fix OOM error in CI. If len(operations) is more than GROUP_SIZE, it will be split into multiple sub groups.
GROUP_SIZE = 8
op_groups = dict()
for op in operations:
if should_skip(op):
continue
# This dict key is used to group kernels with common instantiations together
# Similar implementations should live in the same file so the compiler can share the cutlass state
# Without this we see significant memory consumption, and separating them also does not reduce the compilation time
# because most time is spent parsing the same cutlass files
# We separate by: Architecture, Leading dimension of the CTA shape, FP4 (i.e. block scaled MMA), mixed input
# TODO Do a more scientific analysis of this
dict_key = (
op.gemm_kind,
op.arch,
op.cta_shape[0],
op.arch >= 100 and (op.weight_type == e2m1 or op.is_mx_fpx),
is_mixed_dtype_grouped(op),
)
op_group = op_groups.get(dict_key, [])
if len(op_group) == 0 or len(op_group[-1]) >= GROUP_SIZE:
op_group.append([op])
else:
op_group[-1].append(op)
op_groups[dict_key] = op_group
file_list = []
for key, value in op_groups.items():
gemm_kind, arch, m, block_scale, is_mixed = key
for i, op_sub_group in enumerate(value):
out_file = os.path.join(
output_dir,
GemmKindNames[gemm_kind],
str(arch),
f"cutlass_kernel_file_{GemmKindNames[gemm_kind]}_sm{arch}_M{m}{'_BS' if block_scale else ''}{'_Mixed' if is_mixed else ''}_group{i}.generated.cu",
)
inl_file = [moe_mixed_gemm_inl] if is_mixed else inl_map[key[:2]]
write_file(inl_file, op_sub_group, out_file)
file_list.append(out_file)
# Clean up any leftover files from previous runs
clean_leftover_files(output_dir, set(file_list))