279 lines
8.0 KiB
Python
279 lines
8.0 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
import triton
|
|
|
|
from .kernels.sm_constraint_gemm import (
|
|
gemm_kernel,
|
|
gemm_kernel_descriptor_persistent,
|
|
gemm_kernel_persistent,
|
|
)
|
|
from .utils import check_device, check_dim, check_input
|
|
|
|
|
|
def gemm_persistent(a, b, c=None, alpha=1.0, beta=0.0, out_dtype=None, num_sms=None):
|
|
"""
|
|
GEMM operation with SM constraint by Triton.
|
|
C = alpha * (a @ b.T) + beta * C
|
|
|
|
Args:
|
|
a: The first input matrix. Shape: (M, K)
|
|
b: The second input matrix. Shape: (K, N)
|
|
c: The output matrix. Shape: (M, N). In-place epilogue is supported. Expected to be out_dtype (if not specified, same as a.dtype, but fp8 --> bf16).
|
|
alpha: The scaling factor for the product of a and b.
|
|
beta: The scaling factor for the output matrix c.
|
|
out_dtype: The dtype of the output matrix. Default: fp8 --> bf16. Otherwise, same as a.dtype.
|
|
num_sms: The number of SMs to use for the computation.
|
|
"""
|
|
|
|
# Check inputs.
|
|
check_input(a)
|
|
# b can be non-contiguous
|
|
check_device([a, b])
|
|
check_dim(2, a)
|
|
check_dim(2, b)
|
|
|
|
if c is not None:
|
|
check_input(c)
|
|
check_device([c])
|
|
check_dim(2, c)
|
|
|
|
assert a.shape[1] == b.shape[0], "Incompatible dimensions between a and b"
|
|
assert a.dtype == b.dtype, "Incompatible dtypes between a and b"
|
|
|
|
if c is not None:
|
|
assert a.shape[0] == c.shape[0], "Incompatible dimensions between a and c"
|
|
assert b.shape[1] == c.shape[1], "Incompatible dimensions between b and c"
|
|
|
|
M, K = a.shape
|
|
K, N = b.shape
|
|
dtype = a.dtype
|
|
out_dtype = (
|
|
out_dtype
|
|
if out_dtype
|
|
else dtype
|
|
if dtype != torch.float8_e4m3fn
|
|
else torch.bfloat16
|
|
)
|
|
|
|
assert c is None or c.dtype == out_dtype, (
|
|
"Incompatible dtypes between c and out_dtype"
|
|
)
|
|
|
|
# Allocates output.
|
|
c = torch.empty((M, N), device=a.device, dtype=out_dtype) if c is None else c
|
|
|
|
# Set num_sms to be 100% of the available SMs
|
|
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
|
num_sms = NUM_SMS if num_sms is None else min(NUM_SMS, num_sms)
|
|
|
|
# 1D launch kernel where each block gets its own program.
|
|
grid = lambda META: (
|
|
min(
|
|
num_sms,
|
|
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
|
),
|
|
)
|
|
|
|
gemm_kernel_persistent[grid](
|
|
a,
|
|
b,
|
|
c,
|
|
M,
|
|
N,
|
|
K,
|
|
a.stride(0),
|
|
a.stride(1),
|
|
b.stride(0),
|
|
b.stride(1),
|
|
c.stride(0),
|
|
c.stride(1),
|
|
alpha=alpha,
|
|
beta=beta,
|
|
NUM_SMS=num_sms,
|
|
)
|
|
return c
|
|
|
|
|
|
def gemm(a, b, c=None, alpha=1.0, beta=0.0, out_dtype=None):
|
|
"""
|
|
GEMM operation without SM constraint by Triton.
|
|
C = alpha * (a @ b.T) + beta * C
|
|
|
|
Args:
|
|
a: The first input matrix. Shape: (M, K)
|
|
b: The second input matrix. Shape: (K, N)
|
|
c: The output matrix. Shape: (M, N). In-place epilogue is supported. Expected to be out_dtype (if not specified, same as a.dtype, but fp8 --> bf16).
|
|
alpha: The scaling factor for the product of a and b.
|
|
beta: The scaling factor for the output matrix c.
|
|
out_dtype: The dtype of the output matrix. Default: fp8 --> bf16. Otherwise, same as a.dtype.
|
|
num_sms: The number of SMs to use for the computation.
|
|
"""
|
|
# Check inputs.
|
|
check_input(a)
|
|
# b can be non-contiguous
|
|
check_device([a, b])
|
|
check_dim(2, a)
|
|
check_dim(2, b)
|
|
|
|
if c is not None:
|
|
check_input(c)
|
|
check_device([c])
|
|
check_dim(2, c)
|
|
|
|
assert a.shape[1] == b.shape[0], "Incompatible dimensions between a and b"
|
|
assert a.dtype == b.dtype, "Incompatible dtypes between a and b"
|
|
|
|
if c is not None:
|
|
assert a.shape[0] == c.shape[0], "Incompatible dimensions between a and c"
|
|
assert b.shape[1] == c.shape[1], "Incompatible dimensions between b and c"
|
|
|
|
M, K = a.shape
|
|
K, N = b.shape
|
|
dtype = a.dtype
|
|
out_dtype = (
|
|
out_dtype
|
|
if out_dtype
|
|
else dtype
|
|
if dtype != torch.float8_e4m3fn
|
|
else torch.bfloat16
|
|
)
|
|
|
|
assert c is None or c.dtype == out_dtype, (
|
|
"Incompatible dtypes between c and out_dtype"
|
|
)
|
|
|
|
# Allocates output.
|
|
c = torch.empty((M, N), device=a.device, dtype=out_dtype) if c is None else c
|
|
|
|
# 1D launch kernel where each block gets its own program.
|
|
grid = lambda META: (
|
|
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
|
)
|
|
|
|
gemm_kernel[grid](
|
|
a,
|
|
b,
|
|
c,
|
|
M,
|
|
N,
|
|
K,
|
|
a.stride(0),
|
|
a.stride(1),
|
|
b.stride(0),
|
|
b.stride(1),
|
|
c.stride(0),
|
|
c.stride(1),
|
|
alpha=alpha,
|
|
beta=beta,
|
|
)
|
|
return c
|
|
|
|
|
|
def gemm_descriptor_persistent(
|
|
a,
|
|
b,
|
|
c=None,
|
|
alpha=1.0,
|
|
beta=0.0,
|
|
out_dtype=None,
|
|
num_sms=None,
|
|
EPILOGUE_SUBTILE=False,
|
|
):
|
|
"""
|
|
GEMM operation with SM constraint by Triton.
|
|
Requires TMA support and descriptor creation.
|
|
C = alpha * (a @ b.T) + beta * C
|
|
|
|
Note:
|
|
- K and N must be greater than 16B.
|
|
- Support float16, float8_e4m3fn, bfloat16.
|
|
- float32 is not supported due to performance issues.
|
|
|
|
Args:
|
|
a: The first input matrix. Shape: (M, K)
|
|
b: The second input matrix. Shape: (N, K)
|
|
c: The output matrix. Shape: (M, N). In-place epilogue is supported. Expected to be out_dtype (if not specified, same as a.dtype, but fp8 --> bf16).
|
|
alpha: The scaling factor for the product of a and b.
|
|
beta: The scaling factor for the output matrix c.
|
|
out_dtype: The dtype of the output matrix. Default: fp8 --> bf16. Otherwise, same as a.dtype.
|
|
num_sms: The number of SMs to use for the computation.
|
|
EPILOGUE_SUBTILE: Whether to use the epilogue subtile optimization.
|
|
"""
|
|
# Check inputs.
|
|
check_input(a)
|
|
check_input(b)
|
|
check_device([a, b])
|
|
check_dim(2, a)
|
|
check_dim(2, b)
|
|
|
|
if c is not None:
|
|
check_input(c)
|
|
check_device([c])
|
|
check_dim(2, c)
|
|
|
|
assert a.shape[1] == b.shape[1], "Incompatible dimensions between a and b"
|
|
assert a.dtype == b.dtype, "Incompatible dtypes between a and b"
|
|
|
|
if c is not None:
|
|
assert a.shape[0] == c.shape[0], "Incompatible dimensions between a and c"
|
|
assert b.shape[0] == c.shape[1], "Incompatible dimensions between b and c"
|
|
|
|
M, K = a.shape
|
|
N, K = b.shape
|
|
dtype = a.dtype
|
|
out_dtype = (
|
|
out_dtype
|
|
if out_dtype
|
|
else dtype
|
|
if dtype != torch.float8_e4m3fn
|
|
else torch.bfloat16
|
|
)
|
|
|
|
# check on TMA tensor map swizzling granularity
|
|
# Swizzle 16B chunks within at least 32B span
|
|
if dtype == torch.float8_e4m3fn:
|
|
assert K >= 16, "Least chunk size must be 16B"
|
|
assert N >= 16, "Least chunk size must be 16B"
|
|
else:
|
|
assert K >= 8, "Least chunk size must be 16B"
|
|
assert N >= 8, "Least chunk size must be 16B"
|
|
|
|
c = torch.empty((M, N), device=a.device, dtype=out_dtype) if c is None else c
|
|
|
|
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
|
num_sms = NUM_SMS if num_sms is None else min(NUM_SMS, num_sms)
|
|
|
|
# TMA descriptors require a global memory allocation
|
|
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
|
|
return torch.empty(size, device="cuda", dtype=torch.int8)
|
|
|
|
triton.set_allocator(alloc_fn)
|
|
|
|
grid = lambda META: (
|
|
min(
|
|
num_sms,
|
|
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
|
),
|
|
)
|
|
|
|
gemm_kernel_descriptor_persistent[grid](
|
|
a,
|
|
b,
|
|
c, #
|
|
M,
|
|
N,
|
|
K, #
|
|
alpha,
|
|
beta,
|
|
NUM_SMS=num_sms, #
|
|
BLOCK_SIZE_M=128,
|
|
BLOCK_SIZE_N=128 if dtype != torch.float32 else 64,
|
|
BLOCK_SIZE_K=64,
|
|
GROUP_SIZE_M=8,
|
|
num_stages=3,
|
|
num_warps=8,
|
|
EPILOGUE_SUBTILE=EPILOGUE_SUBTILE,
|
|
)
|
|
return c
|