sglang_v0.5.2/flashinfer_0.3.1/flashinfer/deep_gemm.py

1482 lines
44 KiB
Python

"""
MIT License
Copyright (c) 2025 DeepSeek
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
# Imported and adapted from DeepGEMM.
import ctypes
import enum
import functools
import hashlib
import json
from typing import Any, Dict, Optional, Tuple
try:
import cuda.bindings.driver as cbd
except ImportError as e:
raise ImportError(
"Could not import the 'cuda' module. "
"Please install cuda-python that matches your CUDA version."
) from e
import torch
from .artifacts import ArtifactPath, MetaInfoHash
from .cuda_utils import checkCudaErrors
from .jit.cubin_loader import get_cubin
from .jit.env import FLASHINFER_CUBIN_DIR
from .utils import ceil_div, round_up
class GemmType(enum.Enum):
Normal = 0
GroupedContiguous = 1
GroupedMasked = 2
def __str__(self) -> str:
return {
0: "GemmType::Normal",
1: "GemmType::GroupedContiguous",
2: "GemmType::GroupedMasked",
}[self.value]
class MajorTypeAB(enum.Enum):
KMajor = 0
MNMajor = 1
def shape_direction(self):
return 1 if self.value == 0 else -1
def non_contiguous_dim(self):
return -2 if self.value == 0 else -1
def __str__(self) -> str:
return {0: "cute::UMMA::Major::K", 1: "cute::UMMA::Major::MN"}[self.value]
class MajorTypeCD(enum.Enum):
NMajor = 0
MMajor = 1
def non_contiguous_dim(self):
return -2 if self.value == 0 else -1
def major_check(t: torch.Tensor):
assert t.dim() in (2, 3)
if t.dim() == 3:
assert t.stride(0) == t.size(-2) * t.size(-1), (
"Grouped dimension cannot have abnormal stride"
)
assert t.stride(-2) == 1 or t.stride(-1) == 1
def get_major_type_ab(t: torch.Tensor):
major_check(t)
return MajorTypeAB.KMajor if t.stride(-1) == 1 else MajorTypeAB.MNMajor
def get_major_type_cd(t: torch.Tensor):
major_check(t)
return MajorTypeCD.NMajor if t.stride(-1) == 1 else MajorTypeCD.MMajor
def get_element_size(dtype: torch.dtype):
return {
torch.float8_e4m3fn: 1,
torch.bfloat16: 2,
torch.float: 4,
}[dtype]
def get_m_alignment_for_contiguous_layout():
return 128
def get_tma_aligned_size(x: int, element_size: int) -> int:
tma_alignment_bytes = 16
assert tma_alignment_bytes % element_size == 0
alignment = tma_alignment_bytes // element_size
return round_up(x, alignment)
def get_col_major_tma_aligned_packed_tensor(x: torch.Tensor) -> torch.Tensor:
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
assert x.dtype == torch.float and x.dim() in (2, 3)
# First, convert into UE8M0 `uint8_t`
ue8m0_tensor = (x.view(torch.int) >> 23).to(torch.uint8)
# Second, make padded packed tensors
mn, k = x.shape[-2], x.shape[-1]
remove_dim = False
if x.dim() == 2:
x, remove_dim = x.unsqueeze(0), True
b = x.shape[0]
aligned_mn = get_tma_aligned_size(mn, 4)
aligned_k = round_up(k, 4)
padded = torch.zeros((b, aligned_mn, aligned_k), device=x.device, dtype=torch.uint8)
padded[:, :mn, :k] = ue8m0_tensor
padded = padded.view(-1).view(dtype=torch.int).view(b, aligned_mn, aligned_k // 4)
# Finally, transpose
transposed = torch.transpose(
torch.empty((b, aligned_k // 4, aligned_mn), device=x.device, dtype=torch.int),
1,
2,
)
transposed[:, :, :] = padded
aligned_x = transposed[:, :mn, :]
return aligned_x.squeeze(0) if remove_dim else aligned_x
def check_sf_layout(
sf: torch.Tensor,
mn: int,
k: int,
gran: Tuple[int, int],
num_groups: Optional[int],
tma_stride_check: bool = False,
type_check: Optional[torch.dtype] = None,
) -> torch.Tensor:
# Type check
if type_check is not None:
assert sf.dtype == type_check
# Always do shape checks
assert sf.dtype in (torch.float, torch.int)
assert sf.dim() == int(num_groups is not None) + 2
if num_groups is not None:
assert sf.size(-3) == num_groups
assert sf.size(-2) == ceil_div(mn, gran[0])
assert sf.size(-1) == ceil_div(k, gran[1] * (1 if sf.dtype == torch.float else 4))
# TMA stride checks: TMA aligned and MN-major
if tma_stride_check:
if num_groups is not None:
assert sf.stride(-3) == sf.stride(-1) * sf.size(-1)
assert sf.stride(-2) == 1
assert sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size())
return sf
def transform_sf_into_required_layout(
sf: torch.Tensor,
mn: int,
k: int,
recipe: Tuple[int, int, int],
num_groups: Optional[int] = None,
is_sfa: bool = False,
):
gran = (recipe[0 if is_sfa else 1], recipe[2])
should_skip_transform = (
sf.dtype == torch.int and gran == (1, 128) and get_device_arch() == "100a"
) or (sf.dtype == torch.int and gran == (128, 128) and get_device_arch() == "100a")
if not should_skip_transform:
# Pre-transform checks
check_sf_layout(sf, mn=mn, k=k, gran=gran, num_groups=num_groups)
# (FP32, 1, 128) on Hopper: transform to TMA-aligned and MN-major
if sf.dtype == torch.float and gran == (1, 128) and get_device_arch() == "90a":
raise NotImplementedError
# (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
if sf.dtype == torch.float and gran == (1, 128) and get_device_arch() == "100a":
sf = get_col_major_tma_aligned_packed_tensor(sf)
return check_sf_layout(
sf,
mn=mn,
k=k,
gran=(1, 128),
num_groups=num_groups,
tma_stride_check=True,
type_check=torch.int,
)
# (FP32, 128, 128) on Hopper: no need to transform, check shape and whatever-major
if sf.dtype == torch.float and gran == (128, 128) and get_device_arch() == "90a":
raise NotImplementedError
# (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
if sf.dtype == torch.float and gran == (128, 128) and get_device_arch() == "100a":
sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
sf = get_col_major_tma_aligned_packed_tensor(sf)
return check_sf_layout(
sf,
mn=mn,
k=k,
gran=(1, 128),
num_groups=num_groups,
tma_stride_check=True,
type_check=torch.int,
)
if should_skip_transform:
# TODO: add transpose kernel if SF layout is not satisfied
return check_sf_layout(
sf,
mn=mn,
k=k,
gran=(1, 128),
num_groups=num_groups,
tma_stride_check=True,
type_check=torch.int,
)
AssertionError(f"Unknown cases: {sf.dtype=}, {gran=}, arch={get_device_arch()}")
@functools.lru_cache(maxsize=None)
def get_device_arch():
major, minor = torch.cuda.get_device_capability()
suffix = "a" if major >= 9 else ""
return f"{major * 10 + minor}{suffix}"
def hash_to_hex(s: str) -> str:
md5 = hashlib.md5()
md5.update(s.encode("utf-8"))
return md5.hexdigest()[0:12]
@functools.lru_cache(maxsize=None)
def must_be_k_major() -> bool:
return {
"90a": True,
"100a": False,
}[get_device_arch()]
@functools.lru_cache(maxsize=None)
def get_default_recipe(
sfa_dtype: torch.dtype, sfb_dtype: torch.dtype
) -> Tuple[int, int, int]:
assert sfa_dtype in (torch.float, torch.int)
return {
("90a", torch.float): (1, 128, 128),
("100a", torch.float): (1, 128, 128),
("100a", torch.int): (1, 1, 128),
}[(get_device_arch(), sfb_dtype)]
class MulticastConfig:
def __init__(self, num_multicast: int, is_multicast_on_a: bool):
self.num_multicast = num_multicast
self.is_multicast_on_a = is_multicast_on_a
def get_ab_load_block_m(self, block_m: int):
# NOTES: this for >= SM100 only
assert get_device_arch() != "90a"
return block_m // (self.num_multicast if self.is_multicast_on_a else 1)
def get_ab_load_block_n(self, block_n: int):
# NOTES: this for >= SM100 only
assert get_device_arch() != "90a"
return block_n // (1 if self.is_multicast_on_a else self.num_multicast)
class SharedMemoryConfig:
def __init__(
self,
smem_size: int,
swizzle_a_mode: int,
swizzle_b_mode: int,
swizzle_cd_mode: int,
):
self.smem_size = smem_size
self.swizzle_a_mode = swizzle_a_mode
self.swizzle_b_mode = swizzle_b_mode
# NOTES: sometimes the default swizzling pattern maybe not compatible (e.g., FP32 output)
self.swizzle_cd_mode = swizzle_cd_mode
# TODO: swizzle SF as well
self.swizzle_sf_mode = 0
assert self.swizzle_a_mode != 0
assert self.swizzle_b_mode != 0
assert self.swizzle_cd_mode > 16
assert self.swizzle_sf_mode == 0
def is_multicast_legal(
shape_dim: int,
block_dim: int,
num_multicast: int,
num_sms: int,
require_divisible: bool = False,
) -> bool:
divisible = (
ceil_div(shape_dim, block_dim) % num_multicast == 0 or not require_divisible
)
return divisible and num_sms % num_multicast == 0
def get_swizzle_mode(block_size: int, elem_size: int) -> int:
# `> 0` means interleaving
# 16B actually means non-swizzling (but interleaving)
for mode_bytes in (128, 64, 32, 16):
if (block_size * elem_size) % mode_bytes == 0:
return mode_bytes
AssertionError("Invalid mode")
return 0
def get_sf_aligned_block_sizes(block_m: int, block_n: int, ab_dtype: torch.dtype):
num_utccp_aligned_elems = 128
assert block_m % num_utccp_aligned_elems == 0
return {
torch.bfloat16: (0, 0),
torch.float8_e4m3fn: (
round_up(block_m, num_utccp_aligned_elems),
round_up(block_n, num_utccp_aligned_elems),
),
}[ab_dtype]
def is_tmem_size_legal(block_m: int, block_n: int, ab_dtype: torch.float):
# M waves or epilogue stages (* 2), SFA and SFB
sf_block_m, sf_block_n = get_sf_aligned_block_sizes(block_m, block_n, ab_dtype)
return ((2 * block_n) + (sf_block_m // 32) + (sf_block_n // 32)) <= 512
def get_smem_config(
block_m: int,
block_n: int,
block_k: int,
major_a: MajorTypeAB,
major_b: MajorTypeAB,
major_d: MajorTypeCD,
ab_dtype: torch.dtype,
cd_dtype: torch.dtype,
num_stages: int,
multicast_config: MulticastConfig,
) -> SharedMemoryConfig:
assert major_d == MajorTypeCD.NMajor
ab_elem_size = get_element_size(ab_dtype)
cd_elem_size = get_element_size(cd_dtype)
load_block_m = multicast_config.get_ab_load_block_m(block_m)
load_block_n = multicast_config.get_ab_load_block_n(block_n)
swizzle_a_mode = get_swizzle_mode(
block_k if major_a == MajorTypeAB.KMajor else load_block_m, ab_elem_size
)
swizzle_b_mode = get_swizzle_mode(
block_k if major_b == MajorTypeAB.KMajor else load_block_n, ab_elem_size
)
swizzle_cd_mode = get_swizzle_mode(
block_n if major_d == MajorTypeCD.NMajor else block_m, cd_elem_size
)
# 2 stages of STSM and TMA store
# TODO: consider other layouts
layout_ad_m = 128
smem_d = min(block_m, layout_ad_m) * swizzle_cd_mode * 2
# A/B shared memory
smem_a_per_stage = load_block_m * block_k * ab_elem_size
smem_b_per_stage = load_block_n * block_k * ab_elem_size
# SF shared memory must be aligned to UTCCP
# Each stage must prefetch next 4 stages' SF (including the current)
sf_block_m, sf_block_n = get_sf_aligned_block_sizes(block_m, block_n, ab_dtype)
smem_scales_a_per_stage = sf_block_m * 4
smem_scales_b_per_stage = sf_block_n * 4
# TODO: remove SF barriers for BF16 GEMMs
# TMA full/empty barriers, with-SF full barriers, tensor memory full/empty barriers, accumulation full barrier
# NOTES: some shapes may only have 1 epilogue stage, but we still allocate space for 2 stages
# NOTES: cases without accumulation will not use the accumulation full barrier
smem_barrier = num_stages * 8 * 3 + 2 * 8 * 2 + 8
smem_tmem_ptr = 4
# Sum them up
smem_size = 0
smem_size += smem_d
smem_size += num_stages * smem_a_per_stage
smem_size += num_stages * smem_b_per_stage
smem_size += num_stages * smem_scales_a_per_stage
smem_size += num_stages * smem_scales_b_per_stage
smem_size += smem_barrier
smem_size += smem_tmem_ptr
return SharedMemoryConfig(
smem_size, swizzle_a_mode, swizzle_b_mode, swizzle_cd_mode
)
@functools.lru_cache(maxsize=None)
def get_best_configs(
gemm_type: GemmType,
m: int,
n: int,
k: int,
num_groups: int,
major_a: MajorTypeAB,
major_b: MajorTypeAB,
major_d: MajorTypeCD,
ab_dtype: torch.dtype,
cd_dtype: torch.dtype,
num_sms: int,
) -> Tuple[int, int, int, int, int, MulticastConfig, SharedMemoryConfig]:
assert ab_dtype == torch.float8_e4m3fn
assert cd_dtype in (torch.bfloat16, torch.float)
# `BLOCK_M` and `BLOCK_N` are selected according to MMA instructions
block_ms: Tuple[int, ...] = None
if gemm_type == GemmType.GroupedContiguous:
block_ms = (get_m_alignment_for_contiguous_layout(),)
else:
block_ms = (128,) if major_b == MajorTypeAB.KMajor else (128, 256)
# NOTES: some `% 32 == 16` cases are not compatible with 2-CTA TMA swizzling
block_ns = (
tuple(range(16, 257, 16))
if major_b == MajorTypeAB.KMajor
else tuple(range(32, 257, 32))
)
# `BLOCK_K` is selected in a fixed manner
block_k = 128 // get_element_size(ab_dtype)
fix_wave_saturate = lambda x: num_sms if x == 0 else x
get_num_waves = lambda bm, bn: (
ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms)
if bm
else None
)
get_last_wave_util = lambda bm, bn: fix_wave_saturate(
(ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms
)
# Decide block sizes by waves
# TODO: move block size search into `common.py`
best_block_m, best_block_n = None, None
for block_m in block_ms:
for block_n in block_ns:
success = False
num_waves, best_num_waves = (
get_num_waves(block_m, block_n),
get_num_waves(best_block_m, best_block_n),
)
if (
best_block_m is None
or best_block_n is None
or num_waves < best_num_waves
):
success = True
elif num_waves == best_num_waves:
# Check last wave utilization
util = get_last_wave_util(block_m, block_n)
best_util = get_last_wave_util(best_block_m, best_block_n)
success = util > best_util
if util == best_util:
# Case 1: same `block_m`, smaller `block_n` (wasted)
success |= block_m == best_block_m and block_n < best_block_n
# Case 2: same `block_n`, smaller `block_m` (wasted)
success |= block_n == best_block_n and block_m < best_block_m
# Case 3: different for both `block_m` and `block_n`, larger `block_n` is better
success |= block_m != best_block_m and block_n > best_block_n
success &= is_tmem_size_legal(block_m, block_n, ab_dtype)
best_block_m, best_block_n = (
(block_m, block_n) if success else (best_block_m, best_block_n)
)
assert best_block_m is not None and best_block_n is not None
# Decide the number of TMA multicasts and whether broadcast on A
best_multicast_config = MulticastConfig(1, True)
# Try to multicast on the larger block side first
is_legal = {
# TODO: support other `tcgen05` layouts
"A": False,
"B": is_multicast_legal(m, best_block_m, 2, num_sms, True)
and gemm_type == GemmType.Normal,
}
for i in ("A", "B") if best_block_m > best_block_n else ("B", "A"):
if m >= 512 and is_legal[i]:
best_multicast_config = MulticastConfig(2, i == "A")
break
# Always pick the longest one
# NOTES: for double B scales, the best number of stages may be reduced
# TODO: move stage search into `common.py`
best_num_stages, best_smem_config, sm100_capacity = None, None, 232448
stage_candidates = tuple(
filter(lambda s: s <= max(k // 128, 1), (8, 7, 6, 5, 4, 3, 2, 1))
)
for num_stages in stage_candidates:
best_smem_config = get_smem_config(
best_block_m,
best_block_n,
block_k,
major_a,
major_b,
major_d,
ab_dtype,
cd_dtype,
num_stages,
best_multicast_config,
)
if best_smem_config.smem_size <= sm100_capacity:
best_num_stages = num_stages
break
assert best_smem_config is not None
assert best_num_stages is not None
# Recompute the minimal number of SMs required
# NOTES: less L2 cache usage and less GPU frequency drop
# TODO: move min SM fix into `common.py`
num_waves = get_num_waves(best_block_m, best_block_n)
num_min_sms = ceil_div(
ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves
)
num_min_sms = (
ceil_div(num_min_sms, best_multicast_config.num_multicast)
* best_multicast_config.num_multicast
)
assert num_min_sms <= num_sms
return (
num_min_sms,
best_block_m,
best_block_n,
block_k,
best_num_stages,
best_multicast_config,
best_smem_config,
)
tmap_type_map: Dict[Any, str] = {
torch.int8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.int16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
torch.int32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT32,
torch.int64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT64,
torch.uint8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.uint16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
torch.uint32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT32,
torch.uint64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT64,
torch.float32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
torch.float16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
torch.bfloat16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
torch.float8_e4m3fn: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.float8_e4m3fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.float8_e5m2: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.float8_e5m2fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
}
swizzle_type_map = {
0: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE,
16: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE,
32: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_32B,
64: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_64B,
128: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B,
}
def make_tma_xd_desc(
t: torch.Tensor,
gmem_dims: Tuple[cbd.cuuint64_t, ...],
gmem_strides: Tuple[cbd.cuuint64_t, ...],
smem_dims: Tuple[cbd.cuuint32_t, ...],
swizzle_type: cbd.CUtensorMapSwizzle,
) -> cbd.CUtensorMap:
num_dims = len(gmem_dims)
assert len(gmem_strides) == num_dims - 1
assert len(smem_dims) == num_dims
tensor_dtype = tmap_type_map[t.dtype]
tensor_map = checkCudaErrors(
cbd.cuTensorMapEncodeTiled(
tensor_dtype,
num_dims,
t.data_ptr(),
gmem_dims,
gmem_strides,
smem_dims,
(cbd.cuuint32_t(1),) * num_dims,
cbd.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE,
swizzle_type,
cbd.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
cbd.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE,
)
)
return tensor_map
def make_tma_2d_desc(
t: torch.Tensor,
gmem_inner_dim: int,
gmem_outer_dim: int,
smem_inner_dim: int,
smem_outer_dim: int,
gmem_outer_stride: int,
swizzle_mode: int,
) -> cbd.CUtensorMap:
# For swizzling pattern, multiple TMAs are required
if swizzle_mode != 0:
assert swizzle_mode % t.element_size() == 0
smem_inner_dim = swizzle_mode // t.element_size()
gmem_dims = (cbd.cuuint64_t(gmem_inner_dim), cbd.cuuint64_t(gmem_outer_dim))
gmem_strides = (cbd.cuuint64_t(gmem_outer_stride * t.element_size()),)
smem_dims = (cbd.cuuint32_t(smem_inner_dim), cbd.cuuint32_t(smem_outer_dim))
return make_tma_xd_desc(
t, gmem_dims, gmem_strides, smem_dims, swizzle_type_map[swizzle_mode]
)
def make_tma_a_desc(
major_type: MajorTypeAB,
t: torch.Tensor,
shape_m: int,
shape_k: int,
block_m: int,
block_k: int,
outer_stride: int,
num_groups: int,
swizzle_mode: int,
) -> cbd.CUtensorMap:
if num_groups > 1:
assert major_type == MajorTypeAB.KMajor
gmem_inner_dim, gmem_outer_dim = (shape_k, shape_m * num_groups)[
:: major_type.shape_direction()
]
smem_inner_dim, smem_outer_dim = (block_k, block_m)[:: major_type.shape_direction()]
return make_tma_2d_desc(
t,
gmem_inner_dim,
gmem_outer_dim,
smem_inner_dim,
smem_outer_dim,
outer_stride,
swizzle_mode,
)
def make_tma_b_desc(
major_type: MajorTypeAB,
t: torch.Tensor,
shape_n: int,
shape_k: int,
block_n: int,
block_k: int,
outer_stride: int,
num_groups: int,
swizzle_mode: int,
) -> cbd.CUtensorMap:
# `num_groups` is always applied into the outer dimensions
io_shapes = (shape_k, shape_n)[:: major_type.shape_direction()]
gmem_inner_dim, gmem_outer_dim = (io_shapes[0], io_shapes[1] * num_groups)
smem_inner_dim, smem_outer_dim = (block_k, block_n)[:: major_type.shape_direction()]
return make_tma_2d_desc(
t,
gmem_inner_dim,
gmem_outer_dim,
smem_inner_dim,
smem_outer_dim,
outer_stride,
swizzle_mode,
)
def make_tma_cd_desc(
major_type: MajorTypeCD,
t: torch.Tensor,
shape_m: int,
shape_n: int,
block_m: int,
block_n: int,
outer_stride: int,
num_groups: int,
swizzle_mode: int,
) -> cbd.CUtensorMap:
assert major_type == MajorTypeCD.NMajor
# Swizzling requires the inner box dim to be less or equal than `kSwizzleCDMode`
# bytes, so `BLOCK_N * sizeof(T) / kSwizzleCDMode` TMA stores are required
layout_ad_m = 128
return make_tma_2d_desc(
t,
shape_n,
shape_m * num_groups,
block_n,
min(block_m, layout_ad_m),
outer_stride,
swizzle_mode,
)
def make_tma_sf_desc(
major_type: MajorTypeAB,
t: torch.Tensor,
shape_mn: int,
shape_k: int,
block_mn: int,
block_k: int,
num_groups: int,
swizzle_mode: int,
) -> cbd.CUtensorMap:
assert major_type == MajorTypeAB.MNMajor
# TODO: maybe swizzle SF as well
assert swizzle_mode == 0
# Make TMA aligned to 16 bytes
shape_mn = get_tma_aligned_size(shape_mn, t.element_size())
return make_tma_2d_desc(
t,
shape_mn,
ceil_div(shape_k, block_k * 4) * num_groups,
block_mn,
1,
shape_mn,
swizzle_mode,
)
# Map some common Python types into C types
pytypes_to_ctypes = {
True: "true",
False: "false",
torch.bfloat16: "cutlass::bfloat16_t",
torch.float: "float",
}
RUNTIME_CACHE = {}
class SM100FP8GemmRuntime:
def __init__(self, path: str, symbol: str) -> None:
self.path = path
self.lib = None
self.kernel = None
self.symbol = symbol
# Store a reference to the cleanup function to avoid import issues during shutdown
self._cleanup_func = cbd.cuLibraryUnload
def __call__(self, **kwargs) -> cbd.CUresult:
# Load CUBIN
if self.kernel is None:
# Load CUBIN
path = bytes(self.path, encoding="utf-8")
self.lib = checkCudaErrors(
cbd.cuLibraryLoadFromFile(path, [], [], 0, [], [], 0)
)
self.kernel = checkCudaErrors(
cbd.cuLibraryGetKernel(self.lib, bytes(self.symbol, encoding="utf-8"))
)
# noinspection PyArgumentList
return self.launch(self.kernel, kwargs)
def __del__(self) -> None:
if self.lib is not None:
try:
checkCudaErrors(self._cleanup_func(self.lib))
except Exception as e:
# Ignore any errors during shutdown
print(f"Failed to delete SM100FP8GemmRuntime: {e}")
@staticmethod
def generate(kwargs: Dict[str, Any]) -> str:
assert kwargs["CD_DTYPE_T"] in (torch.bfloat16, torch.float)
code = f"""
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#include <string>
#endif
#include <deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm100_fp8_gemm_1d1d_impl<
{kwargs["MAJOR_A"]},
{kwargs["MAJOR_B"]},
{kwargs["M"] if "m" in kwargs["COMPILED_DIMS"] else 0},
{kwargs["N"] if "n" in kwargs["COMPILED_DIMS"] else 0},
{kwargs["K"] if "k" in kwargs["COMPILED_DIMS"] else 0},
{kwargs["BLOCK_M"]},
{kwargs["BLOCK_N"]},
{kwargs["BLOCK_K"]},
{kwargs["NUM_GROUPS"]},
{kwargs["SWIZZLE_A_MODE"]},
{kwargs["SWIZZLE_B_MODE"]},
{kwargs["SWIZZLE_CD_MODE"]},
{kwargs["NUM_STAGES"]},
{kwargs["NUM_LAST_STAGES"]},
{kwargs["NUM_NON_EPILOGUE_THREADS"]},
{kwargs["NUM_EPILOGUE_THREADS"]},
{kwargs["NUM_MULTICAST"]},
{pytypes_to_ctypes[kwargs["IS_MULTICAST_ON_A"]]},
{kwargs["GEMM_TYPE"]},
{pytypes_to_ctypes[kwargs["WITH_ACCUMULATION"]]},
{pytypes_to_ctypes[kwargs["CD_DTYPE_T"]]}
>);
}};
"""
return code
# noinspection PyMethodOverriding
@staticmethod
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
checkCudaErrors(
cbd.cuKernelSetAttribute(
cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
kwargs["SMEM_SIZE"],
kernel,
cbd.CUdevice(kwargs["DEVICE_INDEX"]),
)
)
attr_val = cbd.CUlaunchAttributeValue()
attr_val.clusterDim.x = kwargs["NUM_MULTICAST"]
attr_val.clusterDim.y = 1
attr_val.clusterDim.z = 1
attr = cbd.CUlaunchAttribute()
attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
attr.value = attr_val
config = cbd.CUlaunchConfig()
config.numAttrs = 1
config.attrs = [attr]
config.gridDimX = kwargs["NUM_SMS"]
config.gridDimY = 1
config.gridDimZ = 1
config.blockDimX = (
kwargs["NUM_NON_EPILOGUE_THREADS"] + kwargs["NUM_EPILOGUE_THREADS"]
)
config.blockDimY = 1
config.blockDimZ = 1
config.sharedMemBytes = kwargs["SMEM_SIZE"]
config.hStream = kwargs["STREAM"]
arg_values = (
kwargs["GROUPED_LAYOUT"].data_ptr(),
kwargs["M"],
kwargs["N"],
kwargs["K"],
kwargs["TENSOR_MAP_A"],
kwargs["TENSOR_MAP_B"],
kwargs["TENSOR_MAP_SFA"],
kwargs["TENSOR_MAP_SFB"],
kwargs["TENSOR_MAP_C"],
kwargs["TENSOR_MAP_D"],
)
arg_types = (
ctypes.c_void_p,
ctypes.c_uint32,
ctypes.c_uint32,
ctypes.c_uint32,
None,
None,
None,
None,
None,
None,
)
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)
def load_all():
for cubin_name in KERNEL_MAP:
if cubin_name in RUNTIME_CACHE:
continue
symbol, sha256 = KERNEL_MAP[cubin_name]
get_cubin(ArtifactPath.DEEPGEMM + cubin_name, sha256)
path = FLASHINFER_CUBIN_DIR / f"{ArtifactPath.DEEPGEMM + cubin_name}.cubin"
assert path.exists()
RUNTIME_CACHE[cubin_name] = SM100FP8GemmRuntime(str(path), symbol)
def load(name: str, code: str) -> SM100FP8GemmRuntime:
signature = f"{name}$${code}"
cubin_name = f"kernel.{name}.{hash_to_hex(signature)}"
if cubin_name not in KERNEL_MAP:
raise ValueError("cubin not registered")
if cubin_name in RUNTIME_CACHE:
return RUNTIME_CACHE[cubin_name]
symbol, sha256 = KERNEL_MAP[cubin_name]
get_cubin(ArtifactPath.DEEPGEMM + cubin_name, sha256)
path = FLASHINFER_CUBIN_DIR / f"{ArtifactPath.DEEPGEMM + cubin_name}.cubin"
assert path.exists()
RUNTIME_CACHE[cubin_name] = SM100FP8GemmRuntime(str(path), symbol)
return RUNTIME_CACHE[cubin_name]
def m_grouped_fp8_gemm_nt_contiguous_static_kwargs_gen(
m: int,
n: int,
k: int,
aligned_k: int,
num_groups: int,
major_a: MajorTypeAB,
major_b: MajorTypeAB,
major_d: MajorTypeCD,
compiled_dims: str,
output_dtype: torch.dtype,
):
num_sms = torch.cuda.get_device_properties(device="cuda").multi_processor_count
num_sms, block_m, block_n, block_k, num_stages, multicast_config, smem_config = (
get_best_configs(
GemmType.GroupedContiguous,
m,
n,
k,
num_groups,
major_a,
major_b,
major_d,
torch.float8_e4m3fn,
output_dtype,
num_sms,
)
)
kwargs = {
# Templated or runtime arguments according to the `COMPILED_DIMS`
"COMPILED_DIMS": compiled_dims,
"M": m,
"N": n,
"K": aligned_k,
# Templated arguments
"GEMM_TYPE": GemmType.GroupedContiguous,
"NUM_NON_EPILOGUE_THREADS": 128,
"NUM_EPILOGUE_THREADS": 128,
"MAJOR_A": major_a,
"MAJOR_B": major_b,
"NUM_GROUPS": num_groups,
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"BLOCK_K": block_k,
"NUM_STAGES": num_stages,
"NUM_LAST_STAGES": ceil_div(k, block_k) % num_stages,
"SWIZZLE_A_MODE": smem_config.swizzle_a_mode,
"SWIZZLE_B_MODE": smem_config.swizzle_b_mode,
"SWIZZLE_CD_MODE": smem_config.swizzle_cd_mode,
"NUM_MULTICAST": multicast_config.num_multicast,
"IS_MULTICAST_ON_A": multicast_config.is_multicast_on_a,
"WITH_ACCUMULATION": False,
"CD_DTYPE_T": output_dtype,
}
return (
num_sms,
block_m,
block_n,
block_k,
num_stages,
multicast_config,
smem_config,
), kwargs
def m_grouped_fp8_gemm_nt_contiguous_kwargs_gen(
a: torch.Tensor,
sfa: torch.Tensor,
b: torch.Tensor,
sfb: torch.Tensor,
d: torch.Tensor,
m_indices: torch.Tensor,
major_a: MajorTypeAB,
major_b: MajorTypeAB,
compiled_dims: str,
):
m, k = a.shape
num_groups, n, _ = b.shape
major_d = MajorTypeCD.NMajor
# K must be aligned to 128
aligned_k = round_up(k, 128)
(
(
num_sms,
block_m,
block_n,
block_k,
num_stages,
multicast_config,
smem_config,
),
static_kwargs,
) = m_grouped_fp8_gemm_nt_contiguous_static_kwargs_gen(
m,
n,
k,
aligned_k,
num_groups,
major_a,
major_b,
major_d,
compiled_dims,
d.dtype,
)
# NOTES: you cannot distinguish groups for A, SFA, and D
tensor_map_a = make_tma_a_desc(
major_a,
a,
m,
k,
multicast_config.get_ab_load_block_m(block_m),
block_k,
a.stride(major_a.non_contiguous_dim()),
num_groups=1,
swizzle_mode=smem_config.swizzle_a_mode,
)
tensor_map_b = make_tma_b_desc(
major_b,
b,
n,
k,
multicast_config.get_ab_load_block_n(block_n),
block_k,
b.stride(major_b.non_contiguous_dim()),
num_groups=num_groups,
swizzle_mode=smem_config.swizzle_b_mode,
)
tensor_map_d = make_tma_cd_desc(
major_d,
d,
m,
n,
block_m,
block_n,
d.stride(major_d.non_contiguous_dim()),
num_groups=1,
swizzle_mode=smem_config.swizzle_cd_mode,
)
tensor_map_sfa = make_tma_sf_desc(
MajorTypeAB.MNMajor,
sfa,
m,
k,
block_m,
block_k,
num_groups=1,
swizzle_mode=smem_config.swizzle_sf_mode,
)
tensor_map_sfb = make_tma_sf_desc(
MajorTypeAB.MNMajor,
sfb,
n,
k,
block_n,
block_k,
num_groups=num_groups,
swizzle_mode=smem_config.swizzle_sf_mode,
)
all_kwargs = {
**static_kwargs,
# Runtime arguments
"GROUPED_LAYOUT": m_indices,
"NUM_SMS": num_sms,
"SMEM_SIZE": smem_config.smem_size,
"TENSOR_MAP_A": tensor_map_a,
"TENSOR_MAP_B": tensor_map_b,
"TENSOR_MAP_SFA": tensor_map_sfa,
"TENSOR_MAP_SFB": tensor_map_sfb,
"TENSOR_MAP_C": tensor_map_d,
"TENSOR_MAP_D": tensor_map_d,
"STREAM": torch.cuda.current_stream().cuda_stream,
"DEVICE_INDEX": d.device.index,
}
return static_kwargs, all_kwargs
def m_grouped_fp8_gemm_nt_contiguous_sm100(
a: torch.Tensor,
sfa: torch.Tensor,
b: torch.Tensor,
sfb: torch.Tensor,
d: torch.Tensor,
m_indices: torch.Tensor,
major_a: MajorTypeAB,
major_b: MajorTypeAB,
compiled_dims: str,
) -> None:
static_kwargs, all_kwargs = m_grouped_fp8_gemm_nt_contiguous_kwargs_gen(
a, sfa, b, sfb, d, m_indices, major_a, major_b, compiled_dims
)
# Generate, build and run the kernel
code = SM100FP8GemmRuntime.generate(static_kwargs)
runtime = load("fp8_m_grouped_gemm", code)
runtime(**all_kwargs)
def m_grouped_fp8_gemm_nt_masked_static_kwargs_gen(
m: int,
n: int,
k: int,
expected_m: int,
aligned_k: int,
num_groups: int,
major_a: MajorTypeAB,
major_b: MajorTypeAB,
major_d: MajorTypeCD,
compiled_dims: str,
output_dtype: torch.dtype,
):
num_sms = torch.cuda.get_device_properties(device="cuda").multi_processor_count
num_sms, block_m, block_n, block_k, num_stages, multicast_config, smem_config = (
get_best_configs(
GemmType.GroupedMasked,
expected_m,
n,
k,
num_groups,
major_a,
major_b,
major_d,
torch.float8_e4m3fn,
output_dtype,
num_sms,
)
)
if num_groups > 1:
assert m % block_m == 0
kwargs = {
# Templated or runtime arguments according to the `COMPILED_DIMS`
"COMPILED_DIMS": compiled_dims,
"M": m,
"N": n,
"K": aligned_k,
# Templated arguments
"GEMM_TYPE": GemmType.GroupedMasked,
"NUM_NON_EPILOGUE_THREADS": 128,
"NUM_EPILOGUE_THREADS": 128,
"MAJOR_A": major_a,
"MAJOR_B": major_b,
"NUM_GROUPS": num_groups,
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"BLOCK_K": block_k,
"NUM_STAGES": num_stages,
"NUM_LAST_STAGES": ceil_div(k, block_k) % num_stages,
"SWIZZLE_A_MODE": smem_config.swizzle_a_mode,
"SWIZZLE_B_MODE": smem_config.swizzle_b_mode,
"SWIZZLE_CD_MODE": smem_config.swizzle_cd_mode,
"NUM_MULTICAST": multicast_config.num_multicast,
"IS_MULTICAST_ON_A": multicast_config.is_multicast_on_a,
"WITH_ACCUMULATION": False,
"CD_DTYPE_T": output_dtype,
}
return (
num_sms,
block_m,
block_n,
block_k,
num_stages,
multicast_config,
smem_config,
), kwargs
def m_grouped_fp8_gemm_nt_masked_kwargs_gen(
a: torch.Tensor,
sfa: torch.Tensor,
b: torch.Tensor,
sfb: torch.Tensor,
d: torch.Tensor,
masked_m: torch.Tensor,
expected_m: int,
major_a: MajorTypeAB,
major_b: MajorTypeAB,
compiled_dims: str,
):
num_groups, m, k = a.shape
_, n, _ = b.shape
major_d = MajorTypeCD.NMajor
# K must be aligned to 128
aligned_k = round_up(k, 128)
(
(
num_sms,
block_m,
block_n,
block_k,
num_stages,
multicast_config,
smem_config,
),
static_kwargs,
) = m_grouped_fp8_gemm_nt_masked_static_kwargs_gen(
m,
n,
k,
expected_m,
aligned_k,
num_groups,
major_a,
major_b,
major_d,
compiled_dims,
d.dtype,
)
tensor_map_a = make_tma_a_desc(
major_a,
a,
m,
k,
multicast_config.get_ab_load_block_m(block_m),
block_k,
a.stride(major_a.non_contiguous_dim()),
num_groups,
smem_config.swizzle_a_mode,
)
tensor_map_b = make_tma_b_desc(
major_b,
b,
n,
k,
multicast_config.get_ab_load_block_n(block_n),
block_k,
b.stride(major_b.non_contiguous_dim()),
num_groups,
smem_config.swizzle_b_mode,
)
tensor_map_d = make_tma_cd_desc(
major_d,
d,
m,
n,
block_m,
block_n,
d.stride(major_d.non_contiguous_dim()),
num_groups,
smem_config.swizzle_cd_mode,
)
tensor_map_sfa = make_tma_sf_desc(
MajorTypeAB.MNMajor,
sfa,
m,
k,
block_m,
block_k,
num_groups,
smem_config.swizzle_sf_mode,
)
tensor_map_sfb = make_tma_sf_desc(
MajorTypeAB.MNMajor,
sfb,
n,
k,
block_n,
block_k,
num_groups,
smem_config.swizzle_sf_mode,
)
all_kwargs = {
**static_kwargs,
# Runtime arguments
"GROUPED_LAYOUT": masked_m,
"NUM_SMS": num_sms,
"SMEM_SIZE": smem_config.smem_size,
"TENSOR_MAP_A": tensor_map_a,
"TENSOR_MAP_B": tensor_map_b,
"TENSOR_MAP_SFA": tensor_map_sfa,
"TENSOR_MAP_SFB": tensor_map_sfb,
"TENSOR_MAP_C": tensor_map_d,
"TENSOR_MAP_D": tensor_map_d,
"STREAM": torch.cuda.current_stream().cuda_stream,
"DEVICE_INDEX": d.device.index,
}
return static_kwargs, all_kwargs
def m_grouped_fp8_gemm_nt_masked_sm100(
a: torch.Tensor,
sfa: torch.Tensor,
b: torch.Tensor,
sfb: torch.Tensor,
d: torch.Tensor,
masked_m: torch.Tensor,
expected_m: int,
major_a: MajorTypeAB,
major_b: MajorTypeAB,
compiled_dims: str,
) -> None:
static_kwargs, all_kwargs = m_grouped_fp8_gemm_nt_masked_kwargs_gen(
a, sfa, b, sfb, d, masked_m, expected_m, major_a, major_b, compiled_dims
)
# Generate, build and run the kernel
code = SM100FP8GemmRuntime.generate(static_kwargs)
runtime = load("fp8_m_grouped_gemm", code)
runtime(**all_kwargs)
def m_grouped_fp8_gemm_nt_contiguous(
a_fp8: Tuple[torch.Tensor, torch.Tensor],
b_fp8: Tuple[torch.Tensor, torch.Tensor],
d: torch.Tensor,
m_indices: torch.Tensor,
recipe: Optional[Tuple[int, int, int]] = None,
compiled_dims: str = "nk",
) -> None:
# Compiled dims can be upper cases
compiled_dims = compiled_dims.lower()
# NOTES: shape must be `[M, K] @ [G, N, K].mT`
major_a = get_major_type_ab(a_fp8[0])
major_b = get_major_type_ab(b_fp8[0])
assert major_a == MajorTypeAB.KMajor
if must_be_k_major():
assert major_b == MajorTypeAB.KMajor
assert m_indices.is_contiguous()
a, sfa = a_fp8
b, sfb = b_fp8
m, k = a.shape
num_groups, n, k_ = b.shape
m_, n_ = d.shape
m__ = m_indices.numel()
# Type and shape checks
assert m == m_ == m__ and n == n_ and k == k_
assert n > 0 and k > 0 and num_groups > 0
assert a.dtype == torch.float8_e4m3fn
assert b.dtype == torch.float8_e4m3fn
assert d.dtype == torch.bfloat16
assert m_indices.dtype == torch.int32
# D must be N-major
assert get_major_type_cd(d) == MajorTypeCD.NMajor
# Do nothing if the problem is empty
if m == 0:
return
# Transform SFA and SFB into compute-required layout
recipe = get_default_recipe(sfa.dtype, sfb.dtype) if recipe is None else recipe
sfa = transform_sf_into_required_layout(sfa, mn=m, k=k, recipe=recipe, is_sfa=True)
sfb = transform_sf_into_required_layout(
sfb, mn=n, k=k, recipe=recipe, num_groups=num_groups, is_sfa=False
)
impl = {
"100a": functools.partial(
m_grouped_fp8_gemm_nt_contiguous_sm100,
major_a=major_a,
major_b=major_b,
compiled_dims=compiled_dims,
)
}[get_device_arch()]
impl(a, sfa, b, sfb, d, m_indices)
def m_grouped_fp8_gemm_nt_masked(
a_fp8: Tuple[torch.Tensor, torch.Tensor],
b_fp8: Tuple[torch.Tensor, torch.Tensor],
d: torch.Tensor,
masked_m: torch.Tensor,
expected_m: int,
recipe: Optional[Tuple[int, int, int]] = None,
compiled_dims: str = "nk",
) -> None:
# Compiled dims can be upper cases
compiled_dims = compiled_dims.lower()
# NOTES: shape must be `[G, M, K] @ [G, N, K].mT`
major_a = get_major_type_ab(a_fp8[0])
major_b = get_major_type_ab(b_fp8[0])
assert major_a == major_b == MajorTypeAB.KMajor
assert masked_m.is_contiguous()
a, sfa = a_fp8
b, sfb = b_fp8
num_groups, m, k = a.shape
num_groups_, n, k_ = b.shape
num_groups__, m_, n_ = d.shape
num_groups___ = masked_m.numel()
# Type and shape checks
assert num_groups == num_groups_ == num_groups__ == num_groups___
assert m == m_ and n == n_ and k == k_
assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
assert a.dtype == torch.float8_e4m3fn
assert b.dtype == torch.float8_e4m3fn
assert d.dtype == torch.bfloat16
assert masked_m.dtype == torch.int32
# D must be N-major
assert get_major_type_cd(d) == MajorTypeCD.NMajor
# Transform SFA and SFB into compute-required layout
recipe = get_default_recipe(sfa.dtype, sfb.dtype) if recipe is None else recipe
sfa = transform_sf_into_required_layout(
sfa, mn=m, k=k, recipe=recipe, num_groups=num_groups, is_sfa=True
)
sfb = transform_sf_into_required_layout(
sfb, mn=n, k=k, recipe=recipe, num_groups=num_groups, is_sfa=False
)
impl = {
"100a": functools.partial(
m_grouped_fp8_gemm_nt_masked_sm100,
major_a=major_a,
major_b=major_b,
compiled_dims=compiled_dims,
)
}[get_device_arch()]
impl(a, sfa, b, sfb, d, masked_m, expected_m)
class KernelMap:
def __init__(self, sha256: str):
self.sha256 = sha256
self.indice = None
def init_indices(self):
indice_path = ArtifactPath.DEEPGEMM + "kernel_map"
assert get_cubin(indice_path, self.sha256, file_extension=".json"), (
"cubin kernel map file not found, nor downloaded with matched sha256"
)
path = FLASHINFER_CUBIN_DIR / f"{indice_path}.json"
assert path.exists()
with open(path, "r") as f:
self.indice = json.load(f)
def __iter__(self):
if self.indice is None:
self.init_indices()
for name in self.indice:
yield name
def __getitem__(self, key):
if self.indice is None:
self.init_indices()
return self.indice[key]
KERNEL_MAP = KernelMap(MetaInfoHash.DEEPGEMM)