sglang_v0.5.2/flashinfer_0.3.1/3rdparty/cutlass/python/cutlass_library/heuristics.py

415 lines
17 KiB
Python

#################################################################################################
#
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Utilities for selecting CUTLASS library kernels based on problem description
"""
import json
import csv
try:
if CUTLASS_IGNORE_PACKAGE:
raise ImportError("Disabling attempt to import cutlass_library")
from cutlass_library.library import *
from cutlass_library.generator import *
from cutlass_library.heuristics_provider import *
except ImportError:
from library import *
from generator import *
from heuristics_provider import *
try:
from .sm90_utils import (
get_valid_schedules,
generate_data_types_from_math_instruction,
fix_alignments,
)
except ImportError:
from sm90_utils import (
get_valid_schedules,
generate_data_types_from_math_instruction,
fix_alignments,
)
_LOGGER = logging.getLogger(__name__)
dtype_map = {v: k for k, v in DataTypeNames.items()}
def serialize_heuristics_results_to_json(problems_with_configs, outfile_path):
"""
Utilitiy function to write heuristics results to a json file for debug
args:
problems_with_configs: List of problems provided to the heuristic, with a list of operations added to each problem dict
outfile_path: Outfile path
returns:
None
"""
pc_copy = problems_with_configs.copy()
for p in pc_copy:
for k, v in p.items():
if isinstance(v, DataType):
p[k] = DataTypeNames[v]
elif isinstance(v, LayoutType):
p[k] = ShortLayoutTypeNames[v]
configs = p['configs']
for c in configs:
for k, v in c.items():
if isinstance(v, DataType):
c[k] = DataTypeNames[v]
elif isinstance(v, LayoutType):
c[k] = ShortLayoutTypeNames[v]
with open(outfile_path, 'w') as f:
json.dump(pc_copy, f, indent=2)
def get_single_gemm_config(m, n, k, batch_count, layouts, dtypes, alignment_a, alignment_b, voidC=False, use_fast_acc=True, count=1, provider=None):
"""
Get heuristic-suggested GEMM kernel configurations for a single GEMM problem.
args:
m, n, k: GEMM dimensions
batch_count: batch count
layouts: tuple of layouts of type LayoutType
use_fast_acc: Use fast accumulation for FP8. Ignored for other precisions
count: Number of configs to return
provider: Heuristics provider to use
returns:
A list of dictionaries containing the suggested kernel configurations and additional info from the input required to define a Cutlass GemmOperation, with the following keys:
- 'cta_tile_m', 'cta_tile_m', 'cta_tile_k': CTA tile size
- 'instr_tile_m', 'instr_tile_n', 'instr_tile_k': Instruction tile size
- 'stages': kernel pipeline stage count
- 'cluster_m', 'cluster_n', 'cluster_k': cluster size
- 'layout_a', 'layout_b': input tensor layouts of type LayoutType
- 'alignment_a', 'alignment_b': input tensor alignments, in count of elements
- 'dtype_a', 'dtype_b', 'dtype_acc': dtypes of a, b, and accumulator, of type DataType
- 'swizzle_size' : suggested threadblock swizzle
- 'split_k_slices': number of partitions of the k dimension for splitK
- 'raster_order': raster order for CTAs over output tiles ('along_m' or 'along_n')
"""
if provider is None:
provider = MatmulHeuristics()
return provider.get_configs(m, n, k, batch_count, dtypes, layouts, alignment_a, alignment_b, voidC=voidC, use_fast_acc=use_fast_acc, count=count)
def get_gemm_configs(problems, provider=None, count=1):
"""
Get heuristic-suggested GEMM kernel configurations for a set of GEMM problems.
args:
problems: List of dictionaries describing GEMM problems with the following keys:
- 'm', 'n', 'k': Matrix dimensions (required)
- 'dtype_a': Data type of matrix A (required)
- 'dtype_b': Data type of matrix B (required)
- 'dtype_c': Data type of matrix C (default: None)
- 'dtype_d': Data type of matrix D (required)
- 'dtype_acc': Compute data type (default 'f32')
- 'layout': Operation layout (e.g. 'tnt')
- 'alignment_a': Memory access granularity of A, in units of elements (default: 16 bytes equivalent elements)
- 'alignment_b': Memory access granularity of B, in units of elements (default: 16 bytes equivalent elements)
- 'alpha': Scalar multiplier for A*B (default: 1.0)
- 'beta': Scalar multiplier for C (default: 0.0)
- 'batch_count': Number of GEMM operations in batch (default: 1)
- 'use_fast_acc': Enable fast accumulation for FP8 on Hopper (default: True)
provider: Heuristics provider to use
count: Number of configurations to return per problem (defualt: 1)
returns:
A copy of the input dictionary, with key `configs` added containing the selected gemm configs
"""
ret = []
for problem in problems:
problem = problem.copy()
try:
m = problem['m']
n = problem['n']
k = problem['k']
dtype_a = problem['dtype_a']
dtype_b = problem['dtype_b']
dtype_d = problem['dtype_d']
layout = problem['layout']
except KeyError as e:
_LOGGER.error(f"Missing required parameter {e} for problem {problem}")
raise
operation = problem.get('operation', 'gemm')
batch_count = problem.get('batch_count', 1)
dtype_acc = problem.get('dtype_acc', 'f32')
dtype_c = problem.get('dtype_c', None)
alpha = problem.get('alpha', 1.0)
beta = problem.get('beta', 0.0)
use_fast_acc = problem.get('use_fast_acc', True)
if operation != OperationKindNames[OperationKind.Gemm]:
raise ValueError(f"Unsupported operation {operation}")
if not (len(layout) == 3 and all(c in "nt" for c in layout)):
raise ValueError(f"layout must be a 3-character string containing only 'n' or 't', got {layout}")
layouts = tuple(LayoutType.RowMajor if l == 't' else LayoutType.ColumnMajor for l in layout)
try:
dtype_list = [dtype_a.lower(), dtype_b.lower(), dtype_acc.lower(), dtype_c.lower() if dtype_c is not None else dtype_d.lower(), dtype_d.lower()]
dtypes = tuple(dtype_map[dt] for dt in dtype_list)
except KeyError as dt:
_LOGGER.error(f"Unsupported data type: {dt}")
raise
alignment_a = problem.get('alignment_a', 128 // DataTypeSize[dtypes[0]])
alignment_b = problem.get('alignment_b', 128 // DataTypeSize[dtypes[1]])
configs = get_single_gemm_config(m, n, k, batch_count, layouts, dtypes, alignment_a, alignment_b, beta==0.0, use_fast_acc, count, provider)
problem['configs'] = configs
ret.append(problem)
return ret
def generate_sm100_from_heuristics_configs(manifest, cuda_version, kernel_configs):
"""
Generate CUTLASS operations based on the list of configs provided by the heuristic provider
args:
manifest: manifest argument to which to add operations, or None to just return the operations without a manifest (for pruning an existing manifest)
cuda_version: Cuda compiler version for generating cutlass operations
kernel_configs: list of configs generated by the heuristic
returns:
(configs, operations): a list of heuristic-provided kernel configs along with a one-to-one corresponding list of the generated operations
"""
min_cc = 100
max_cc = 101
if manifest is None:
# Use a dummy manifest so we can use existing CreateGemmOperator functions
manifest = Manifest()
configs = []
operations = []
for config in kernel_configs:
layout = ([config['layout_a'], config['alignment_a']], [config['layout_b'], config['alignment_b']], [config['layout_d'], 128 // DataTypeSize[config['dtype_d']]])
element_a, element_b, element_accumulator, element_c, element_d = config['dtype_a'], config['dtype_b'], config['dtype_acc'], config['dtype_c'], config['dtype_d']
# nvMMH assumes 2sm instruction for !(cluster_m % 2)
is_2sm = config['cluster_m'] % 2 == 0
instruction_shape = [(2 * config['cta_tile_m']) if is_2sm else config['cta_tile_m'], config['cta_tile_n'], config['cta_tile_k'] // 4]
math_instruction = MathInstruction(
instruction_shape,
element_a, element_b, element_accumulator,
OpcodeClass.TensorOp,
MathOperation.multiply_add
)
data_types = [
{
"a_type" : math_instruction.element_a,
"b_type" : math_instruction.element_b,
"c_type" : DataType.void if config['voidC'] else math_instruction.element_accumulator,
"d_type" : element_d,
"acc_type" : math_instruction.element_accumulator,
"epi_type" : math_instruction.element_accumulator,
}
]
tile_multiplier = (config['cluster_m'] // (2 if is_2sm else 1), config['cluster_n'], config['cluster_k'])
tile_description = TileDescription(
[instruction_shape[0] * tile_multiplier[0],
instruction_shape[1] * tile_multiplier[1],
instruction_shape[2] * 4 * tile_multiplier[2]],
0,
[4,1,1],
math_instruction,
min_cc,
max_cc,
cluster_shape=(config['cluster_m'], config['cluster_n'], config['cluster_k'])
)
schedules = []
if is_2sm:
schedules.append([KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm])
else:
schedules.append([KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm])
for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, schedules, tile_schedulers=[TileSchedulerType.Default, TileSchedulerType.StreamK], gemm_kind=GemmKind.Universal3x):
configs.append(config)
operations.append(o)
return configs, operations
def generate_sm90_from_heuristics_configs(manifest, cuda_version, kernel_configs):
"""
Generate CUTLASS operations based on the list of configs provided by the heuristic provider
args:
manifest: manifest argument to which to add operations, or None to just return the operations without a manifest (for pruning an existing manifest)
cuda_version: Cuda compiler version for generating cutlass operations
kernel_configs: list of configs generated by the heuristic
returns:
(configs, operations): a list of heuristic-provided kernel configs along with a one-to-one corresponding list of the generated operations
"""
min_cc, max_cc = 90, 90
if manifest is None:
# Use a dummy manifest so we can use existing CreateGemmOperator functions
manifest = Manifest()
configs = []
operations = []
for config in kernel_configs:
is_aligned = (config['alignment_a'] * DataTypeSize[config['dtype_a']] >= 128) and (config['alignment_b'] * DataTypeSize[config['dtype_b']] >= 128)
layout = ([config['layout_a'], config['alignment_a']], [config['layout_b'], config['alignment_b']], [LayoutType.ColumnMajor, 1])
element_a, element_b, element_accumulator, element_c, element_d = config['dtype_a'], config['dtype_b'], config['dtype_acc'], config['dtype_c'], config['dtype_d']
# instr shape and warp config are unused for emitting 3x collective builder code
dummy_instr_shape = [0, 0, 0]
math_instruction = MathInstruction(
dummy_instr_shape,
element_a, element_b, element_accumulator,
OpcodeClass.TensorOp,
MathOperation.multiply_add
)
data_types = generate_data_types_from_math_instruction(math_instruction, element_source=element_c, element_dest=element_d)
if is_aligned:
layout = fix_alignments(data_types, layout, alignment_bits=128)
# instr shape and warp config are unused for emitting 3x collective builder code
dummy_warp_count = [0, 0, 0]
tile_description = TileDescription(
[config['cta_tile_m'], config['cta_tile_n'], config['cta_tile_k']],
0,
dummy_warp_count,
math_instruction,
min_cc,
max_cc,
cluster_shape=(config['cluster_m'], config['cluster_n'], config['cluster_k'])
)
schedules, stream_k_schedules = get_valid_schedules(
tile_description=tile_description,
cuda_version=cuda_version,
is_aligned=is_aligned,
data_types=data_types,
instantiation_level=9000, # don't prune schedules: we didn't get any schedule suggestion from the heuristic
layout=layout,
gemm_kind=GemmKind.Universal3x,
enable_fp8_fast_acc=config['use_fast_acc']
)
if len(schedules):
for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, schedules, gemm_kind=GemmKind.Universal3x):
configs.append(config)
operations.append(o)
if len(stream_k_schedules):
for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types,
stream_k_schedules,
tile_schedulers=[TileSchedulerType.StreamK]):
configs.append(config)
operations.append(o)
return configs, operations
def filter_manifest_and_write_heuristics_file(manifest, args):
"""
Prune a manifest according to heuristics suggestions from the problems file
args:
manifest: Cutlass manifest to prune
args: generator.py args, requires:
- args.heuristics_problems_file
- args.heuristics_gpu
- args.heuristics_testlist_file
returns:
A list of dictionaries, each of which has information about an operation and a problem from the input problems
"""
heuristics_problems = []
with open(args.heuristics_problems_file, 'r') as f:
heuristics_problems = json.load(f)
gpu = None if (args.heuristics_gpu == "auto" or args.heuristics_gpu == "") else args.heuristics_gpu
mmh = MatmulHeuristics(gpu=gpu)
if any(('100' in arch) for arch in args.architectures.split(';')):
mmh.set_cta_div_n(64)
problems_with_configs = get_gemm_configs(heuristics_problems, provider=mmh, count=args.heuristics_configs_per_problem)
all_configs_and_operations = []
operations = []
for problem in problems_with_configs:
if any('90' in arch for arch in args.architectures.split(';')):
problem_configs, problem_operations = generate_sm90_from_heuristics_configs(None if args.heuristics_restrict_kernels else manifest, args.cuda_version, problem['configs'])
if any(('100' in arch) or ('101' in arch) for arch in args.architectures.split(';')):
problem_configs, problem_operations = generate_sm100_from_heuristics_configs(None if args.heuristics_restrict_kernels else manifest, args.cuda_version, problem['configs'])
operations += problem_operations
problem_without_configs = {k: v for k, v in problem.items() if k != 'configs'}
with_problem_size = [{'operation_name': o.procedural_name(), **problem_without_configs, **c} for c, o in zip(problem_configs, problem_operations)]
all_configs_and_operations += with_problem_size
for operation in operations:
manifest.add_kernel_filter(f"^{operation.procedural_name()}$")
if not all_configs_and_operations:
raise Exception("No valid configurations generated")
write_profiler_testlist_to_csv(all_configs_and_operations, args.heuristics_testlist_file)
return all_configs_and_operations
def write_profiler_testlist_to_csv(configs_list, outfile_path):
"""
Write a list of configs to a testlist to be consumed by cutlass_profiler
args:
configs_list: List of kernel configs along with runtime arguments and any other columns to include in the CSV, expressed as a list of dictionaries
outfile_path: Outfile path
returns:
None
"""
profiler_testlist = configs_list.copy()
for c in profiler_testlist:
for k, v in c.items():
if isinstance(v, DataType):
c[k] = DataTypeNames[v]
elif isinstance(v, LayoutType):
c[k] = ShortLayoutTypeNames[v]
with open(outfile_path, mode='w', newline='') as ofile:
k_names = profiler_testlist[0].keys()
writer = csv.DictWriter(ofile, fieldnames=k_names)
writer.writeheader()
writer.writerows(profiler_testlist)