################################################################################################# # # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # 1. Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # # 2. Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # 3. Neither the name of the copyright holder nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ################################################################################################# """ Data types and tags used for emitting CUTLASS C++ kernels """ import enum import re # The following block implements enum.auto() for Python 3.5 variants that don't include it such # as the default 3.5.2 on Ubuntu 16.04. # # https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility try: from enum import auto as enum_auto except ImportError: __cutlass_library_auto_enum = 0 def enum_auto() -> int: # type: ignore[no-redef] global __cutlass_library_auto_enum i = __cutlass_library_auto_enum __cutlass_library_auto_enum += 1 return i ################################################################################################### # class GeneratorTarget(enum.Enum): Library = enum_auto() # GeneratorTargetNames = {GeneratorTarget.Library: "library"} # ################################################################################################### # class DataType(enum.Enum): void = enum_auto() # primarily used to disable C tensor for epilogues b1 = enum_auto() u2 = enum_auto() u4 = enum_auto() u8 = enum_auto() u16 = enum_auto() u32 = enum_auto() u64 = enum_auto() s2 = enum_auto() s4 = enum_auto() s8 = enum_auto() s16 = enum_auto() s32 = enum_auto() s64 = enum_auto() e4m3 = enum_auto() e5m2 = enum_auto() f8 = enum_auto() f6 = enum_auto() f4 = enum_auto() e3m2 = enum_auto() e2m3 = enum_auto() e2m1 = enum_auto() ue8m0 = enum_auto() ue4m3 = enum_auto() f16 = enum_auto() bf16 = enum_auto() f32 = enum_auto() tf32 = enum_auto() f64 = enum_auto() cf16 = enum_auto() cbf16 = enum_auto() cf32 = enum_auto() ctf32 = enum_auto() cf64 = enum_auto() cs2 = enum_auto() cs4 = enum_auto() cs8 = enum_auto() cs16 = enum_auto() cs32 = enum_auto() cs64 = enum_auto() cu2 = enum_auto() cu4 = enum_auto() cu8 = enum_auto() cu16 = enum_auto() cu32 = enum_auto() cu64 = enum_auto() invalid = enum_auto() # ShortDataTypeNames = { DataType.s32: "i", DataType.e4m3: "e4m3", DataType.e5m2: "e5m2", DataType.f16: "h", DataType.f32: "s", DataType.f64: "d", DataType.cf32: "c", DataType.cf64: "z", DataType.f8: "f8", DataType.f6: "f6", DataType.f4: "f4", } # DataTypeNames = { DataType.void: "void", DataType.b1: "b1", DataType.u2: "u2", DataType.u4: "u4", DataType.u8: "u8", DataType.u16: "u16", DataType.u32: "u32", DataType.u64: "u64", DataType.s2: "s2", DataType.s4: "s4", DataType.s8: "s8", DataType.s16: "s16", DataType.s32: "s32", DataType.s64: "s64", DataType.e4m3: "e4m3", DataType.e5m2: "e5m2", DataType.f8: "f8", DataType.f6: "f6", DataType.f4: "f4", DataType.e2m3: "e2m3", DataType.e3m2: "e3m2", DataType.e2m1: "e2m1", DataType.ue8m0: "ue8m0", DataType.ue4m3: "ue4m3", DataType.f16: "f16", DataType.bf16: "bf16", DataType.f32: "f32", DataType.tf32: "tf32", DataType.f64: "f64", DataType.cf16: "cf16", DataType.cbf16: "cbf16", DataType.cf32: "cf32", DataType.ctf32: "ctf32", DataType.cf64: "cf64", DataType.cu2: "cu2", DataType.cu4: "cu4", DataType.cu8: "cu8", DataType.cu16: "cu16", DataType.cu32: "cu32", DataType.cu64: "cu64", DataType.cs2: "cs2", DataType.cs4: "cs4", DataType.cs8: "cs8", DataType.cs16: "cs16", DataType.cs32: "cs32", DataType.cs64: "cs64", } DataTypeTag = { DataType.void: "void", DataType.b1: "cutlass::uint1b_t", DataType.u2: "cutlass::uint2b_t", DataType.u4: "cutlass::uint4b_t", DataType.u8: "uint8_t", DataType.u16: "uint16_t", DataType.u32: "uint32_t", DataType.u64: "uint64_t", DataType.s2: "cutlass::int2b_t", DataType.s4: "cutlass::int4b_t", DataType.s8: "int8_t", DataType.s16: "int16_t", DataType.s32: "int32_t", DataType.s64: "int64_t", DataType.e4m3: "cutlass::float_e4m3_t", DataType.e5m2: "cutlass::float_e5m2_t", DataType.f8: "cutlass::type_erased_dynamic_float8_t", DataType.f6: "cutlass::type_erased_dynamic_float6_t", DataType.f4: "cutlass::type_erased_dynamic_float4_t", DataType.e2m3: "cutlass::float_e2m3_t", DataType.e3m2: "cutlass::float_e3m2_t", DataType.e2m1: "cutlass::float_e2m1_t", DataType.ue8m0: "cutlass::float_ue8m0_t", DataType.ue4m3: "cutlass::float_ue4m3_t", DataType.f16: "cutlass::half_t", DataType.bf16: "cutlass::bfloat16_t", DataType.f32: "float", DataType.tf32: "cutlass::tfloat32_t", DataType.f64: "double", DataType.cf16: "cutlass::complex", DataType.cbf16: "cutlass::complex", DataType.cf32: "cutlass::complex", DataType.ctf32: "cutlass::complex", DataType.cf64: "cutlass::complex", DataType.cu2: "cutlass::complex", DataType.cu4: "cutlass::complex", DataType.cu8: "cutlass::complex", DataType.cu16: "cutlass::complex", DataType.cu32: "cutlass::complex", DataType.cu64: "cutlass::complex", DataType.cs2: "cutlass::complex", DataType.cs4: "cutlass::complex", DataType.cs8: "cutlass::complex", DataType.cs16: "cutlass::complex", DataType.cs32: "cutlass::complex", DataType.cs64: "cutlass::complex", } DataTypeSize = { DataType.void: 0, DataType.b1: 1, DataType.u2: 2, DataType.u4: 4, DataType.u8: 8, DataType.u16: 16, DataType.u32: 32, DataType.u64: 64, DataType.s2: 2, DataType.s4: 4, DataType.s8: 8, DataType.s16: 16, DataType.s32: 32, DataType.s64: 64, DataType.e4m3: 8, DataType.e5m2: 8, DataType.f8: 8, DataType.f6: 6, DataType.f4: 4, DataType.e2m3: 6, DataType.e3m2: 6, DataType.e2m1: 4, DataType.ue8m0: 8, DataType.ue4m3: 8, DataType.f16: 16, DataType.bf16: 16, DataType.f32: 32, DataType.tf32: 32, DataType.f64: 64, DataType.cf16: 32, DataType.cbf16: 32, DataType.cf32: 64, DataType.ctf32: 32, DataType.cf64: 128, DataType.cu2: 4, DataType.cu4: 8, DataType.cu8: 16, DataType.cu16: 32, DataType.cu32: 64, DataType.cu64: 128, DataType.cs2: 4, DataType.cs4: 8, DataType.cs8: 16, DataType.cs16: 32, DataType.cs32: 64, DataType.cs64: 128, } ################################################################################################### # class BlasMode(enum.Enum): symmetric = enum_auto() hermitian = enum_auto() # BlasModeTag = { BlasMode.symmetric: "cutlass::BlasMode::kSymmetric", BlasMode.hermitian: "cutlass::BlasMode::kHermitian", } # class ComplexTransform(enum.Enum): none = enum_auto() conj = enum_auto() # ComplexTransformTag = { ComplexTransform.none: "cutlass::ComplexTransform::kNone", ComplexTransform.conj: "cutlass::ComplexTransform::kConjugate", } # Used for cutlass3x complex kernel collective mainloop builder instantiation ComplexTransformTag3x = { ComplexTransform.none: "cute::identity", ComplexTransform.conj: "cute::conjugate", } # RealComplexBijection = [ (DataType.f16, DataType.cf16), (DataType.f32, DataType.cf32), (DataType.f64, DataType.cf64), ] # def is_complex(data_type): return any(data_type == c for _r, c in RealComplexBijection) def is_block_scaled(gemm_kind): return gemm_kind in ( GemmKind.BlockScaledUniversal3x, GemmKind.GroupedBlockScaledUniversal3x, ) def is_blockwise(gemm_kind): return gemm_kind in ( GemmKind.BlockwiseUniversal3x, GemmKind.GroupedBlockwiseUniversal3x, ) def is_grouped(gemm_kind): return gemm_kind in ( GemmKind.GroupedUniversal3x, GemmKind.GroupedBlockScaledUniversal3x, GemmKind.GroupedBlockwiseUniversal3x, ) # def get_complex_from_real(real_type): for r, c in RealComplexBijection: if real_type == r: return c return DataType.invalid # def get_real_from_complex(complex_type): for r, c in RealComplexBijection: if complex_type == c: return r return DataType.invalid # TMA requires an alignment of 128 bits for all data types def get_tma_alignment(data_type): if data_type == DataType.void: return 0 elif DataTypeSize[data_type] == 6: return 128 # 96B alignment for 16U6 format else: return 128 // DataTypeSize[data_type] # class ComplexMultiplyOp(enum.Enum): multiply_add = enum_auto() gaussian = enum_auto() ################################################################################################### # class MathOperation(enum.Enum): multiply_add = enum_auto() multiply_add_saturate = enum_auto() multiply_add_mixed_input_upcast = enum_auto() xor_popc = enum_auto() and_popc = enum_auto() multiply_add_fast_bf16 = enum_auto() multiply_add_fast_f16 = enum_auto() multiply_add_fast_f32 = enum_auto() multiply_add_complex_fast_f32 = enum_auto() multiply_add_complex = enum_auto() multiply_add_complex_gaussian = enum_auto() multiply_add_fast_accum = enum_auto() # MathOperationTag = { MathOperation.multiply_add: "cutlass::arch::OpMultiplyAdd", MathOperation.multiply_add_saturate: "cutlass::arch::OpMultiplyAddSaturate", MathOperation.multiply_add_mixed_input_upcast: "cutlass::arch::OpMultiplyAddMixedInputUpcast", MathOperation.xor_popc: "cutlass::arch::OpXorPopc", MathOperation.and_popc: "cutlass::arch::OpAndPopc", MathOperation.multiply_add_fast_bf16: "cutlass::arch::OpMultiplyAddFastBF16", MathOperation.multiply_add_fast_f16: "cutlass::arch::OpMultiplyAddFastF16", MathOperation.multiply_add_fast_f32: "cutlass::arch::OpMultiplyAddFastF32", MathOperation.multiply_add_complex_fast_f32: "cutlass::arch::OpMultiplyAddComplexFastF32", MathOperation.multiply_add_complex: "cutlass::arch::OpMultiplyAddComplex", MathOperation.multiply_add_complex_gaussian: "cutlass::arch::OpMultiplyAddGaussianComplex", MathOperation.multiply_add_fast_accum: "cutlass::arch::OpMultiplyAddFastAccum", } ################################################################################################### # class LayoutType(enum.Enum): ColumnMajor = enum_auto() RowMajor = enum_auto() ColumnMajorInterleaved2 = enum_auto() RowMajorInterleaved2 = enum_auto() ColumnMajorInterleaved32 = enum_auto() RowMajorInterleaved32 = enum_auto() ColumnMajorInterleaved64 = enum_auto() RowMajorInterleaved64 = enum_auto() TensorNWC = enum_auto() TensorNHWC = enum_auto() TensorNDHWC = enum_auto() TensorNCHW = enum_auto() TensorNGHWC = enum_auto() TensorNC32HW32 = enum_auto() TensorNC64HW64 = enum_auto() TensorC32RSK32 = enum_auto() TensorC64RSK64 = enum_auto() TensorKCS = enum_auto() TensorKCSR = enum_auto() TensorKCSRT = enum_auto() # LayoutTag = { LayoutType.ColumnMajor: "cutlass::layout::ColumnMajor", LayoutType.RowMajor: "cutlass::layout::RowMajor", LayoutType.ColumnMajorInterleaved2: "cutlass::layout::ColumnMajorInterleaved<2>", LayoutType.RowMajorInterleaved2: "cutlass::layout::RowMajorInterleaved<2>", LayoutType.ColumnMajorInterleaved32: "cutlass::layout::ColumnMajorInterleaved<32>", LayoutType.RowMajorInterleaved32: "cutlass::layout::RowMajorInterleaved<32>", LayoutType.ColumnMajorInterleaved64: "cutlass::layout::ColumnMajorInterleaved<64>", LayoutType.RowMajorInterleaved64: "cutlass::layout::RowMajorInterleaved<64>", LayoutType.TensorNWC: "cutlass::layout::TensorNWC", LayoutType.TensorNHWC: "cutlass::layout::TensorNHWC", LayoutType.TensorNDHWC: "cutlass::layout::TensorNDHWC", LayoutType.TensorNCHW: "cutlass::layout::TensorNCHW", LayoutType.TensorNGHWC: "cutlass::layout::TensorNGHWC", LayoutType.TensorNC32HW32: "cutlass::layout::TensorNCxHWx<32>", LayoutType.TensorC32RSK32: "cutlass::layout::TensorCxRSKx<32>", LayoutType.TensorNC64HW64: "cutlass::layout::TensorNCxHWx<64>", LayoutType.TensorC64RSK64: "cutlass::layout::TensorCxRSKx<64>", LayoutType.TensorKCS: "cutlass::layout::TensorKCS", LayoutType.TensorKCSR: "cutlass::layout::TensorKCSR", LayoutType.TensorKCSRT: "cutlass::layout::TensorKCSRT", } # TransposedLayout = { LayoutType.ColumnMajor: LayoutType.RowMajor, LayoutType.RowMajor: LayoutType.ColumnMajor, LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2, LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2, LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32, LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32, LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64, LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64, LayoutType.TensorNHWC: LayoutType.TensorNHWC, } # ShortLayoutTypeNames = { LayoutType.ColumnMajor: "n", LayoutType.ColumnMajorInterleaved2: "n2", LayoutType.ColumnMajorInterleaved32: "n32", LayoutType.ColumnMajorInterleaved64: "n64", LayoutType.RowMajor: "t", LayoutType.RowMajorInterleaved2: "t2", LayoutType.RowMajorInterleaved32: "t32", LayoutType.RowMajorInterleaved64: "t64", LayoutType.TensorNWC: "nwc", LayoutType.TensorNHWC: "nhwc", LayoutType.TensorNDHWC: "ndhwc", LayoutType.TensorNCHW: "nchw", LayoutType.TensorNGHWC: "nghwc", LayoutType.TensorNC32HW32: "nc32hw32", LayoutType.TensorNC64HW64: "nc64hw64", LayoutType.TensorC32RSK32: "c32rsk32", LayoutType.TensorC64RSK64: "c64rsk64", LayoutType.TensorKCS: "kcs", LayoutType.TensorKCSR: "kcsr", LayoutType.TensorKCSRT: "kcsrt", } # ShortComplexLayoutNames = { (LayoutType.ColumnMajor, ComplexTransform.none): "n", (LayoutType.ColumnMajor, ComplexTransform.conj): "c", (LayoutType.RowMajor, ComplexTransform.none): "t", (LayoutType.RowMajor, ComplexTransform.conj): "h", } ################################################################################################### class KernelScheduleType(enum.Enum): ScheduleAuto = enum_auto() Multistage = enum_auto() CpAsyncWarpSpecialized = enum_auto() CpAsyncWarpSpecializedPingpong = enum_auto() CpAsyncWarpSpecializedCooperative = enum_auto() Tma = enum_auto() TmaWarpSpecialized = enum_auto() TmaWarpSpecializedPingpong = enum_auto() TmaWarpSpecializedCooperative = enum_auto() TmaWarpSpecializedFP8FastAccum = enum_auto() TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto() TmaWarpSpecializedPingpongFP8FastAccum = enum_auto() ImplicitTmaWarpSpecializedSm90 = enum_auto() PtrArrayTmaWarpSpecializedCooperative = enum_auto() PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto() PtrArrayTmaWarpSpecializedPingpong = enum_auto() PtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto() BlockwiseTmaWarpSpecializedCooperative = enum_auto() PtrArrayBlockwiseTmaWarpSpecializedCooperative = enum_auto() TmaWarpSpecialized1SmSm100 = enum_auto() TmaWarpSpecialized2SmSm100 = enum_auto() ImplicitTmaWarpSpecialized1SmSm100 = enum_auto() ImplicitTmaWarpSpecialized2SmSm100 = enum_auto() PtrArrayTmaWarpSpecialized1SmSm100 = enum_auto() PtrArrayTmaWarpSpecialized2SmSm100 = enum_auto() PtrArrayTmaWarpSpecialized1SmBlockScaledSm100 = enum_auto() PtrArrayTmaWarpSpecialized2SmBlockScaledSm100 = enum_auto() PtrArrayNvf4TmaWarpSpecialized1SmSm100 = enum_auto() PtrArrayNvf4TmaWarpSpecialized2SmSm100 = enum_auto() PtrArrayMxf4TmaWarpSpecialized1SmSm100 = enum_auto() PtrArrayMxf4TmaWarpSpecialized2SmSm100 = enum_auto() PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto() PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto() SparseTmaWarpSpecialized1SmSm100 = enum_auto() SparseTmaWarpSpecialized2SmSm100 = enum_auto() BlockScaledTmaWarpSpecialized1SmSm100 = enum_auto() BlockScaledTmaWarpSpecialized2SmSm100 = enum_auto() Mxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto() Mxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto() BlockwiseTmaWarpSpecialized1SmSm100 = enum_auto() BlockwiseTmaWarpSpecialized2SmSm100 = enum_auto() PtrArrayBlockwiseTmaWarpSpecialized1SmSm100 = enum_auto() PtrArrayBlockwiseTmaWarpSpecialized2SmSm100 = enum_auto() Mxf4TmaWarpSpecialized1SmSm100 = enum_auto() Mxf4TmaWarpSpecialized2SmSm100 = enum_auto() Nvf4TmaWarpSpecialized1SmSm100 = enum_auto() Nvf4TmaWarpSpecialized2SmSm100 = enum_auto() Mxf8f6f4TmaWarpSpecializedCooperativeSm120 = enum_auto() Mxf8f6f4TmaWarpSpecializedPingpongSm120 = enum_auto() Nvf4TmaWarpSpecializedCooperativeSm120 = enum_auto() Nvf4TmaWarpSpecializedPingpongSm120 = enum_auto() Mxf4TmaWarpSpecializedCooperativeSm120 = enum_auto() Mxf4TmaWarpSpecializedPingpongSm120 = enum_auto() F8f6f4SparseTmaWarpSpecializedCooperativeSm120 = enum_auto() BlockwiseTmaWarpSpecializedCooperativeSm120 = enum_auto() BlockwiseTmaWarpSpecializedPingpongSm120 = enum_auto() KernelScheduleTag = { KernelScheduleType.ScheduleAuto: "cutlass::gemm::collective::KernelScheduleAuto", KernelScheduleType.Multistage: "cutlass::gemm::KernelMultistage", KernelScheduleType.CpAsyncWarpSpecialized: "cutlass::gemm::KernelCpAsyncWarpSpecialized", KernelScheduleType.CpAsyncWarpSpecializedPingpong: "cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong", KernelScheduleType.CpAsyncWarpSpecializedCooperative: "cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative", KernelScheduleType.Tma: "cutlass::gemm::KernelTma", KernelScheduleType.TmaWarpSpecialized: "cutlass::gemm::KernelTmaWarpSpecialized", KernelScheduleType.TmaWarpSpecializedPingpong: "cutlass::gemm::KernelTmaWarpSpecializedPingpong", KernelScheduleType.TmaWarpSpecializedCooperative: "cutlass::gemm::KernelTmaWarpSpecializedCooperative", KernelScheduleType.TmaWarpSpecializedFP8FastAccum: "cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum", KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: "cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum", KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: "cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum", KernelScheduleType.ImplicitTmaWarpSpecializedSm90: "cutlass::conv::KernelImplicitTmaWarpSpecializedSm90", KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: "cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum", KernelScheduleType.TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelTmaWarpSpecialized1SmSm100", KernelScheduleType.TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelTmaWarpSpecialized2SmSm100", KernelScheduleType.ImplicitTmaWarpSpecialized1SmSm100: "cutlass::conv::KernelImplicitTmaWarpSpecialized1SmSm100", KernelScheduleType.ImplicitTmaWarpSpecialized2SmSm100: "cutlass::conv::KernelImplicitTmaWarpSpecialized2SmSm100", KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100", KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100", KernelScheduleType.SparseTmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100", KernelScheduleType.SparseTmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100", KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100", KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100", KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100", KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100", KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100", KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100", KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100", KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100", KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelTmaWarpSpecialized1SmMxf4Sm100", KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelTmaWarpSpecialized2SmMxf4Sm100", KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100", KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100", KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: "cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative", KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: "cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum", KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: "cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong", KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: "cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum", KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: "cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum", KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100", KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100", KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100", KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmNvf4Sm100", KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf4Sm100", KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf4Sm100", KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100", KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100", KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120: "cutlass::gemm::KernelTmaWarpSpecializedMxf8f6f4Sm120", KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: "cutlass::gemm::KernelTmaWarpSpecializedPingpongMxf8f6f4Sm120", KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120: "cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120", KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120: "cutlass::gemm::KernelTmaWarpSpecializedPingpongNvf4Sm120", KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: "cutlass::gemm::KernelTmaWarpSpecializedMxf4Sm120", KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: "cutlass::gemm::KernelTmaWarpSpecializedPingpongMxf4Sm120", KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: "cutlass::gemm::KernelScheduleSparseF8f6f4Sm120", KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: "cutlass::gemm::KernelTmaWarpSpecializedBlockwiseCooperativeSm120", KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: "cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120", } # KernelScheduleSuffixes = { KernelScheduleType.ScheduleAuto: "", KernelScheduleType.Multistage: "_cpasync", KernelScheduleType.CpAsyncWarpSpecialized: "_cpasync_warpspecialized", KernelScheduleType.CpAsyncWarpSpecializedPingpong: "_cpasync_warpspecialized_pingpong", KernelScheduleType.CpAsyncWarpSpecializedCooperative: "_cpasync_warpspecialized_cooperative", KernelScheduleType.Tma: "_unspecialized", KernelScheduleType.TmaWarpSpecialized: "_warpspecialized", KernelScheduleType.TmaWarpSpecializedPingpong: "_warpspecialized_pingpong", KernelScheduleType.TmaWarpSpecializedCooperative: "_warpspecialized_cooperative", KernelScheduleType.TmaWarpSpecializedFP8FastAccum: "_warpspecialized_fp8_fastaccum", KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: "_warpspecialized_cooperative_fp8_fastaccum", KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: "_warpspecialized_pingpong_fp8_fastaccum", KernelScheduleType.ImplicitTmaWarpSpecializedSm90: "_warpspecialized", KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: "_warpspecialized_cooperative", KernelScheduleType.TmaWarpSpecialized1SmSm100: "_1sm", KernelScheduleType.TmaWarpSpecialized2SmSm100: "_2sm", KernelScheduleType.ImplicitTmaWarpSpecialized1SmSm100: "_1sm", KernelScheduleType.ImplicitTmaWarpSpecialized2SmSm100: "_2sm", KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100: "_1sm", KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100: "_2sm", KernelScheduleType.SparseTmaWarpSpecialized1SmSm100: "_1sm", KernelScheduleType.SparseTmaWarpSpecialized2SmSm100: "_2sm", KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: "_1sm", KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: "_2sm", KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: "_q_1sm", KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: "_q_2sm", KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100: "_1sm", KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100: "_2sm", KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100: "_1sm", KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100: "_2sm", KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: "_o_vs32_1sm", KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: "_o_vs32_2sm", KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: "_o_vs16_1sm", KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: "_o_vs16_2sm", KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: "_warpspecialized_cooperative", KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: "_warpspecialized_cooperative_fp8_fastaccum", KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: "_warpspecialized_pingpong", KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: "_warpspecialized_pingpong_fp8_fastaccum", KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: "_warpspecialized_cooperative", KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: "_1sm", KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: "_2sm", KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: "_o_vs16_1sm", KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100: "_o_vs16_2sm", KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100: "_o_vs32_1sm", KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100: "_o_vs32_2sm", KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: "_o_vs32_1sm", KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: "_o_vs32_2sm", KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120: "_cooperative_q", KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: "_pingpong_q", KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120: "_cooperative_o_vs16", KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120: "_pingpong_o_vs16", KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: "_cooperative_o_vs32", KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: "_pingpong_o_vs32", KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: "_q", KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: "_cooperative_q", KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: "_pingpong_q", } class EpilogueScheduleType(enum.Enum): ScheduleAuto = enum_auto() EpilogueTransposed = enum_auto() NoSmemWarpSpecialized = enum_auto() PtrArrayNoSmemWarpSpecialized = enum_auto() NoSmemWarpSpecialized1Sm = enum_auto() NoSmemWarpSpecialized2Sm = enum_auto() PtrArrayNoSmemWarpSpecialized1Sm = enum_auto() PtrArrayNoSmemWarpSpecialized2Sm = enum_auto() TmaWarpSpecialized = enum_auto() TmaWarpSpecializedCooperative = enum_auto() TmaWarpSpecialized1Sm = enum_auto() TmaWarpSpecialized2Sm = enum_auto() PtrArrayTmaWarpSpecialized1Sm = enum_auto() PtrArrayTmaWarpSpecialized2Sm = enum_auto() PtrArrayTmaWarpSpecializedPingpong = enum_auto() PtrArrayTmaWarpSpecializedCooperative = enum_auto() # EpilogueScheduleTag = { EpilogueScheduleType.ScheduleAuto: "cutlass::epilogue::collective::EpilogueScheduleAuto", EpilogueScheduleType.EpilogueTransposed: "cutlass::gemm::EpilogueTransposed", EpilogueScheduleType.NoSmemWarpSpecialized: "cutlass::epilogue::NoSmemWarpSpecialized", EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: "cutlass::epilogue::PtrArrayNoSmemWarpSpecialized", EpilogueScheduleType.NoSmemWarpSpecialized1Sm: "cutlass::epilogue::NoSmemWarpSpecialized1Sm", EpilogueScheduleType.NoSmemWarpSpecialized2Sm: "cutlass::epilogue::NoSmemWarpSpecialized2Sm", EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: "cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm", EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: "cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm", EpilogueScheduleType.TmaWarpSpecialized: "cutlass::epilogue::TmaWarpSpecialized", EpilogueScheduleType.TmaWarpSpecializedCooperative: "cutlass::epilogue::TmaWarpSpecializedCooperative", EpilogueScheduleType.TmaWarpSpecialized1Sm: "cutlass::epilogue::TmaWarpSpecialized1Sm", EpilogueScheduleType.TmaWarpSpecialized2Sm: "cutlass::epilogue::TmaWarpSpecialized2Sm", EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: "cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm", EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: "cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm", EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: "cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative", EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: "cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong", } # EpilogueScheduleSuffixes = { EpilogueScheduleType.ScheduleAuto: "", EpilogueScheduleType.EpilogueTransposed: "", EpilogueScheduleType.NoSmemWarpSpecialized: "_epi_nosmem", EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: "_epi_nosmem", EpilogueScheduleType.NoSmemWarpSpecialized1Sm: "_epi_nosmem", EpilogueScheduleType.NoSmemWarpSpecialized2Sm: "_epi_nosmem", EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: "_epi_nosmem", EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: "_epi_nosmem", EpilogueScheduleType.TmaWarpSpecialized: "_epi_tma", EpilogueScheduleType.TmaWarpSpecializedCooperative: "_epi_tma", EpilogueScheduleType.TmaWarpSpecialized1Sm: "", EpilogueScheduleType.TmaWarpSpecialized2Sm: "_epi_tma", EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: "_tma_1sm", EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: "_tma_2sm", EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: "_epi_tma", EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: "_epi_tma", } class EpilogueFunctor3x(enum.Enum): LinearCombination = enum_auto() LinearCombinationBlockScaleFactor = enum_auto() # EpilogueFunctor3xTag = { EpilogueFunctor3x.LinearCombination: "cutlass::epilogue::fusion::LinearCombination", EpilogueFunctor3x.LinearCombinationBlockScaleFactor: "cutlass::epilogue::fusion::LinCombBlockScaleFactor", } # TMA epilogues have certain alignment requirements as calculated in get_tma_alignment(data_type) def is_tma_epilogue(epilogue_schedule_type): return epilogue_schedule_type in [ EpilogueScheduleType.ScheduleAuto, EpilogueScheduleType.TmaWarpSpecialized, EpilogueScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecialized1Sm, EpilogueScheduleType.TmaWarpSpecialized2Sm, EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm, EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm, EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative, EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong, ] def to_grouped_schedule(schedule, grouped): if not grouped: return schedule group_schedule_map = { # SM90 KernelScheduleType.TmaWarpSpecializedCooperative: KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative, KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative, KernelScheduleType.TmaWarpSpecializedPingpong: KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong, KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum, KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.TmaWarpSpecialized: EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecializedCooperative: EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized, # SM100 KernelScheduleType.TmaWarpSpecialized1SmSm100: KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100, KernelScheduleType.TmaWarpSpecialized2SmSm100: KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100, KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100, KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100, KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100, KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100, KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100, KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100, KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100: KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100, KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100: KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm, EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm, } return group_schedule_map[schedule] class TileSchedulerType(enum.Enum): Default = enum_auto() Persistent = enum_auto() StreamK = enum_auto() # TileSchedulerTag = { TileSchedulerType.Default: "void", TileSchedulerType.Persistent: "cutlass::gemm::PersistentScheduler", TileSchedulerType.StreamK: "cutlass::gemm::StreamKScheduler", } # TileSchedulerSuffixes = { TileSchedulerType.Default: "", TileSchedulerType.Persistent: "", TileSchedulerType.StreamK: "_stream_k", } ################################################################################################### # class SideMode(enum.Enum): Left = enum_auto() Right = enum_auto() # SideModeTag = { SideMode.Left: "cutlass::SideMode::kLeft", SideMode.Right: "cutlass::SideMode::kRight", } # ShortSideModeNames = {SideMode.Left: "ls", SideMode.Right: "rs"} ################################################################################################### # class FillMode(enum.Enum): Lower = enum_auto() Upper = enum_auto() # FillModeTag = { FillMode.Lower: "cutlass::FillMode::kLower", FillMode.Upper: "cutlass::FillMode::kUpper", } # ShortFillModeNames = {FillMode.Lower: "l", FillMode.Upper: "u"} ################################################################################################### # class DiagType(enum.Enum): NonUnit = enum_auto() Unit = enum_auto() # DiagTypeTag = { DiagType.NonUnit: "cutlass::DiagType::kNonUnit", DiagType.Unit: "cutlass::DiagType::kUnit", } # ShortDiagTypeNames = {DiagType.NonUnit: "nu", DiagType.Unit: "un"} ################################################################################################### # class OpcodeClass(enum.Enum): Simt = enum_auto() TensorOp = enum_auto() WmmaTensorOp = enum_auto() SparseTensorOp = enum_auto() BlockScaledTensorOp = enum_auto() OpcodeClassNames = { OpcodeClass.Simt: "simt", OpcodeClass.TensorOp: "tensorop", OpcodeClass.WmmaTensorOp: "wmma_tensorop", OpcodeClass.SparseTensorOp: "sptensorop", OpcodeClass.BlockScaledTensorOp: "bstensorop", } OpcodeClassTag = { OpcodeClass.Simt: "cutlass::arch::OpClassSimt", OpcodeClass.TensorOp: "cutlass::arch::OpClassTensorOp", OpcodeClass.WmmaTensorOp: "cutlass::arch::OpClassWmmaTensorOp", OpcodeClass.SparseTensorOp: "cutlass::arch::OpClassSparseTensorOp", OpcodeClass.BlockScaledTensorOp: "cutlass::arch::OpClassBlockScaledTensorOp", } ################################################################################################### # class OperationKind(enum.Enum): Gemm = enum_auto() RankK = enum_auto() Rank2K = enum_auto() Trmm = enum_auto() Symm = enum_auto() Conv2d = enum_auto() Conv3d = enum_auto() # OperationKindNames = { OperationKind.Gemm: "gemm", OperationKind.RankK: "rank_k", OperationKind.Rank2K: "rank_2k", OperationKind.Trmm: "trmm", OperationKind.Symm: "symm", OperationKind.Conv2d: "conv2d", OperationKind.Conv3d: "conv3d", } # class Target(enum.Enum): library = enum_auto() # ArchitectureNames = { 50: "maxwell", 60: "pascal", 61: "pascal", 70: "volta", 75: "turing", 80: "ampere", 89: "ada", 90: "hopper", } # SharedMemPerCC = { 70: 96, # 96KB of SMEM 72: 96, # 96KB of SMEM 75: 64, # 64KB of SMEM 80: 163, # 163KB of SMEM - 1KB reserved for the driver 86: 99, # 99KB of SMEM - 1KB reserved for the driver 87: 163, # 163KB of SMEM - 1KB reserved for the driver 89: 99, # 99KB of SMEM - 1KB reserved for the driver 90: 227, # 227KB of SMEM - 1KB reserved for the driver } ################################################################################################### # def SubstituteTemplate(template, values): text = template changed = True while changed: changed = False for key, value in values.items(): regex = "\\$\\{%s\\}" % key newtext = re.sub(regex, value, text) if newtext != text: changed = True text = newtext return text ################################################################################################### # class GemmKind(enum.Enum): Gemm = enum_auto() Sparse = enum_auto() Universal = enum_auto() Universal3x = enum_auto() SparseUniversal3x = enum_auto() PlanarComplex = enum_auto() PlanarComplexArray = enum_auto() Grouped = enum_auto() BlockScaledUniversal3x = enum_auto() GroupedUniversal3x = enum_auto() GroupedBlockScaledUniversal3x = enum_auto() BlockwiseUniversal3x = enum_auto() GroupedBlockwiseUniversal3x = enum_auto() # GemmKindNames = { GemmKind.Gemm: "gemm", GemmKind.Sparse: "spgemm", GemmKind.Universal: "gemm", GemmKind.Universal3x: "gemm", GemmKind.SparseUniversal3x: "spgemm", GemmKind.PlanarComplex: "gemm_planar_complex", GemmKind.PlanarComplexArray: "gemm_planar_complex_array", GemmKind.Grouped: "gemm_grouped", GemmKind.BlockScaledUniversal3x: "gemm", GemmKind.GroupedUniversal3x: "gemm_grouped", GemmKind.GroupedBlockScaledUniversal3x: "gemm_grouped", GemmKind.BlockwiseUniversal3x: "gemm", GemmKind.GroupedBlockwiseUniversal3x: "gemm_grouped", } # class RankKKind(enum.Enum): Universal = enum_auto() # RankKKindNames = {RankKKind.Universal: "rank_k"} # class TrmmKind(enum.Enum): Universal = enum_auto() # TrmmKindNames = {TrmmKind.Universal: "trmm"} # class SymmKind(enum.Enum): Universal = enum_auto() # SymmKindNames = {SymmKind.Universal: "symm"} # class EpilogueFunctor(enum.Enum): LinearCombination = enum_auto() LinearCombinationClamp = enum_auto() # EpilogueFunctorTag = { EpilogueFunctor.LinearCombination: "cutlass::epilogue::thread::LinearCombination", EpilogueFunctor.LinearCombinationClamp: "cutlass::epilogue::thread::LinearCombinationClamp", } # class MixedInputMode(enum.Enum): ConvertOnly = enum_auto() ScaleOnly = enum_auto() ScaleWithZeroPoint = enum_auto() # class SwizzlingFunctor(enum.Enum): Identity1 = enum_auto() Identity2 = enum_auto() Identity4 = enum_auto() Identity8 = enum_auto() Horizontal = enum_auto() StridedDgradIdentity1 = enum_auto() StridedDgradIdentity4 = enum_auto() StridedDgradHorizontal = enum_auto() StreamK = enum_auto() # SwizzlingFunctorTag = { SwizzlingFunctor.Identity1: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>", SwizzlingFunctor.Identity2: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>", SwizzlingFunctor.Identity4: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>", SwizzlingFunctor.Identity8: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>", SwizzlingFunctor.Horizontal: "cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle", SwizzlingFunctor.StridedDgradIdentity1: "cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>", SwizzlingFunctor.StridedDgradIdentity4: "cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>", SwizzlingFunctor.StridedDgradHorizontal: "cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle", SwizzlingFunctor.StreamK: "cutlass::gemm::threadblock::ThreadblockSwizzleStreamK", } # class GroupScheduleMode(enum.Enum): Device = (enum_auto(),) Host = enum_auto() # GroupScheduleModeTag = { GroupScheduleMode.Device: "cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly", GroupScheduleMode.Host: "cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute", } # ShortGroupScheduleModeNames = { GroupScheduleMode.Device: "Device", GroupScheduleMode.Host: "Host", } ################################################################################################### # class ConvKind(enum.IntEnum): Fprop = 0 Dgrad = 1 Wgrad = 2 # ConvKindTag = { ConvKind.Fprop: "cutlass::conv::Operator::kFprop", ConvKind.Dgrad: "cutlass::conv::Operator::kDgrad", ConvKind.Wgrad: "cutlass::conv::Operator::kWgrad", } ConvKindNames = { ConvKind.Fprop: "fprop", ConvKind.Dgrad: "dgrad", ConvKind.Wgrad: "wgrad", } class ConvMode(enum.IntEnum): CrossCorrelation = 0 Convolution = 1 # class IteratorAlgorithm(enum.Enum): Analytic = 0 Optimized = 1 FixedChannels = 2 FewChannels = 3 FixedStrideDilation = 4 # IteratorAlgorithmTag = { IteratorAlgorithm.Analytic: "cutlass::conv::IteratorAlgorithm::kAnalytic", IteratorAlgorithm.Optimized: "cutlass::conv::IteratorAlgorithm::kOptimized", IteratorAlgorithm.FixedChannels: "cutlass::conv::IteratorAlgorithm::kFixedChannels", IteratorAlgorithm.FewChannels: "cutlass::conv::IteratorAlgorithm::kFewChannels", IteratorAlgorithm.FixedStrideDilation: "cutlass::conv::IteratorAlgorithm::kFixedStrideDilation", } IteratorAlgorithmNames = { IteratorAlgorithm.Analytic: "analytic", IteratorAlgorithm.Optimized: "optimized", IteratorAlgorithm.FixedChannels: "fixed_channels", IteratorAlgorithm.FewChannels: "few_channels", IteratorAlgorithm.FixedStrideDilation: "fixed_stride_dilation", } # class StrideSupport(enum.Enum): Strided = 0 Unity = 1 Fixed = 2 # StrideSupportTag = { StrideSupport.Strided: "cutlass::conv::StrideSupport::kStrided", StrideSupport.Unity: "cutlass::conv::StrideSupport::kUnity", StrideSupport.Fixed: "cutlass::conv::StrideSupport::kFixed", } StrideSupportNames = { StrideSupport.Strided: "", StrideSupport.Unity: "unity_stride", StrideSupport.Fixed: "fixed_stride", } # class GroupMode(enum.Enum): NoneGroup = enum_auto() # dense conv (G=1) SingleGroup = enum_auto() # grouped convolution (single group per CTA) MultipleGroup = enum_auto() # grouped convolution ( multiple groups per CTA) Depthwise = enum_auto() # Depthwise convolution ( C=K=G ) # GroupModeTag = { GroupMode.NoneGroup: "cutlass::conv::GroupMode::kNone", GroupMode.SingleGroup: "cutlass::conv::GroupMode::kSingleGroup", GroupMode.MultipleGroup: "cutlass::conv::GroupMode::kMultipleGroup", GroupMode.Depthwise: "cutlass::conv::GroupMode::kDepthwise", } GroupModeNames = { GroupMode.NoneGroup: "", GroupMode.SingleGroup: "single_group", GroupMode.MultipleGroup: "multiple_group", GroupMode.Depthwise: "depthwise", } DynamicClusterShape = [0, 0, 1] ################################################################################################### # class MathInstruction: def __init__( self, instruction_shape, element_a, element_b, element_accumulator, opcode_class, math_operation=MathOperation.multiply_add, element_scale_factor=None, ): self.instruction_shape = instruction_shape self.element_a = element_a self.element_b = element_b self.element_accumulator = element_accumulator self.opcode_class = opcode_class self.math_operation = math_operation self.element_scale_factor = element_scale_factor # class TileDescription: def __init__( self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute, cluster_shape=(1, 1, 1), explicit_vector_sizes=None, ): self.threadblock_shape = threadblock_shape self.tile_shape = threadblock_shape self.stages = stages self.warp_count = warp_count self.math_instruction = math_instruction self.minimum_compute_capability = min_compute self.maximum_compute_capability = max_compute self.cluster_shape = cluster_shape self.explicit_vector_sizes = explicit_vector_sizes def procedural_name(self): if self.minimum_compute_capability >= 90: return "{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{s}".format( tbm=self.threadblock_shape[0], tbn=self.threadblock_shape[1], tbk=self.threadblock_shape[2], cm=self.cluster_shape[0], cn=self.cluster_shape[1], ck=self.cluster_shape[2], s=self.stages, ) else: return "%dx%d_%dx%d" % ( self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages, ) class Direct2dConvFixedStrideDilationTileDescription: def __init__( self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute, ): self.threadblock_shape = [ threadblock_output_shape[0] * threadblock_output_shape[1] * threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0] * filter_shape[1], ] self.threadblock_output_shape = threadblock_output_shape self.filter_shape = filter_shape self.stages = stages self.warp_count = warp_count self.stride = stride self.dilation = dilation self.math_instruction = math_instruction self.minimum_compute_capability = min_compute self.maximum_compute_capability = max_compute def procedural_name(self): str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % ( self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.threadblock_output_shape[0], self.threadblock_output_shape[1], self.threadblock_output_shape[2], self.threadblock_output_shape[3], self.stages, self.filter_shape[0], self.filter_shape[1], ) # Fixed Strided and dilation if self.stride != [-1, -1] and self.dilation != [-1, -1]: str_name += "_stride%dx%d_dilation%dx%d" % ( self.stride[0], self.stride[1], self.dilation[0], self.dilation[1], ) return str_name # class TensorDescription: def __init__( self, element, layout, alignment=1, complex_transform=ComplexTransform.none ): self.element = element self.layout = layout self.alignment = alignment self.complex_transform = complex_transform # class SymmetricTensorDescription: def __init__( self, element, layout, fill_mode, alignment=1, complex_transform=ComplexTransform.none, side_mode=SideMode.Left, ): self.element = element self.layout = layout self.fill_mode = fill_mode self.alignment = alignment self.complex_transform = complex_transform self.side_mode = side_mode # class TriangularTensorDescription: def __init__( self, element, layout, side_mode, fill_mode, diag_type, alignment=1, complex_transform=ComplexTransform.none, ): self.element = element self.layout = layout self.side_mode = side_mode self.fill_mode = fill_mode self.diag_type = diag_type self.alignment = alignment self.complex_transform = complex_transform # def CalculateSmemUsage(operation): cta_shape = operation.tile_description.threadblock_shape stages = operation.tile_description.stages if ( operation.operation_kind == OperationKind.Gemm and operation.gemm_kind == GemmKind.Sparse ): # Elements represented by 8 bits of metadata (based on 4:8, 2:4 or 1:2 sparsity) if DataTypeSize[operation.A.element] == 32: elements_per_8b_md = 2 elif DataTypeSize[operation.A.element] == 4: elements_per_8b_md = 8 else: elements_per_8b_md = 4 smem_per_stage = ( DataTypeSize[operation.A.element] * cta_shape[0] * (cta_shape[2] // 2) // 8 + DataTypeSize[operation.B.element] * cta_shape[1] * cta_shape[2] // 8 + cta_shape[0] * (cta_shape[2] // 2) // elements_per_8b_md ) else: # Few BLAS3 operations only have A tensor data_type_size_a = DataTypeSize[operation.A.element] data_type_size_b = DataTypeSize[operation.A.element] if operation.is_mixed_input(): data_type_size_b = DataTypeSize[operation.B.element] smem_per_stage = ( data_type_size_a * cta_shape[0] * cta_shape[2] // 8 + data_type_size_b * cta_shape[1] * cta_shape[2] // 8 ) smem_usage = smem_per_stage * stages return smem_usage >> 10 class GemmUniversalMode(enum.IntEnum): """ Types corresponding to GemmUniversalMode """ Gemm = 0 GemmSplitKParallel = 1 Batched = 2 Array = 3 class SplitKMode(enum.IntEnum): """ Types corresponding to SplitKMode """ NoneSplitK = 0 Serial = 1 Parallel = 2