sglang_v0.5.2/flashinfer_0.3.1/flashinfer/triton/sm_constraint_gemm.py

279 lines
8.0 KiB
Python

from typing import Optional
import torch
import triton
from .kernels.sm_constraint_gemm import (
gemm_kernel,
gemm_kernel_descriptor_persistent,
gemm_kernel_persistent,
)
from .utils import check_device, check_dim, check_input
def gemm_persistent(a, b, c=None, alpha=1.0, beta=0.0, out_dtype=None, num_sms=None):
"""
GEMM operation with SM constraint by Triton.
C = alpha * (a @ b.T) + beta * C
Args:
a: The first input matrix. Shape: (M, K)
b: The second input matrix. Shape: (K, N)
c: The output matrix. Shape: (M, N). In-place epilogue is supported. Expected to be out_dtype (if not specified, same as a.dtype, but fp8 --> bf16).
alpha: The scaling factor for the product of a and b.
beta: The scaling factor for the output matrix c.
out_dtype: The dtype of the output matrix. Default: fp8 --> bf16. Otherwise, same as a.dtype.
num_sms: The number of SMs to use for the computation.
"""
# Check inputs.
check_input(a)
# b can be non-contiguous
check_device([a, b])
check_dim(2, a)
check_dim(2, b)
if c is not None:
check_input(c)
check_device([c])
check_dim(2, c)
assert a.shape[1] == b.shape[0], "Incompatible dimensions between a and b"
assert a.dtype == b.dtype, "Incompatible dtypes between a and b"
if c is not None:
assert a.shape[0] == c.shape[0], "Incompatible dimensions between a and c"
assert b.shape[1] == c.shape[1], "Incompatible dimensions between b and c"
M, K = a.shape
K, N = b.shape
dtype = a.dtype
out_dtype = (
out_dtype
if out_dtype
else dtype
if dtype != torch.float8_e4m3fn
else torch.bfloat16
)
assert c is None or c.dtype == out_dtype, (
"Incompatible dtypes between c and out_dtype"
)
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=out_dtype) if c is None else c
# Set num_sms to be 100% of the available SMs
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
num_sms = NUM_SMS if num_sms is None else min(NUM_SMS, num_sms)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
min(
num_sms,
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
),
)
gemm_kernel_persistent[grid](
a,
b,
c,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
alpha=alpha,
beta=beta,
NUM_SMS=num_sms,
)
return c
def gemm(a, b, c=None, alpha=1.0, beta=0.0, out_dtype=None):
"""
GEMM operation without SM constraint by Triton.
C = alpha * (a @ b.T) + beta * C
Args:
a: The first input matrix. Shape: (M, K)
b: The second input matrix. Shape: (K, N)
c: The output matrix. Shape: (M, N). In-place epilogue is supported. Expected to be out_dtype (if not specified, same as a.dtype, but fp8 --> bf16).
alpha: The scaling factor for the product of a and b.
beta: The scaling factor for the output matrix c.
out_dtype: The dtype of the output matrix. Default: fp8 --> bf16. Otherwise, same as a.dtype.
num_sms: The number of SMs to use for the computation.
"""
# Check inputs.
check_input(a)
# b can be non-contiguous
check_device([a, b])
check_dim(2, a)
check_dim(2, b)
if c is not None:
check_input(c)
check_device([c])
check_dim(2, c)
assert a.shape[1] == b.shape[0], "Incompatible dimensions between a and b"
assert a.dtype == b.dtype, "Incompatible dtypes between a and b"
if c is not None:
assert a.shape[0] == c.shape[0], "Incompatible dimensions between a and c"
assert b.shape[1] == c.shape[1], "Incompatible dimensions between b and c"
M, K = a.shape
K, N = b.shape
dtype = a.dtype
out_dtype = (
out_dtype
if out_dtype
else dtype
if dtype != torch.float8_e4m3fn
else torch.bfloat16
)
assert c is None or c.dtype == out_dtype, (
"Incompatible dtypes between c and out_dtype"
)
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=out_dtype) if c is None else c
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
gemm_kernel[grid](
a,
b,
c,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
alpha=alpha,
beta=beta,
)
return c
def gemm_descriptor_persistent(
a,
b,
c=None,
alpha=1.0,
beta=0.0,
out_dtype=None,
num_sms=None,
EPILOGUE_SUBTILE=False,
):
"""
GEMM operation with SM constraint by Triton.
Requires TMA support and descriptor creation.
C = alpha * (a @ b.T) + beta * C
Note:
- K and N must be greater than 16B.
- Support float16, float8_e4m3fn, bfloat16.
- float32 is not supported due to performance issues.
Args:
a: The first input matrix. Shape: (M, K)
b: The second input matrix. Shape: (N, K)
c: The output matrix. Shape: (M, N). In-place epilogue is supported. Expected to be out_dtype (if not specified, same as a.dtype, but fp8 --> bf16).
alpha: The scaling factor for the product of a and b.
beta: The scaling factor for the output matrix c.
out_dtype: The dtype of the output matrix. Default: fp8 --> bf16. Otherwise, same as a.dtype.
num_sms: The number of SMs to use for the computation.
EPILOGUE_SUBTILE: Whether to use the epilogue subtile optimization.
"""
# Check inputs.
check_input(a)
check_input(b)
check_device([a, b])
check_dim(2, a)
check_dim(2, b)
if c is not None:
check_input(c)
check_device([c])
check_dim(2, c)
assert a.shape[1] == b.shape[1], "Incompatible dimensions between a and b"
assert a.dtype == b.dtype, "Incompatible dtypes between a and b"
if c is not None:
assert a.shape[0] == c.shape[0], "Incompatible dimensions between a and c"
assert b.shape[0] == c.shape[1], "Incompatible dimensions between b and c"
M, K = a.shape
N, K = b.shape
dtype = a.dtype
out_dtype = (
out_dtype
if out_dtype
else dtype
if dtype != torch.float8_e4m3fn
else torch.bfloat16
)
# check on TMA tensor map swizzling granularity
# Swizzle 16B chunks within at least 32B span
if dtype == torch.float8_e4m3fn:
assert K >= 16, "Least chunk size must be 16B"
assert N >= 16, "Least chunk size must be 16B"
else:
assert K >= 8, "Least chunk size must be 16B"
assert N >= 8, "Least chunk size must be 16B"
c = torch.empty((M, N), device=a.device, dtype=out_dtype) if c is None else c
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
num_sms = NUM_SMS if num_sms is None else min(NUM_SMS, num_sms)
# TMA descriptors require a global memory allocation
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
return torch.empty(size, device="cuda", dtype=torch.int8)
triton.set_allocator(alloc_fn)
grid = lambda META: (
min(
num_sms,
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
),
)
gemm_kernel_descriptor_persistent[grid](
a,
b,
c, #
M,
N,
K, #
alpha,
beta,
NUM_SMS=num_sms, #
BLOCK_SIZE_M=128,
BLOCK_SIZE_N=128 if dtype != torch.float32 else 64,
BLOCK_SIZE_K=64,
GROUP_SIZE_M=8,
num_stages=3,
num_warps=8,
EPILOGUE_SUBTILE=EPILOGUE_SUBTILE,
)
return c