28 lines
875 B
Python
28 lines
875 B
Python
import triton # type: ignore[import]
|
|
import triton.language as tl # type: ignore[import]
|
|
|
|
|
|
@triton.jit
|
|
def scale_and_clamp(x, scale, dtype):
|
|
"""Scales a value and clamps it to the range of the target dtype.
|
|
|
|
This function hard-wires the upper/lower bounds in order to be
|
|
compatible with both `torch.compile` and `triton.jit`.
|
|
"""
|
|
if dtype == tl.float8e4nv:
|
|
clamp_min = -448.0
|
|
clamp_max = 448.0
|
|
elif dtype == tl.float8e5:
|
|
clamp_min = -57344.0
|
|
clamp_max = 57344.0
|
|
elif dtype == tl.float16:
|
|
clamp_min = -65504.0
|
|
clamp_max = 65504.0
|
|
elif dtype == tl.bfloat16:
|
|
clamp_min = -3.3895313892515355e38
|
|
clamp_max = 3.3895313892515355e38
|
|
else:
|
|
tl.static_assert(False, f"Unsupported dtype: {dtype}")
|
|
|
|
return tl.clamp(x.to(tl.float32) * scale, clamp_min, clamp_max).to(dtype)
|