sglang_v0.5.2/flashinfer_0.3.1/flashinfer/triton/kernels/quant.py

28 lines
875 B
Python

import triton # type: ignore[import]
import triton.language as tl # type: ignore[import]
@triton.jit
def scale_and_clamp(x, scale, dtype):
"""Scales a value and clamps it to the range of the target dtype.
This function hard-wires the upper/lower bounds in order to be
compatible with both `torch.compile` and `triton.jit`.
"""
if dtype == tl.float8e4nv:
clamp_min = -448.0
clamp_max = 448.0
elif dtype == tl.float8e5:
clamp_min = -57344.0
clamp_max = 57344.0
elif dtype == tl.float16:
clamp_min = -65504.0
clamp_max = 65504.0
elif dtype == tl.bfloat16:
clamp_min = -3.3895313892515355e38
clamp_max = 3.3895313892515355e38
else:
tl.static_assert(False, f"Unsupported dtype: {dtype}")
return tl.clamp(x.to(tl.float32) * scale, clamp_min, clamp_max).to(dtype)