sglang_v0.5.2/flashinfer_0.3.1/flashinfer/triton/activation.py

58 lines
1.5 KiB
Python

from collections.abc import Mapping
from typing import Optional
import torch
import triton # type: ignore[import]
from flashinfer.triton.kernels.activation import silu_and_mul_kernel
def silu_and_mul(
x: torch.Tensor,
x_scale: Optional[torch.Tensor] = None,
o_scale: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
"""Sigmoid Linear Unit and Multiplication
Computes `silu(x[:,:d]) * x[:, d:]`, where `d = x.shape[-1] // 2.
If the scale of `x` is `x_scale`, the scale applied to the output
is the square of that, as the sigmoid function ranges in (0, 1).
Args:
x: The input tensor, of shape `(b, 2 * d)`.
x_scale: An optional scale which was applied to `x`.
o_scale: The scale to apply to the output.
dtype: The desired output dtype.
Returns:
The output activation, of shape `(b, d)`.
"""
b, n = x.shape
assert n % 2 == 0
d = n // 2
o_dtype = dtype or x.dtype
o = torch.empty((b, d), dtype=o_dtype, device=x.device)
def grid(meta: Mapping[str, int]) -> tuple[int, int]:
return (b, triton.cdiv(d, meta["BLOCK_SIZE"]))
silu_and_mul_kernel[grid](
o_ptr=o,
o_stride=o.stride(0),
o_scale_ptr=o_scale,
x_ptr=x,
x_stride=x.stride(0),
x_scale_ptr=x_scale,
d=d,
BLOCK_SIZE=1024,
HAS_X_SCALE=x_scale is not None,
HAS_O_SCALE=o_scale is not None,
)
return o