102 lines
3.3 KiB
Python
102 lines
3.3 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from sglang.srt.utils import is_cuda, is_hip
|
|
|
|
_is_cuda = is_cuda()
|
|
_is_hip = is_hip()
|
|
|
|
|
|
class CustomOp(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._forward_method = self.dispatch_forward()
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return self._forward_method(*args, **kwargs)
|
|
|
|
def forward_native(self, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
def forward_cuda(self, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
def forward_hip(self, *args, **kwargs):
|
|
return self.forward_cuda(*args, **kwargs)
|
|
|
|
def forward_xpu(self, *args, **kwargs):
|
|
return self.forward_native(*args, **kwargs)
|
|
|
|
def forward_hpu(self, *args, **kwargs):
|
|
return self.forward_native(*args, **kwargs)
|
|
|
|
def forward_cpu(self, *args, **kwargs):
|
|
return self.forward_native(*args, **kwargs)
|
|
|
|
def dispatch_forward(self):
|
|
if _is_cuda:
|
|
return self.forward_cuda
|
|
elif _is_hip:
|
|
return self.forward_hip
|
|
else:
|
|
return self.forward_native
|
|
|
|
|
|
if _is_cuda:
|
|
from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
|
|
|
|
def scaled_fp8_quant(
|
|
input: torch.Tensor,
|
|
scale: Optional[torch.Tensor] = None,
|
|
use_per_token_if_dynamic: bool = False,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Quantize input tensor to FP8 (8-bit floating point) format.
|
|
|
|
Args:
|
|
input (torch.Tensor): Input tensor to be quantized
|
|
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
|
|
If None, scales will be computed dynamically.
|
|
use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
|
|
determines the quantization granularity:
|
|
- True: compute scale per token
|
|
- False: compute single scale per tensor
|
|
|
|
Returns:
|
|
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
|
- quantized_tensor: The FP8 quantized version of input
|
|
- scale_tensor: The scaling factors used for quantization
|
|
|
|
Raises:
|
|
AssertionError: If input is not 2D or if static scale's numel != 1
|
|
"""
|
|
assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
|
|
shape = input.shape
|
|
out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
|
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
|
|
|
if scale is None:
|
|
# Dynamic scaling
|
|
if use_per_token_if_dynamic:
|
|
scale = torch.empty(
|
|
(shape[0], 1), device=input.device, dtype=torch.float32
|
|
)
|
|
sgl_per_token_quant_fp8(input, output, scale)
|
|
else:
|
|
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
|
sgl_per_tensor_quant_fp8(
|
|
input, output, scale, is_static=False
|
|
) # False for dynamic
|
|
else:
|
|
# Static scaling
|
|
assert (
|
|
scale.numel() == 1
|
|
), f"Expected scalar scale, got numel={scale.numel()}"
|
|
sgl_per_tensor_quant_fp8(
|
|
input, output, scale, is_static=True
|
|
) # True for static
|
|
|
|
return output, scale
|