sglang_v0.5.2/flashinfer_0.3.1/flashinfer/norm.py

259 lines
7.3 KiB
Python

"""
Copyright (c) 2024 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import functools
from typing import Optional
import torch
from .jit import JitSpec
from .jit import env as jit_env
from .jit import gen_jit_spec
from .utils import device_support_pdl, register_custom_op, register_fake_op
def gen_norm_module() -> JitSpec:
return gen_jit_spec(
"norm",
[
jit_env.FLASHINFER_CSRC_DIR / "norm.cu",
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_norm_ops.cu",
],
)
@functools.cache
def get_norm_module():
return gen_norm_module().build_and_load()
def rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
out: Optional[torch.Tensor] = None,
enable_pdl: Optional[bool] = None,
) -> torch.Tensor:
r"""Root mean square normalization.
``out[i] = (input[i] / RMS(input)) * weight[i]``
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
out: Optional[torch.Tensor]
The output tensor, if specified, the kernel will update this tensor inplace.
enable_pdl: bool
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
Returns
-------
output: torch.Tensor
Normalized tensor, shape (batch_size, hidden_size).
"""
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
if out is None:
out = torch.empty_like(input)
_rmsnorm(out, input, weight, eps, enable_pdl)
return out
@register_custom_op("flashinfer::rmsnorm", mutates_args=("out",))
def _rmsnorm(
out: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
eps: float,
enable_pdl: Optional[bool],
) -> None:
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
get_norm_module().rmsnorm(out, input, weight, eps, enable_pdl)
@register_fake_op("flashinfer::rmsnorm")
def _rmsnorm_fake(
out: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
eps: float,
enable_pdl: Optional[bool],
) -> None:
pass
@register_custom_op("flashinfer::fused_add_rmsnorm", mutates_args=("input", "residual"))
def fused_add_rmsnorm(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
enable_pdl: Optional[bool] = None,
) -> None:
r"""Fused add root mean square normalization.
Step 1:
``residual[i] += input[i]``
Step 2:
``input[i] = (residual[i] / RMS(residual)) * weight[i]``
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
residual: torch.Tensor
Residual tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
enable_pdl: bool
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
"""
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
get_norm_module().fused_add_rmsnorm(input, residual, weight, eps, enable_pdl)
@register_fake_op("flashinfer::fused_add_rmsnorm")
def _fused_add_rmsnorm_fake(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
enable_pdl: Optional[bool] = None,
) -> None:
pass
def gemma_rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
out: Optional[torch.Tensor] = None,
enable_pdl: Optional[bool] = None,
) -> torch.Tensor:
r"""Gemma-style root mean square normalization.
``out[i] = (input[i] / RMS(input)) * (weight[i] + 1)``
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
out: Optional[torch.Tensor]
The output tensor, if specified, the kernel will update this tensor inplace.
enable_pdl: bool
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
Returns
-------
output: torch.Tensor
Gemma Normalized tensor, shape (batch_size, hidden_size).
"""
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
if out is None:
out = torch.empty_like(input)
_gemma_rmsnorm(out, input, weight, eps, enable_pdl)
return out
@register_custom_op("flashinfer::gemma_rmsnorm", mutates_args=("out",))
def _gemma_rmsnorm(
out: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
eps: float,
enable_pdl: Optional[bool],
) -> None:
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
get_norm_module().gemma_rmsnorm(out, input, weight, eps, enable_pdl)
@register_fake_op("flashinfer::gemma_rmsnorm")
def _gemma_rmsnorm_fake(
out: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
eps: float,
enable_pdl: Optional[bool],
) -> None:
pass
@register_custom_op(
"flashinfer::gemma_fused_add_rmsnorm", mutates_args=("input", "residual")
)
def gemma_fused_add_rmsnorm(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
enable_pdl: Optional[bool] = None,
) -> None:
r"""Gemma-style fused add root mean square normalization.
Step 1:
``residual[i] += input[i]``
Step 2:
``input[i] = (residual[i] / RMS(residual)) * (weight + 1)``
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
residual: torch.Tensor
Residual tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
enable_pdl: bool
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
"""
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)
get_norm_module().gemma_fused_add_rmsnorm(input, residual, weight, eps, enable_pdl)
@register_fake_op("flashinfer::gemma_fused_add_rmsnorm")
def _gemma_fused_add_rmsnorm_fake(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
enable_pdl: Optional[bool] = None,
) -> None:
pass