""" MIT License Copyright (c) 2025 DeepSeek Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ # Imported and adapted from DeepGEMM. import ctypes import enum import functools import hashlib import json from typing import Any, Dict, Optional, Tuple try: import cuda.bindings.driver as cbd except ImportError as e: raise ImportError( "Could not import the 'cuda' module. " "Please install cuda-python that matches your CUDA version." ) from e import torch from .artifacts import ArtifactPath, MetaInfoHash from .cuda_utils import checkCudaErrors from .jit.cubin_loader import get_cubin from .jit.env import FLASHINFER_CUBIN_DIR from .utils import ceil_div, round_up class GemmType(enum.Enum): Normal = 0 GroupedContiguous = 1 GroupedMasked = 2 def __str__(self) -> str: return { 0: "GemmType::Normal", 1: "GemmType::GroupedContiguous", 2: "GemmType::GroupedMasked", }[self.value] class MajorTypeAB(enum.Enum): KMajor = 0 MNMajor = 1 def shape_direction(self): return 1 if self.value == 0 else -1 def non_contiguous_dim(self): return -2 if self.value == 0 else -1 def __str__(self) -> str: return {0: "cute::UMMA::Major::K", 1: "cute::UMMA::Major::MN"}[self.value] class MajorTypeCD(enum.Enum): NMajor = 0 MMajor = 1 def non_contiguous_dim(self): return -2 if self.value == 0 else -1 def major_check(t: torch.Tensor): assert t.dim() in (2, 3) if t.dim() == 3: assert t.stride(0) == t.size(-2) * t.size(-1), ( "Grouped dimension cannot have abnormal stride" ) assert t.stride(-2) == 1 or t.stride(-1) == 1 def get_major_type_ab(t: torch.Tensor): major_check(t) return MajorTypeAB.KMajor if t.stride(-1) == 1 else MajorTypeAB.MNMajor def get_major_type_cd(t: torch.Tensor): major_check(t) return MajorTypeCD.NMajor if t.stride(-1) == 1 else MajorTypeCD.MMajor def get_element_size(dtype: torch.dtype): return { torch.float8_e4m3fn: 1, torch.bfloat16: 2, torch.float: 4, }[dtype] def get_m_alignment_for_contiguous_layout(): return 128 def get_tma_aligned_size(x: int, element_size: int) -> int: tma_alignment_bytes = 16 assert tma_alignment_bytes % element_size == 0 alignment = tma_alignment_bytes // element_size return round_up(x, alignment) def get_col_major_tma_aligned_packed_tensor(x: torch.Tensor) -> torch.Tensor: # NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA assert x.dtype == torch.float and x.dim() in (2, 3) # First, convert into UE8M0 `uint8_t` ue8m0_tensor = (x.view(torch.int) >> 23).to(torch.uint8) # Second, make padded packed tensors mn, k = x.shape[-2], x.shape[-1] remove_dim = False if x.dim() == 2: x, remove_dim = x.unsqueeze(0), True b = x.shape[0] aligned_mn = get_tma_aligned_size(mn, 4) aligned_k = round_up(k, 4) padded = torch.zeros((b, aligned_mn, aligned_k), device=x.device, dtype=torch.uint8) padded[:, :mn, :k] = ue8m0_tensor padded = padded.view(-1).view(dtype=torch.int).view(b, aligned_mn, aligned_k // 4) # Finally, transpose transposed = torch.transpose( torch.empty((b, aligned_k // 4, aligned_mn), device=x.device, dtype=torch.int), 1, 2, ) transposed[:, :, :] = padded aligned_x = transposed[:, :mn, :] return aligned_x.squeeze(0) if remove_dim else aligned_x def check_sf_layout( sf: torch.Tensor, mn: int, k: int, gran: Tuple[int, int], num_groups: Optional[int], tma_stride_check: bool = False, type_check: Optional[torch.dtype] = None, ) -> torch.Tensor: # Type check if type_check is not None: assert sf.dtype == type_check # Always do shape checks assert sf.dtype in (torch.float, torch.int) assert sf.dim() == int(num_groups is not None) + 2 if num_groups is not None: assert sf.size(-3) == num_groups assert sf.size(-2) == ceil_div(mn, gran[0]) assert sf.size(-1) == ceil_div(k, gran[1] * (1 if sf.dtype == torch.float else 4)) # TMA stride checks: TMA aligned and MN-major if tma_stride_check: if num_groups is not None: assert sf.stride(-3) == sf.stride(-1) * sf.size(-1) assert sf.stride(-2) == 1 assert sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size()) return sf def transform_sf_into_required_layout( sf: torch.Tensor, mn: int, k: int, recipe: Tuple[int, int, int], num_groups: Optional[int] = None, is_sfa: bool = False, ): gran = (recipe[0 if is_sfa else 1], recipe[2]) should_skip_transform = ( sf.dtype == torch.int and gran == (1, 128) and get_device_arch() == "100a" ) or (sf.dtype == torch.int and gran == (128, 128) and get_device_arch() == "100a") if not should_skip_transform: # Pre-transform checks check_sf_layout(sf, mn=mn, k=k, gran=gran, num_groups=num_groups) # (FP32, 1, 128) on Hopper: transform to TMA-aligned and MN-major if sf.dtype == torch.float and gran == (1, 128) and get_device_arch() == "90a": raise NotImplementedError # (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major if sf.dtype == torch.float and gran == (1, 128) and get_device_arch() == "100a": sf = get_col_major_tma_aligned_packed_tensor(sf) return check_sf_layout( sf, mn=mn, k=k, gran=(1, 128), num_groups=num_groups, tma_stride_check=True, type_check=torch.int, ) # (FP32, 128, 128) on Hopper: no need to transform, check shape and whatever-major if sf.dtype == torch.float and gran == (128, 128) and get_device_arch() == "90a": raise NotImplementedError # (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major if sf.dtype == torch.float and gran == (128, 128) and get_device_arch() == "100a": sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128) sf = get_col_major_tma_aligned_packed_tensor(sf) return check_sf_layout( sf, mn=mn, k=k, gran=(1, 128), num_groups=num_groups, tma_stride_check=True, type_check=torch.int, ) if should_skip_transform: # TODO: add transpose kernel if SF layout is not satisfied return check_sf_layout( sf, mn=mn, k=k, gran=(1, 128), num_groups=num_groups, tma_stride_check=True, type_check=torch.int, ) AssertionError(f"Unknown cases: {sf.dtype=}, {gran=}, arch={get_device_arch()}") @functools.lru_cache(maxsize=None) def get_device_arch(): major, minor = torch.cuda.get_device_capability() suffix = "a" if major >= 9 else "" return f"{major * 10 + minor}{suffix}" def hash_to_hex(s: str) -> str: md5 = hashlib.md5() md5.update(s.encode("utf-8")) return md5.hexdigest()[0:12] @functools.lru_cache(maxsize=None) def must_be_k_major() -> bool: return { "90a": True, "100a": False, }[get_device_arch()] @functools.lru_cache(maxsize=None) def get_default_recipe( sfa_dtype: torch.dtype, sfb_dtype: torch.dtype ) -> Tuple[int, int, int]: assert sfa_dtype in (torch.float, torch.int) return { ("90a", torch.float): (1, 128, 128), ("100a", torch.float): (1, 128, 128), ("100a", torch.int): (1, 1, 128), }[(get_device_arch(), sfb_dtype)] class MulticastConfig: def __init__(self, num_multicast: int, is_multicast_on_a: bool): self.num_multicast = num_multicast self.is_multicast_on_a = is_multicast_on_a def get_ab_load_block_m(self, block_m: int): # NOTES: this for >= SM100 only assert get_device_arch() != "90a" return block_m // (self.num_multicast if self.is_multicast_on_a else 1) def get_ab_load_block_n(self, block_n: int): # NOTES: this for >= SM100 only assert get_device_arch() != "90a" return block_n // (1 if self.is_multicast_on_a else self.num_multicast) class SharedMemoryConfig: def __init__( self, smem_size: int, swizzle_a_mode: int, swizzle_b_mode: int, swizzle_cd_mode: int, ): self.smem_size = smem_size self.swizzle_a_mode = swizzle_a_mode self.swizzle_b_mode = swizzle_b_mode # NOTES: sometimes the default swizzling pattern maybe not compatible (e.g., FP32 output) self.swizzle_cd_mode = swizzle_cd_mode # TODO: swizzle SF as well self.swizzle_sf_mode = 0 assert self.swizzle_a_mode != 0 assert self.swizzle_b_mode != 0 assert self.swizzle_cd_mode > 16 assert self.swizzle_sf_mode == 0 def is_multicast_legal( shape_dim: int, block_dim: int, num_multicast: int, num_sms: int, require_divisible: bool = False, ) -> bool: divisible = ( ceil_div(shape_dim, block_dim) % num_multicast == 0 or not require_divisible ) return divisible and num_sms % num_multicast == 0 def get_swizzle_mode(block_size: int, elem_size: int) -> int: # `> 0` means interleaving # 16B actually means non-swizzling (but interleaving) for mode_bytes in (128, 64, 32, 16): if (block_size * elem_size) % mode_bytes == 0: return mode_bytes AssertionError("Invalid mode") return 0 def get_sf_aligned_block_sizes(block_m: int, block_n: int, ab_dtype: torch.dtype): num_utccp_aligned_elems = 128 assert block_m % num_utccp_aligned_elems == 0 return { torch.bfloat16: (0, 0), torch.float8_e4m3fn: ( round_up(block_m, num_utccp_aligned_elems), round_up(block_n, num_utccp_aligned_elems), ), }[ab_dtype] def is_tmem_size_legal(block_m: int, block_n: int, ab_dtype: torch.float): # M waves or epilogue stages (* 2), SFA and SFB sf_block_m, sf_block_n = get_sf_aligned_block_sizes(block_m, block_n, ab_dtype) return ((2 * block_n) + (sf_block_m // 32) + (sf_block_n // 32)) <= 512 def get_smem_config( block_m: int, block_n: int, block_k: int, major_a: MajorTypeAB, major_b: MajorTypeAB, major_d: MajorTypeCD, ab_dtype: torch.dtype, cd_dtype: torch.dtype, num_stages: int, multicast_config: MulticastConfig, ) -> SharedMemoryConfig: assert major_d == MajorTypeCD.NMajor ab_elem_size = get_element_size(ab_dtype) cd_elem_size = get_element_size(cd_dtype) load_block_m = multicast_config.get_ab_load_block_m(block_m) load_block_n = multicast_config.get_ab_load_block_n(block_n) swizzle_a_mode = get_swizzle_mode( block_k if major_a == MajorTypeAB.KMajor else load_block_m, ab_elem_size ) swizzle_b_mode = get_swizzle_mode( block_k if major_b == MajorTypeAB.KMajor else load_block_n, ab_elem_size ) swizzle_cd_mode = get_swizzle_mode( block_n if major_d == MajorTypeCD.NMajor else block_m, cd_elem_size ) # 2 stages of STSM and TMA store # TODO: consider other layouts layout_ad_m = 128 smem_d = min(block_m, layout_ad_m) * swizzle_cd_mode * 2 # A/B shared memory smem_a_per_stage = load_block_m * block_k * ab_elem_size smem_b_per_stage = load_block_n * block_k * ab_elem_size # SF shared memory must be aligned to UTCCP # Each stage must prefetch next 4 stages' SF (including the current) sf_block_m, sf_block_n = get_sf_aligned_block_sizes(block_m, block_n, ab_dtype) smem_scales_a_per_stage = sf_block_m * 4 smem_scales_b_per_stage = sf_block_n * 4 # TODO: remove SF barriers for BF16 GEMMs # TMA full/empty barriers, with-SF full barriers, tensor memory full/empty barriers, accumulation full barrier # NOTES: some shapes may only have 1 epilogue stage, but we still allocate space for 2 stages # NOTES: cases without accumulation will not use the accumulation full barrier smem_barrier = num_stages * 8 * 3 + 2 * 8 * 2 + 8 smem_tmem_ptr = 4 # Sum them up smem_size = 0 smem_size += smem_d smem_size += num_stages * smem_a_per_stage smem_size += num_stages * smem_b_per_stage smem_size += num_stages * smem_scales_a_per_stage smem_size += num_stages * smem_scales_b_per_stage smem_size += smem_barrier smem_size += smem_tmem_ptr return SharedMemoryConfig( smem_size, swizzle_a_mode, swizzle_b_mode, swizzle_cd_mode ) @functools.lru_cache(maxsize=None) def get_best_configs( gemm_type: GemmType, m: int, n: int, k: int, num_groups: int, major_a: MajorTypeAB, major_b: MajorTypeAB, major_d: MajorTypeCD, ab_dtype: torch.dtype, cd_dtype: torch.dtype, num_sms: int, ) -> Tuple[int, int, int, int, int, MulticastConfig, SharedMemoryConfig]: assert ab_dtype == torch.float8_e4m3fn assert cd_dtype in (torch.bfloat16, torch.float) # `BLOCK_M` and `BLOCK_N` are selected according to MMA instructions block_ms: Tuple[int, ...] = None if gemm_type == GemmType.GroupedContiguous: block_ms = (get_m_alignment_for_contiguous_layout(),) else: block_ms = (128,) if major_b == MajorTypeAB.KMajor else (128, 256) # NOTES: some `% 32 == 16` cases are not compatible with 2-CTA TMA swizzling block_ns = ( tuple(range(16, 257, 16)) if major_b == MajorTypeAB.KMajor else tuple(range(32, 257, 32)) ) # `BLOCK_K` is selected in a fixed manner block_k = 128 // get_element_size(ab_dtype) fix_wave_saturate = lambda x: num_sms if x == 0 else x get_num_waves = lambda bm, bn: ( ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None ) get_last_wave_util = lambda bm, bn: fix_wave_saturate( (ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms ) # Decide block sizes by waves # TODO: move block size search into `common.py` best_block_m, best_block_n = None, None for block_m in block_ms: for block_n in block_ns: success = False num_waves, best_num_waves = ( get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n), ) if ( best_block_m is None or best_block_n is None or num_waves < best_num_waves ): success = True elif num_waves == best_num_waves: # Check last wave utilization util = get_last_wave_util(block_m, block_n) best_util = get_last_wave_util(best_block_m, best_block_n) success = util > best_util if util == best_util: # Case 1: same `block_m`, smaller `block_n` (wasted) success |= block_m == best_block_m and block_n < best_block_n # Case 2: same `block_n`, smaller `block_m` (wasted) success |= block_n == best_block_n and block_m < best_block_m # Case 3: different for both `block_m` and `block_n`, larger `block_n` is better success |= block_m != best_block_m and block_n > best_block_n success &= is_tmem_size_legal(block_m, block_n, ab_dtype) best_block_m, best_block_n = ( (block_m, block_n) if success else (best_block_m, best_block_n) ) assert best_block_m is not None and best_block_n is not None # Decide the number of TMA multicasts and whether broadcast on A best_multicast_config = MulticastConfig(1, True) # Try to multicast on the larger block side first is_legal = { # TODO: support other `tcgen05` layouts "A": False, "B": is_multicast_legal(m, best_block_m, 2, num_sms, True) and gemm_type == GemmType.Normal, } for i in ("A", "B") if best_block_m > best_block_n else ("B", "A"): if m >= 512 and is_legal[i]: best_multicast_config = MulticastConfig(2, i == "A") break # Always pick the longest one # NOTES: for double B scales, the best number of stages may be reduced # TODO: move stage search into `common.py` best_num_stages, best_smem_config, sm100_capacity = None, None, 232448 stage_candidates = tuple( filter(lambda s: s <= max(k // 128, 1), (8, 7, 6, 5, 4, 3, 2, 1)) ) for num_stages in stage_candidates: best_smem_config = get_smem_config( best_block_m, best_block_n, block_k, major_a, major_b, major_d, ab_dtype, cd_dtype, num_stages, best_multicast_config, ) if best_smem_config.smem_size <= sm100_capacity: best_num_stages = num_stages break assert best_smem_config is not None assert best_num_stages is not None # Recompute the minimal number of SMs required # NOTES: less L2 cache usage and less GPU frequency drop # TODO: move min SM fix into `common.py` num_waves = get_num_waves(best_block_m, best_block_n) num_min_sms = ceil_div( ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves ) num_min_sms = ( ceil_div(num_min_sms, best_multicast_config.num_multicast) * best_multicast_config.num_multicast ) assert num_min_sms <= num_sms return ( num_min_sms, best_block_m, best_block_n, block_k, best_num_stages, best_multicast_config, best_smem_config, ) tmap_type_map: Dict[Any, str] = { torch.int8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, torch.int16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16, torch.int32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT32, torch.int64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT64, torch.uint8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, torch.uint16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16, torch.uint32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT32, torch.uint64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT64, torch.float32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT32, torch.float16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16, torch.bfloat16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16, torch.float8_e4m3fn: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, torch.float8_e4m3fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, torch.float8_e5m2: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, torch.float8_e5m2fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8, } swizzle_type_map = { 0: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE, 16: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE, 32: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_32B, 64: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_64B, 128: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B, } def make_tma_xd_desc( t: torch.Tensor, gmem_dims: Tuple[cbd.cuuint64_t, ...], gmem_strides: Tuple[cbd.cuuint64_t, ...], smem_dims: Tuple[cbd.cuuint32_t, ...], swizzle_type: cbd.CUtensorMapSwizzle, ) -> cbd.CUtensorMap: num_dims = len(gmem_dims) assert len(gmem_strides) == num_dims - 1 assert len(smem_dims) == num_dims tensor_dtype = tmap_type_map[t.dtype] tensor_map = checkCudaErrors( cbd.cuTensorMapEncodeTiled( tensor_dtype, num_dims, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, (cbd.cuuint32_t(1),) * num_dims, cbd.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type, cbd.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B, cbd.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE, ) ) return tensor_map def make_tma_2d_desc( t: torch.Tensor, gmem_inner_dim: int, gmem_outer_dim: int, smem_inner_dim: int, smem_outer_dim: int, gmem_outer_stride: int, swizzle_mode: int, ) -> cbd.CUtensorMap: # For swizzling pattern, multiple TMAs are required if swizzle_mode != 0: assert swizzle_mode % t.element_size() == 0 smem_inner_dim = swizzle_mode // t.element_size() gmem_dims = (cbd.cuuint64_t(gmem_inner_dim), cbd.cuuint64_t(gmem_outer_dim)) gmem_strides = (cbd.cuuint64_t(gmem_outer_stride * t.element_size()),) smem_dims = (cbd.cuuint32_t(smem_inner_dim), cbd.cuuint32_t(smem_outer_dim)) return make_tma_xd_desc( t, gmem_dims, gmem_strides, smem_dims, swizzle_type_map[swizzle_mode] ) def make_tma_a_desc( major_type: MajorTypeAB, t: torch.Tensor, shape_m: int, shape_k: int, block_m: int, block_k: int, outer_stride: int, num_groups: int, swizzle_mode: int, ) -> cbd.CUtensorMap: if num_groups > 1: assert major_type == MajorTypeAB.KMajor gmem_inner_dim, gmem_outer_dim = (shape_k, shape_m * num_groups)[ :: major_type.shape_direction() ] smem_inner_dim, smem_outer_dim = (block_k, block_m)[:: major_type.shape_direction()] return make_tma_2d_desc( t, gmem_inner_dim, gmem_outer_dim, smem_inner_dim, smem_outer_dim, outer_stride, swizzle_mode, ) def make_tma_b_desc( major_type: MajorTypeAB, t: torch.Tensor, shape_n: int, shape_k: int, block_n: int, block_k: int, outer_stride: int, num_groups: int, swizzle_mode: int, ) -> cbd.CUtensorMap: # `num_groups` is always applied into the outer dimensions io_shapes = (shape_k, shape_n)[:: major_type.shape_direction()] gmem_inner_dim, gmem_outer_dim = (io_shapes[0], io_shapes[1] * num_groups) smem_inner_dim, smem_outer_dim = (block_k, block_n)[:: major_type.shape_direction()] return make_tma_2d_desc( t, gmem_inner_dim, gmem_outer_dim, smem_inner_dim, smem_outer_dim, outer_stride, swizzle_mode, ) def make_tma_cd_desc( major_type: MajorTypeCD, t: torch.Tensor, shape_m: int, shape_n: int, block_m: int, block_n: int, outer_stride: int, num_groups: int, swizzle_mode: int, ) -> cbd.CUtensorMap: assert major_type == MajorTypeCD.NMajor # Swizzling requires the inner box dim to be less or equal than `kSwizzleCDMode` # bytes, so `BLOCK_N * sizeof(T) / kSwizzleCDMode` TMA stores are required layout_ad_m = 128 return make_tma_2d_desc( t, shape_n, shape_m * num_groups, block_n, min(block_m, layout_ad_m), outer_stride, swizzle_mode, ) def make_tma_sf_desc( major_type: MajorTypeAB, t: torch.Tensor, shape_mn: int, shape_k: int, block_mn: int, block_k: int, num_groups: int, swizzle_mode: int, ) -> cbd.CUtensorMap: assert major_type == MajorTypeAB.MNMajor # TODO: maybe swizzle SF as well assert swizzle_mode == 0 # Make TMA aligned to 16 bytes shape_mn = get_tma_aligned_size(shape_mn, t.element_size()) return make_tma_2d_desc( t, shape_mn, ceil_div(shape_k, block_k * 4) * num_groups, block_mn, 1, shape_mn, swizzle_mode, ) # Map some common Python types into C types pytypes_to_ctypes = { True: "true", False: "false", torch.bfloat16: "cutlass::bfloat16_t", torch.float: "float", } RUNTIME_CACHE = {} class SM100FP8GemmRuntime: def __init__(self, path: str, symbol: str) -> None: self.path = path self.lib = None self.kernel = None self.symbol = symbol # Store a reference to the cleanup function to avoid import issues during shutdown self._cleanup_func = cbd.cuLibraryUnload def __call__(self, **kwargs) -> cbd.CUresult: # Load CUBIN if self.kernel is None: # Load CUBIN path = bytes(self.path, encoding="utf-8") self.lib = checkCudaErrors( cbd.cuLibraryLoadFromFile(path, [], [], 0, [], [], 0) ) self.kernel = checkCudaErrors( cbd.cuLibraryGetKernel(self.lib, bytes(self.symbol, encoding="utf-8")) ) # noinspection PyArgumentList return self.launch(self.kernel, kwargs) def __del__(self) -> None: if self.lib is not None: try: checkCudaErrors(self._cleanup_func(self.lib)) except Exception as e: # Ignore any errors during shutdown print(f"Failed to delete SM100FP8GemmRuntime: {e}") @staticmethod def generate(kwargs: Dict[str, Any]) -> str: assert kwargs["CD_DTYPE_T"] in (torch.bfloat16, torch.float) code = f""" #ifdef __CUDACC_RTC__ #include #else #include #include #endif #include using namespace deep_gemm; static void __instantiate_kernel() {{ auto ptr = reinterpret_cast(&sm100_fp8_gemm_1d1d_impl< {kwargs["MAJOR_A"]}, {kwargs["MAJOR_B"]}, {kwargs["M"] if "m" in kwargs["COMPILED_DIMS"] else 0}, {kwargs["N"] if "n" in kwargs["COMPILED_DIMS"] else 0}, {kwargs["K"] if "k" in kwargs["COMPILED_DIMS"] else 0}, {kwargs["BLOCK_M"]}, {kwargs["BLOCK_N"]}, {kwargs["BLOCK_K"]}, {kwargs["NUM_GROUPS"]}, {kwargs["SWIZZLE_A_MODE"]}, {kwargs["SWIZZLE_B_MODE"]}, {kwargs["SWIZZLE_CD_MODE"]}, {kwargs["NUM_STAGES"]}, {kwargs["NUM_LAST_STAGES"]}, {kwargs["NUM_NON_EPILOGUE_THREADS"]}, {kwargs["NUM_EPILOGUE_THREADS"]}, {kwargs["NUM_MULTICAST"]}, {pytypes_to_ctypes[kwargs["IS_MULTICAST_ON_A"]]}, {kwargs["GEMM_TYPE"]}, {pytypes_to_ctypes[kwargs["WITH_ACCUMULATION"]]}, {pytypes_to_ctypes[kwargs["CD_DTYPE_T"]]} >); }}; """ return code # noinspection PyMethodOverriding @staticmethod def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult: checkCudaErrors( cbd.cuKernelSetAttribute( cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, kwargs["SMEM_SIZE"], kernel, cbd.CUdevice(kwargs["DEVICE_INDEX"]), ) ) attr_val = cbd.CUlaunchAttributeValue() attr_val.clusterDim.x = kwargs["NUM_MULTICAST"] attr_val.clusterDim.y = 1 attr_val.clusterDim.z = 1 attr = cbd.CUlaunchAttribute() attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION attr.value = attr_val config = cbd.CUlaunchConfig() config.numAttrs = 1 config.attrs = [attr] config.gridDimX = kwargs["NUM_SMS"] config.gridDimY = 1 config.gridDimZ = 1 config.blockDimX = ( kwargs["NUM_NON_EPILOGUE_THREADS"] + kwargs["NUM_EPILOGUE_THREADS"] ) config.blockDimY = 1 config.blockDimZ = 1 config.sharedMemBytes = kwargs["SMEM_SIZE"] config.hStream = kwargs["STREAM"] arg_values = ( kwargs["GROUPED_LAYOUT"].data_ptr(), kwargs["M"], kwargs["N"], kwargs["K"], kwargs["TENSOR_MAP_A"], kwargs["TENSOR_MAP_B"], kwargs["TENSOR_MAP_SFA"], kwargs["TENSOR_MAP_SFB"], kwargs["TENSOR_MAP_C"], kwargs["TENSOR_MAP_D"], ) arg_types = ( ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32, None, None, None, None, None, None, ) return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0) def load_all(): for cubin_name in KERNEL_MAP: if cubin_name in RUNTIME_CACHE: continue symbol, sha256 = KERNEL_MAP[cubin_name] get_cubin(ArtifactPath.DEEPGEMM + cubin_name, sha256) path = FLASHINFER_CUBIN_DIR / f"{ArtifactPath.DEEPGEMM + cubin_name}.cubin" assert path.exists() RUNTIME_CACHE[cubin_name] = SM100FP8GemmRuntime(str(path), symbol) def load(name: str, code: str) -> SM100FP8GemmRuntime: signature = f"{name}$${code}" cubin_name = f"kernel.{name}.{hash_to_hex(signature)}" if cubin_name not in KERNEL_MAP: raise ValueError("cubin not registered") if cubin_name in RUNTIME_CACHE: return RUNTIME_CACHE[cubin_name] symbol, sha256 = KERNEL_MAP[cubin_name] get_cubin(ArtifactPath.DEEPGEMM + cubin_name, sha256) path = FLASHINFER_CUBIN_DIR / f"{ArtifactPath.DEEPGEMM + cubin_name}.cubin" assert path.exists() RUNTIME_CACHE[cubin_name] = SM100FP8GemmRuntime(str(path), symbol) return RUNTIME_CACHE[cubin_name] def m_grouped_fp8_gemm_nt_contiguous_static_kwargs_gen( m: int, n: int, k: int, aligned_k: int, num_groups: int, major_a: MajorTypeAB, major_b: MajorTypeAB, major_d: MajorTypeCD, compiled_dims: str, output_dtype: torch.dtype, ): num_sms = torch.cuda.get_device_properties(device="cuda").multi_processor_count num_sms, block_m, block_n, block_k, num_stages, multicast_config, smem_config = ( get_best_configs( GemmType.GroupedContiguous, m, n, k, num_groups, major_a, major_b, major_d, torch.float8_e4m3fn, output_dtype, num_sms, ) ) kwargs = { # Templated or runtime arguments according to the `COMPILED_DIMS` "COMPILED_DIMS": compiled_dims, "M": m, "N": n, "K": aligned_k, # Templated arguments "GEMM_TYPE": GemmType.GroupedContiguous, "NUM_NON_EPILOGUE_THREADS": 128, "NUM_EPILOGUE_THREADS": 128, "MAJOR_A": major_a, "MAJOR_B": major_b, "NUM_GROUPS": num_groups, "BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "NUM_STAGES": num_stages, "NUM_LAST_STAGES": ceil_div(k, block_k) % num_stages, "SWIZZLE_A_MODE": smem_config.swizzle_a_mode, "SWIZZLE_B_MODE": smem_config.swizzle_b_mode, "SWIZZLE_CD_MODE": smem_config.swizzle_cd_mode, "NUM_MULTICAST": multicast_config.num_multicast, "IS_MULTICAST_ON_A": multicast_config.is_multicast_on_a, "WITH_ACCUMULATION": False, "CD_DTYPE_T": output_dtype, } return ( num_sms, block_m, block_n, block_k, num_stages, multicast_config, smem_config, ), kwargs def m_grouped_fp8_gemm_nt_contiguous_kwargs_gen( a: torch.Tensor, sfa: torch.Tensor, b: torch.Tensor, sfb: torch.Tensor, d: torch.Tensor, m_indices: torch.Tensor, major_a: MajorTypeAB, major_b: MajorTypeAB, compiled_dims: str, ): m, k = a.shape num_groups, n, _ = b.shape major_d = MajorTypeCD.NMajor # K must be aligned to 128 aligned_k = round_up(k, 128) ( ( num_sms, block_m, block_n, block_k, num_stages, multicast_config, smem_config, ), static_kwargs, ) = m_grouped_fp8_gemm_nt_contiguous_static_kwargs_gen( m, n, k, aligned_k, num_groups, major_a, major_b, major_d, compiled_dims, d.dtype, ) # NOTES: you cannot distinguish groups for A, SFA, and D tensor_map_a = make_tma_a_desc( major_a, a, m, k, multicast_config.get_ab_load_block_m(block_m), block_k, a.stride(major_a.non_contiguous_dim()), num_groups=1, swizzle_mode=smem_config.swizzle_a_mode, ) tensor_map_b = make_tma_b_desc( major_b, b, n, k, multicast_config.get_ab_load_block_n(block_n), block_k, b.stride(major_b.non_contiguous_dim()), num_groups=num_groups, swizzle_mode=smem_config.swizzle_b_mode, ) tensor_map_d = make_tma_cd_desc( major_d, d, m, n, block_m, block_n, d.stride(major_d.non_contiguous_dim()), num_groups=1, swizzle_mode=smem_config.swizzle_cd_mode, ) tensor_map_sfa = make_tma_sf_desc( MajorTypeAB.MNMajor, sfa, m, k, block_m, block_k, num_groups=1, swizzle_mode=smem_config.swizzle_sf_mode, ) tensor_map_sfb = make_tma_sf_desc( MajorTypeAB.MNMajor, sfb, n, k, block_n, block_k, num_groups=num_groups, swizzle_mode=smem_config.swizzle_sf_mode, ) all_kwargs = { **static_kwargs, # Runtime arguments "GROUPED_LAYOUT": m_indices, "NUM_SMS": num_sms, "SMEM_SIZE": smem_config.smem_size, "TENSOR_MAP_A": tensor_map_a, "TENSOR_MAP_B": tensor_map_b, "TENSOR_MAP_SFA": tensor_map_sfa, "TENSOR_MAP_SFB": tensor_map_sfb, "TENSOR_MAP_C": tensor_map_d, "TENSOR_MAP_D": tensor_map_d, "STREAM": torch.cuda.current_stream().cuda_stream, "DEVICE_INDEX": d.device.index, } return static_kwargs, all_kwargs def m_grouped_fp8_gemm_nt_contiguous_sm100( a: torch.Tensor, sfa: torch.Tensor, b: torch.Tensor, sfb: torch.Tensor, d: torch.Tensor, m_indices: torch.Tensor, major_a: MajorTypeAB, major_b: MajorTypeAB, compiled_dims: str, ) -> None: static_kwargs, all_kwargs = m_grouped_fp8_gemm_nt_contiguous_kwargs_gen( a, sfa, b, sfb, d, m_indices, major_a, major_b, compiled_dims ) # Generate, build and run the kernel code = SM100FP8GemmRuntime.generate(static_kwargs) runtime = load("fp8_m_grouped_gemm", code) runtime(**all_kwargs) def m_grouped_fp8_gemm_nt_masked_static_kwargs_gen( m: int, n: int, k: int, expected_m: int, aligned_k: int, num_groups: int, major_a: MajorTypeAB, major_b: MajorTypeAB, major_d: MajorTypeCD, compiled_dims: str, output_dtype: torch.dtype, ): num_sms = torch.cuda.get_device_properties(device="cuda").multi_processor_count num_sms, block_m, block_n, block_k, num_stages, multicast_config, smem_config = ( get_best_configs( GemmType.GroupedMasked, expected_m, n, k, num_groups, major_a, major_b, major_d, torch.float8_e4m3fn, output_dtype, num_sms, ) ) if num_groups > 1: assert m % block_m == 0 kwargs = { # Templated or runtime arguments according to the `COMPILED_DIMS` "COMPILED_DIMS": compiled_dims, "M": m, "N": n, "K": aligned_k, # Templated arguments "GEMM_TYPE": GemmType.GroupedMasked, "NUM_NON_EPILOGUE_THREADS": 128, "NUM_EPILOGUE_THREADS": 128, "MAJOR_A": major_a, "MAJOR_B": major_b, "NUM_GROUPS": num_groups, "BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "NUM_STAGES": num_stages, "NUM_LAST_STAGES": ceil_div(k, block_k) % num_stages, "SWIZZLE_A_MODE": smem_config.swizzle_a_mode, "SWIZZLE_B_MODE": smem_config.swizzle_b_mode, "SWIZZLE_CD_MODE": smem_config.swizzle_cd_mode, "NUM_MULTICAST": multicast_config.num_multicast, "IS_MULTICAST_ON_A": multicast_config.is_multicast_on_a, "WITH_ACCUMULATION": False, "CD_DTYPE_T": output_dtype, } return ( num_sms, block_m, block_n, block_k, num_stages, multicast_config, smem_config, ), kwargs def m_grouped_fp8_gemm_nt_masked_kwargs_gen( a: torch.Tensor, sfa: torch.Tensor, b: torch.Tensor, sfb: torch.Tensor, d: torch.Tensor, masked_m: torch.Tensor, expected_m: int, major_a: MajorTypeAB, major_b: MajorTypeAB, compiled_dims: str, ): num_groups, m, k = a.shape _, n, _ = b.shape major_d = MajorTypeCD.NMajor # K must be aligned to 128 aligned_k = round_up(k, 128) ( ( num_sms, block_m, block_n, block_k, num_stages, multicast_config, smem_config, ), static_kwargs, ) = m_grouped_fp8_gemm_nt_masked_static_kwargs_gen( m, n, k, expected_m, aligned_k, num_groups, major_a, major_b, major_d, compiled_dims, d.dtype, ) tensor_map_a = make_tma_a_desc( major_a, a, m, k, multicast_config.get_ab_load_block_m(block_m), block_k, a.stride(major_a.non_contiguous_dim()), num_groups, smem_config.swizzle_a_mode, ) tensor_map_b = make_tma_b_desc( major_b, b, n, k, multicast_config.get_ab_load_block_n(block_n), block_k, b.stride(major_b.non_contiguous_dim()), num_groups, smem_config.swizzle_b_mode, ) tensor_map_d = make_tma_cd_desc( major_d, d, m, n, block_m, block_n, d.stride(major_d.non_contiguous_dim()), num_groups, smem_config.swizzle_cd_mode, ) tensor_map_sfa = make_tma_sf_desc( MajorTypeAB.MNMajor, sfa, m, k, block_m, block_k, num_groups, smem_config.swizzle_sf_mode, ) tensor_map_sfb = make_tma_sf_desc( MajorTypeAB.MNMajor, sfb, n, k, block_n, block_k, num_groups, smem_config.swizzle_sf_mode, ) all_kwargs = { **static_kwargs, # Runtime arguments "GROUPED_LAYOUT": masked_m, "NUM_SMS": num_sms, "SMEM_SIZE": smem_config.smem_size, "TENSOR_MAP_A": tensor_map_a, "TENSOR_MAP_B": tensor_map_b, "TENSOR_MAP_SFA": tensor_map_sfa, "TENSOR_MAP_SFB": tensor_map_sfb, "TENSOR_MAP_C": tensor_map_d, "TENSOR_MAP_D": tensor_map_d, "STREAM": torch.cuda.current_stream().cuda_stream, "DEVICE_INDEX": d.device.index, } return static_kwargs, all_kwargs def m_grouped_fp8_gemm_nt_masked_sm100( a: torch.Tensor, sfa: torch.Tensor, b: torch.Tensor, sfb: torch.Tensor, d: torch.Tensor, masked_m: torch.Tensor, expected_m: int, major_a: MajorTypeAB, major_b: MajorTypeAB, compiled_dims: str, ) -> None: static_kwargs, all_kwargs = m_grouped_fp8_gemm_nt_masked_kwargs_gen( a, sfa, b, sfb, d, masked_m, expected_m, major_a, major_b, compiled_dims ) # Generate, build and run the kernel code = SM100FP8GemmRuntime.generate(static_kwargs) runtime = load("fp8_m_grouped_gemm", code) runtime(**all_kwargs) def m_grouped_fp8_gemm_nt_contiguous( a_fp8: Tuple[torch.Tensor, torch.Tensor], b_fp8: Tuple[torch.Tensor, torch.Tensor], d: torch.Tensor, m_indices: torch.Tensor, recipe: Optional[Tuple[int, int, int]] = None, compiled_dims: str = "nk", ) -> None: # Compiled dims can be upper cases compiled_dims = compiled_dims.lower() # NOTES: shape must be `[M, K] @ [G, N, K].mT` major_a = get_major_type_ab(a_fp8[0]) major_b = get_major_type_ab(b_fp8[0]) assert major_a == MajorTypeAB.KMajor if must_be_k_major(): assert major_b == MajorTypeAB.KMajor assert m_indices.is_contiguous() a, sfa = a_fp8 b, sfb = b_fp8 m, k = a.shape num_groups, n, k_ = b.shape m_, n_ = d.shape m__ = m_indices.numel() # Type and shape checks assert m == m_ == m__ and n == n_ and k == k_ assert n > 0 and k > 0 and num_groups > 0 assert a.dtype == torch.float8_e4m3fn assert b.dtype == torch.float8_e4m3fn assert d.dtype == torch.bfloat16 assert m_indices.dtype == torch.int32 # D must be N-major assert get_major_type_cd(d) == MajorTypeCD.NMajor # Do nothing if the problem is empty if m == 0: return # Transform SFA and SFB into compute-required layout recipe = get_default_recipe(sfa.dtype, sfb.dtype) if recipe is None else recipe sfa = transform_sf_into_required_layout(sfa, mn=m, k=k, recipe=recipe, is_sfa=True) sfb = transform_sf_into_required_layout( sfb, mn=n, k=k, recipe=recipe, num_groups=num_groups, is_sfa=False ) impl = { "100a": functools.partial( m_grouped_fp8_gemm_nt_contiguous_sm100, major_a=major_a, major_b=major_b, compiled_dims=compiled_dims, ) }[get_device_arch()] impl(a, sfa, b, sfb, d, m_indices) def m_grouped_fp8_gemm_nt_masked( a_fp8: Tuple[torch.Tensor, torch.Tensor], b_fp8: Tuple[torch.Tensor, torch.Tensor], d: torch.Tensor, masked_m: torch.Tensor, expected_m: int, recipe: Optional[Tuple[int, int, int]] = None, compiled_dims: str = "nk", ) -> None: # Compiled dims can be upper cases compiled_dims = compiled_dims.lower() # NOTES: shape must be `[G, M, K] @ [G, N, K].mT` major_a = get_major_type_ab(a_fp8[0]) major_b = get_major_type_ab(b_fp8[0]) assert major_a == major_b == MajorTypeAB.KMajor assert masked_m.is_contiguous() a, sfa = a_fp8 b, sfb = b_fp8 num_groups, m, k = a.shape num_groups_, n, k_ = b.shape num_groups__, m_, n_ = d.shape num_groups___ = masked_m.numel() # Type and shape checks assert num_groups == num_groups_ == num_groups__ == num_groups___ assert m == m_ and n == n_ and k == k_ assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0 assert a.dtype == torch.float8_e4m3fn assert b.dtype == torch.float8_e4m3fn assert d.dtype == torch.bfloat16 assert masked_m.dtype == torch.int32 # D must be N-major assert get_major_type_cd(d) == MajorTypeCD.NMajor # Transform SFA and SFB into compute-required layout recipe = get_default_recipe(sfa.dtype, sfb.dtype) if recipe is None else recipe sfa = transform_sf_into_required_layout( sfa, mn=m, k=k, recipe=recipe, num_groups=num_groups, is_sfa=True ) sfb = transform_sf_into_required_layout( sfb, mn=n, k=k, recipe=recipe, num_groups=num_groups, is_sfa=False ) impl = { "100a": functools.partial( m_grouped_fp8_gemm_nt_masked_sm100, major_a=major_a, major_b=major_b, compiled_dims=compiled_dims, ) }[get_device_arch()] impl(a, sfa, b, sfb, d, masked_m, expected_m) class KernelMap: def __init__(self, sha256: str): self.sha256 = sha256 self.indice = None def init_indices(self): indice_path = ArtifactPath.DEEPGEMM + "kernel_map" assert get_cubin(indice_path, self.sha256, file_extension=".json"), ( "cubin kernel map file not found, nor downloaded with matched sha256" ) path = FLASHINFER_CUBIN_DIR / f"{indice_path}.json" assert path.exists() with open(path, "r") as f: self.indice = json.load(f) def __iter__(self): if self.indice is None: self.init_indices() for name in self.indice: yield name def __getitem__(self, key): if self.indice is None: self.init_indices() return self.indice[key] KERNEL_MAP = KernelMap(MetaInfoHash.DEEPGEMM)