818 lines
24 KiB
Python
818 lines
24 KiB
Python
# Copyright 2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
import functools
|
|
import json
|
|
import logging
|
|
import os
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
from sglang.srt.utils import (
|
|
direct_register_custom_op,
|
|
get_bool_env_var,
|
|
get_device_core_count,
|
|
get_device_name,
|
|
get_device_sm,
|
|
is_cuda,
|
|
is_hip,
|
|
supports_custom_op,
|
|
)
|
|
|
|
_enable_jit_deepgemm = False
|
|
|
|
_is_hip = is_hip()
|
|
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
|
|
|
_is_cuda = is_cuda()
|
|
if _is_cuda:
|
|
import deep_gemm # `pip install "sgl-kernel>=0.0.4.post3"`
|
|
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
|
|
|
|
sm_version = get_device_sm()
|
|
if sm_version >= 90 and get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
|
|
_enable_jit_deepgemm = True
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
if supports_custom_op():
|
|
|
|
def deep_gemm_fp8_fp8_bf16_nt(
|
|
A: torch.Tensor,
|
|
As: torch.Tensor,
|
|
B: torch.Tensor,
|
|
Bs: torch.Tensor,
|
|
C: torch.Tensor,
|
|
) -> None:
|
|
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
|
|
|
def deep_gemm_fp8_fp8_bf16_nt_fake(
|
|
A: torch.Tensor,
|
|
As: torch.Tensor,
|
|
B: torch.Tensor,
|
|
Bs: torch.Tensor,
|
|
C: torch.Tensor,
|
|
) -> None:
|
|
return
|
|
|
|
direct_register_custom_op(
|
|
op_name="deep_gemm_fp8_fp8_bf16_nt",
|
|
op_func=deep_gemm_fp8_fp8_bf16_nt,
|
|
mutates_args=["C"],
|
|
fake_impl=deep_gemm_fp8_fp8_bf16_nt_fake,
|
|
)
|
|
|
|
|
|
@triton.jit
|
|
def _per_token_group_quant_fp8(
|
|
# Pointers to inputs and output
|
|
y_ptr,
|
|
y_q_ptr,
|
|
y_s_ptr,
|
|
# Stride of input
|
|
y_stride,
|
|
# Collums of input
|
|
N,
|
|
# Avoid to divide zero
|
|
eps,
|
|
# Information for float8
|
|
fp8_min,
|
|
fp8_max,
|
|
# Meta-parameters
|
|
BLOCK: tl.constexpr,
|
|
):
|
|
"""A Triton-accelerated function to perform per-token-group quantization on a
|
|
tensor.
|
|
|
|
This function converts the tensor values into float8 values.
|
|
"""
|
|
# Map the program id to the row of X and Y it should compute.
|
|
g_id = tl.program_id(0)
|
|
y_ptr += g_id * y_stride
|
|
y_q_ptr += g_id * y_stride
|
|
y_s_ptr += g_id
|
|
|
|
cols = tl.arange(0, BLOCK) # N <= BLOCK
|
|
mask = cols < N
|
|
|
|
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
|
# Quant
|
|
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
|
y_s = _absmax / fp8_max
|
|
y_s_inv = 1.0 / y_s
|
|
y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
|
|
|
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
|
tl.store(y_s_ptr, y_s)
|
|
|
|
|
|
@triton.jit
|
|
def _per_token_group_quant_fp8_colmajor(
|
|
# Pointers to inputs and output
|
|
y_ptr,
|
|
y_q_ptr,
|
|
y_s_ptr,
|
|
group_size,
|
|
# Num columns of y
|
|
y_num_columns,
|
|
# Stride from one column to the next of y_s
|
|
y_s_col_stride,
|
|
# Avoid to divide zero
|
|
eps,
|
|
# Information for float8
|
|
fp8_min,
|
|
fp8_max,
|
|
# Meta-parameters
|
|
BLOCK: tl.constexpr,
|
|
):
|
|
"""A Triton-accelerated function to perform per-token-group
|
|
quantization on a tensor.
|
|
This function converts the tensor values into float8 values.
|
|
"""
|
|
# Map the program id to the row of X and Y it should compute.
|
|
g_id = tl.program_id(0)
|
|
y_ptr += g_id * group_size
|
|
y_q_ptr += g_id * group_size
|
|
|
|
# Convert g_id the flattened block coordinate to 2D so we can index
|
|
# into the output y_scales matrix
|
|
blocks_per_row = y_num_columns // group_size
|
|
scale_col = g_id % blocks_per_row
|
|
scale_row = g_id // blocks_per_row
|
|
y_s_ptr += scale_col * y_s_col_stride + scale_row
|
|
|
|
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
|
|
mask = cols < group_size
|
|
|
|
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
|
# Quant
|
|
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
|
y_s = _absmax / fp8_max
|
|
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
|
|
|
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
|
tl.store(y_s_ptr, y_s)
|
|
|
|
|
|
def per_token_group_quant_fp8(
|
|
x: torch.Tensor,
|
|
group_size: int,
|
|
eps: float = 1e-10,
|
|
dtype: torch.dtype = fp8_type_,
|
|
column_major_scales: bool = False,
|
|
scale_tma_aligned: bool = False,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Function to perform per-token-group quantization on an input tensor `x`.
|
|
|
|
It converts the tensor values into signed float8 values and returns the
|
|
quantized tensor along with the scaling factor used for quantization.
|
|
|
|
Args:
|
|
x: The input tenosr with ndim >= 2.
|
|
group_size: The group size used for quantization.
|
|
eps: The minimum to avoid dividing zero.
|
|
dtype: The dype of output tensor.
|
|
|
|
Returns:
|
|
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
|
|
"""
|
|
assert (
|
|
x.shape[-1] % group_size == 0
|
|
), "the last dimension of `x` cannot be divisible by `group_size`"
|
|
assert x.is_contiguous(), "`x` is not contiguous"
|
|
|
|
finfo = torch.finfo(dtype)
|
|
fp8_max = finfo.max
|
|
|
|
if _is_hip:
|
|
fp8_max = 224.0
|
|
|
|
fp8_min = -fp8_max
|
|
|
|
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
|
M = x.numel() // group_size
|
|
N = group_size
|
|
if column_major_scales:
|
|
if scale_tma_aligned:
|
|
# aligned to 4 * sizeof(float)
|
|
aligned_size = (x.shape[-2] + 3) // 4 * 4
|
|
x_s = torch.empty(
|
|
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
|
|
device=x.device,
|
|
dtype=torch.float32,
|
|
).permute(-1, -2)[: x.shape[-2], :]
|
|
else:
|
|
x_s = torch.empty(
|
|
(x.shape[-1] // group_size,) + x.shape[:-1],
|
|
device=x.device,
|
|
dtype=torch.float32,
|
|
).permute(-1, -2)
|
|
else:
|
|
x_s = torch.empty(
|
|
x.shape[:-1] + (x.shape[-1] // group_size,),
|
|
device=x.device,
|
|
dtype=torch.float32,
|
|
)
|
|
|
|
BLOCK = triton.next_power_of_2(N)
|
|
# heuristics for number of warps
|
|
num_warps = min(max(BLOCK // 256, 1), 8)
|
|
num_stages = 1
|
|
if column_major_scales:
|
|
_per_token_group_quant_fp8_colmajor[(M,)](
|
|
x,
|
|
x_q,
|
|
x_s,
|
|
group_size,
|
|
x.shape[1],
|
|
x_s.stride(1),
|
|
eps,
|
|
fp8_min=fp8_min,
|
|
fp8_max=fp8_max,
|
|
BLOCK=BLOCK,
|
|
num_warps=num_warps,
|
|
num_stages=num_stages,
|
|
)
|
|
else:
|
|
_per_token_group_quant_fp8[(M,)](
|
|
x,
|
|
x_q,
|
|
x_s,
|
|
group_size,
|
|
N,
|
|
eps,
|
|
fp8_min=fp8_min,
|
|
fp8_max=fp8_max,
|
|
BLOCK=BLOCK,
|
|
num_warps=num_warps,
|
|
num_stages=num_stages,
|
|
)
|
|
|
|
return x_q, x_s
|
|
|
|
|
|
def sglang_per_token_group_quant_fp8(
|
|
x: torch.Tensor,
|
|
group_size: int,
|
|
eps: float = 1e-10,
|
|
dtype: torch.dtype = fp8_type_,
|
|
):
|
|
assert (
|
|
x.shape[-1] % group_size == 0
|
|
), "the last dimension of `x` cannot be divisible by `group_size`"
|
|
assert x.is_contiguous(), "`x` is not contiguous"
|
|
|
|
finfo = torch.finfo(dtype)
|
|
fp8_max = finfo.max
|
|
|
|
fp8_min = -fp8_max
|
|
|
|
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
|
M = x.numel() // group_size
|
|
N = group_size
|
|
x_s = torch.empty(
|
|
x.shape[:-1] + (x.shape[-1] // group_size,),
|
|
device=x.device,
|
|
dtype=torch.float32,
|
|
)
|
|
|
|
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
|
|
|
|
return x_q, x_s
|
|
|
|
|
|
def sglang_per_token_quant_fp8(
|
|
x: torch.Tensor,
|
|
dtype: torch.dtype = fp8_type_,
|
|
):
|
|
assert x.is_contiguous(), "`x` is not contiguous"
|
|
|
|
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
|
x_s = torch.empty(
|
|
x.shape[0],
|
|
1,
|
|
device=x.device,
|
|
dtype=torch.float32,
|
|
)
|
|
|
|
sgl_per_token_quant_fp8(x, x_q, x_s)
|
|
|
|
return x_q, x_s
|
|
|
|
|
|
@triton.jit
|
|
def _static_quant_fp8(
|
|
# Pointers to inputs and output
|
|
y_ptr,
|
|
y_q_ptr,
|
|
y_s_ptr,
|
|
y_s_repeat_ptr,
|
|
# Stride of input
|
|
y_stride,
|
|
# Collums of input
|
|
N,
|
|
# Information for float8
|
|
fp8_min,
|
|
fp8_max,
|
|
# Meta-parameters
|
|
BLOCK: tl.constexpr,
|
|
REPEAT_SCALE: tl.constexpr,
|
|
):
|
|
"""A Triton-accelerated function to perform quantization using the given scale on a
|
|
tensor
|
|
|
|
This function converts the tensor values into float8 values.
|
|
"""
|
|
# Map the program id to the row of X and Y it should compute.
|
|
g_id = tl.program_id(0)
|
|
y_ptr += g_id * y_stride
|
|
y_q_ptr += g_id * y_stride
|
|
if REPEAT_SCALE:
|
|
y_s_repeat_ptr += g_id
|
|
|
|
cols = tl.arange(0, BLOCK) # N <= BLOCK
|
|
mask = cols < N
|
|
|
|
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
|
y_s = tl.load(y_s_ptr).to(tl.float32)
|
|
y_s_inv = 1.0 / y_s
|
|
y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
|
|
|
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
|
if REPEAT_SCALE:
|
|
tl.store(y_s_repeat_ptr, y_s)
|
|
|
|
|
|
def static_quant_fp8(
|
|
x: torch.Tensor,
|
|
x_s: torch.Tensor,
|
|
repeat_scale: bool = False,
|
|
dtype: torch.dtype = fp8_type_,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Function to perform static quantization using the given scale on an input tensor `x`.
|
|
|
|
It converts the tensor values into signed float8 values and returns the
|
|
quantized tensor along with the scaling factor used for quantization.
|
|
|
|
Args:
|
|
x: The input tenosr with ndim >= 2.
|
|
x_s: The quantization scale.
|
|
repeat_scale: Whether to broadcast per-tensor scale to per-channel scale.
|
|
dtype: The dype of output tensor.
|
|
|
|
Returns:
|
|
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
|
|
"""
|
|
assert x.is_contiguous(), "`x` is not contiguous"
|
|
assert x_s.numel() == 1, "only supports per-tensor scale"
|
|
finfo = torch.finfo(dtype)
|
|
fp8_max = finfo.max
|
|
|
|
if _is_hip:
|
|
fp8_max = 224.0
|
|
|
|
fp8_min = -fp8_max
|
|
|
|
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
|
M = x.numel() // x.shape[-1]
|
|
N = x.shape[-1]
|
|
if repeat_scale:
|
|
x_s_repeat = torch.empty(
|
|
(M, 1),
|
|
device=x.device,
|
|
dtype=torch.float32,
|
|
)
|
|
else:
|
|
x_s_repeat = None
|
|
|
|
BLOCK = triton.next_power_of_2(N)
|
|
# heuristics for number of warps
|
|
num_warps = min(max(BLOCK // 256, 1), 8)
|
|
num_stages = 1
|
|
_static_quant_fp8[(M,)](
|
|
x,
|
|
x_q,
|
|
x_s,
|
|
x_s_repeat,
|
|
N,
|
|
N,
|
|
fp8_min=fp8_min,
|
|
fp8_max=fp8_max,
|
|
BLOCK=BLOCK,
|
|
REPEAT_SCALE=repeat_scale,
|
|
num_warps=num_warps,
|
|
num_stages=num_stages,
|
|
)
|
|
x_s = x_s_repeat if repeat_scale else x_s
|
|
return x_q, x_s
|
|
|
|
|
|
@triton.jit
|
|
def _w8a8_block_fp8_matmul(
|
|
# Pointers to inputs and output
|
|
A,
|
|
B,
|
|
C,
|
|
As,
|
|
Bs,
|
|
# Shape for matmul
|
|
M,
|
|
N,
|
|
K,
|
|
# Block size for block-wise quantization
|
|
group_n,
|
|
group_k,
|
|
# Stride for inputs and output
|
|
stride_am,
|
|
stride_ak,
|
|
stride_bk,
|
|
stride_bn,
|
|
stride_cm,
|
|
stride_cn,
|
|
stride_As_m,
|
|
stride_As_k,
|
|
stride_Bs_k,
|
|
stride_Bs_n,
|
|
# Meta-parameters
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
BLOCK_SIZE_K: tl.constexpr,
|
|
GROUP_SIZE_M: tl.constexpr,
|
|
):
|
|
"""Triton-accelerated function used to perform linear operations (dot
|
|
product) on input tensors `A` and `B` with block-wise quantization, and store the result in output
|
|
tensor `C`.
|
|
"""
|
|
|
|
pid = tl.program_id(axis=0)
|
|
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
|
group_id = pid // num_pid_in_group
|
|
first_pid_m = group_id * GROUP_SIZE_M
|
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
|
pid_m = first_pid_m + (pid % group_size_m)
|
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
|
|
|
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
|
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
|
|
|
As_ptrs = As + offs_am * stride_As_m
|
|
offs_bsn = offs_bn // group_n
|
|
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
|
|
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
|
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
|
|
|
k_start = k * BLOCK_SIZE_K
|
|
offs_ks = k_start // group_k
|
|
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
|
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
|
|
|
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
|
|
|
if C.dtype.element_ty == tl.bfloat16:
|
|
c = accumulator.to(tl.bfloat16)
|
|
elif C.dtype.element_ty == tl.float16:
|
|
c = accumulator.to(tl.float16)
|
|
else:
|
|
c = accumulator.to(tl.float32)
|
|
|
|
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
|
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
|
tl.store(c_ptrs, c, mask=c_mask)
|
|
|
|
|
|
@triton.jit
|
|
def _w8a8_block_fp8_matmul_unrolledx4(
|
|
# Pointers to inputs and output
|
|
A,
|
|
B,
|
|
C,
|
|
As,
|
|
Bs,
|
|
# Shape for matmul
|
|
M,
|
|
N,
|
|
K,
|
|
# Block size for block-wise quantization
|
|
group_n,
|
|
group_k,
|
|
# Stride for inputs and output
|
|
stride_am,
|
|
stride_ak,
|
|
stride_bk,
|
|
stride_bn,
|
|
stride_cm,
|
|
stride_cn,
|
|
stride_As_m,
|
|
stride_As_k,
|
|
stride_Bs_k,
|
|
stride_Bs_n,
|
|
# Meta-parameters
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
BLOCK_SIZE_K: tl.constexpr,
|
|
GROUP_SIZE_M: tl.constexpr,
|
|
):
|
|
"""Triton-accelerated function used to perform linear operations (dot
|
|
product) on input tensors `A` and `B` with block-wise quantization, and store the result in output
|
|
tensor `C`.
|
|
"""
|
|
|
|
pid = tl.program_id(axis=0)
|
|
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
|
group_id = pid // num_pid_in_group
|
|
first_pid_m = group_id * GROUP_SIZE_M
|
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
|
pid_m = first_pid_m + (pid % group_size_m)
|
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
|
|
|
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
|
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
|
|
|
As_ptrs = As + offs_am * stride_As_m
|
|
offs_bsn = offs_bn // group_n
|
|
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
|
|
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
# manually unroll to 4 iterations
|
|
UNROLL_FACTOR = 4
|
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * UNROLL_FACTOR)):
|
|
# 1st iteration
|
|
a = tl.load(
|
|
a_ptrs,
|
|
mask=offs_k[None, :] < K - (k * UNROLL_FACTOR) * BLOCK_SIZE_K,
|
|
other=0.0,
|
|
)
|
|
b = tl.load(
|
|
b_ptrs,
|
|
mask=offs_k[:, None] < K - (k * UNROLL_FACTOR) * BLOCK_SIZE_K,
|
|
other=0.0,
|
|
)
|
|
|
|
k_start = (k * UNROLL_FACTOR) * BLOCK_SIZE_K
|
|
offs_ks = k_start // group_k
|
|
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
|
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
|
|
|
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
|
|
|
# 2nd iteration
|
|
a = tl.load(
|
|
a_ptrs,
|
|
mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 1) * BLOCK_SIZE_K,
|
|
other=0.0,
|
|
)
|
|
b = tl.load(
|
|
b_ptrs,
|
|
mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 1) * BLOCK_SIZE_K,
|
|
other=0.0,
|
|
)
|
|
|
|
k_start = k_start + BLOCK_SIZE_K
|
|
offs_ks = k_start // group_k
|
|
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
|
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
|
|
|
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
|
|
|
# 3rd iteration
|
|
a = tl.load(
|
|
a_ptrs,
|
|
mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 2) * BLOCK_SIZE_K,
|
|
other=0.0,
|
|
)
|
|
b = tl.load(
|
|
b_ptrs,
|
|
mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 2) * BLOCK_SIZE_K,
|
|
other=0.0,
|
|
)
|
|
|
|
k_start = k_start + BLOCK_SIZE_K
|
|
offs_ks = k_start // group_k
|
|
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
|
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
|
|
|
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
|
|
|
# 4th iteration
|
|
a = tl.load(
|
|
a_ptrs,
|
|
mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 3) * BLOCK_SIZE_K,
|
|
other=0.0,
|
|
)
|
|
b = tl.load(
|
|
b_ptrs,
|
|
mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 3) * BLOCK_SIZE_K,
|
|
other=0.0,
|
|
)
|
|
|
|
k_start = k_start + BLOCK_SIZE_K
|
|
offs_ks = k_start // group_k
|
|
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
|
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
|
|
|
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
|
|
|
if C.dtype.element_ty == tl.bfloat16:
|
|
c = accumulator.to(tl.bfloat16)
|
|
elif C.dtype.element_ty == tl.float16:
|
|
c = accumulator.to(tl.float16)
|
|
else:
|
|
c = accumulator.to(tl.float32)
|
|
|
|
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
|
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
|
tl.store(c_ptrs, c, mask=c_mask)
|
|
|
|
|
|
@functools.lru_cache
|
|
def get_w8a8_block_fp8_configs(
|
|
N: int, K: int, block_n: int, block_k: int
|
|
) -> Optional[Dict[int, Any]]:
|
|
"""
|
|
Return optimized configurations for the w8a8 block fp8 kernel.
|
|
|
|
The return value will be a dictionary that maps an irregular grid of
|
|
batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
|
|
kernel on a given batch size bs, the closest batch size in the grid should
|
|
be picked and the associated configuration chosen to invoke the kernel.
|
|
"""
|
|
|
|
# First look up if an optimized configuration is available in the configs
|
|
# directory
|
|
device_name = get_device_name().replace(" ", "_")
|
|
json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n}, {block_k}].json"
|
|
|
|
config_file_path = os.path.join(
|
|
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
|
|
)
|
|
if os.path.exists(config_file_path):
|
|
with open(config_file_path) as f:
|
|
logger.info(
|
|
"Using configuration from %s for W8A8 Block FP8 kernel.",
|
|
config_file_path,
|
|
)
|
|
# If a configuration has been found, return it
|
|
return {int(key): val for key, val in json.load(f).items()}
|
|
|
|
# If no optimized configuration is available, we will use the default
|
|
# configuration
|
|
logger.warning(
|
|
(
|
|
"Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! "
|
|
"Config file not found at %s"
|
|
),
|
|
config_file_path,
|
|
)
|
|
return None
|
|
|
|
|
|
def w8a8_block_fp8_matmul(
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
As: torch.Tensor,
|
|
Bs: torch.Tensor,
|
|
block_size: List[int],
|
|
output_dtype: torch.dtype = torch.float16,
|
|
) -> torch.Tensor:
|
|
"""This function performs matrix multiplication with block-wise quantization.
|
|
|
|
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
|
The output is returned in the specified `output_dtype`.
|
|
|
|
Args:
|
|
A: The input tensor, e.g., activation.
|
|
B: The input tensor, e.g., weight.
|
|
As: The per-token-group quantization scale for `A`.
|
|
Bs: The per-block quantization scale for `B`.
|
|
block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128].
|
|
output_dytpe: The dtype of the returned tensor.
|
|
|
|
Returns:
|
|
torch.Tensor: The result of matmul.
|
|
"""
|
|
assert len(block_size) == 2
|
|
block_n, block_k = block_size[0], block_size[1]
|
|
|
|
assert A.shape[-1] == B.shape[-1]
|
|
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
|
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
|
M = A.numel() // A.shape[-1]
|
|
|
|
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
|
N, K = B.shape
|
|
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
|
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
|
|
|
C_shape = A.shape[:-1] + (N,)
|
|
C = A.new_empty(C_shape, dtype=output_dtype)
|
|
|
|
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
|
|
if configs:
|
|
# If an optimal configuration map has been found, look up the
|
|
# optimal config
|
|
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
|
else:
|
|
# Default config
|
|
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
|
|
config = {
|
|
"BLOCK_SIZE_M": 64,
|
|
"BLOCK_SIZE_N": block_size[0],
|
|
"BLOCK_SIZE_K": block_size[1],
|
|
"GROUP_SIZE_M": 32,
|
|
"num_warps": 4,
|
|
"num_stages": 3,
|
|
}
|
|
|
|
def grid(META):
|
|
return (
|
|
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
|
)
|
|
|
|
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
|
|
# Empirical testing shows the sweet spot lies when it's less than the # of
|
|
# compute units available on the device.
|
|
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
|
|
N, config["BLOCK_SIZE_N"]
|
|
)
|
|
|
|
# deepgemm only support bf16
|
|
if C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
|
|
if supports_custom_op():
|
|
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
|
else:
|
|
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
|
else:
|
|
kernel = (
|
|
_w8a8_block_fp8_matmul_unrolledx4
|
|
if (_is_hip == True and num_workgroups <= get_device_core_count())
|
|
else _w8a8_block_fp8_matmul
|
|
)
|
|
|
|
kernel[grid](
|
|
A,
|
|
B,
|
|
C,
|
|
As,
|
|
Bs,
|
|
M,
|
|
N,
|
|
K,
|
|
block_n,
|
|
block_k,
|
|
A.stride(-2),
|
|
A.stride(-1),
|
|
B.stride(1),
|
|
B.stride(0),
|
|
C.stride(-2),
|
|
C.stride(-1),
|
|
As.stride(-2),
|
|
As.stride(-1),
|
|
Bs.stride(1),
|
|
Bs.stride(0),
|
|
**config,
|
|
)
|
|
|
|
return C
|