259 lines
7.3 KiB
Python
259 lines
7.3 KiB
Python
"""
|
|
Copyright (c) 2024 by FlashInfer team.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import functools
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
from .jit import JitSpec
|
|
from .jit import env as jit_env
|
|
from .jit import gen_jit_spec
|
|
from .utils import device_support_pdl, register_custom_op, register_fake_op
|
|
|
|
|
|
def gen_norm_module() -> JitSpec:
|
|
return gen_jit_spec(
|
|
"norm",
|
|
[
|
|
jit_env.FLASHINFER_CSRC_DIR / "norm.cu",
|
|
jit_env.FLASHINFER_CSRC_DIR / "flashinfer_norm_ops.cu",
|
|
],
|
|
)
|
|
|
|
|
|
@functools.cache
|
|
def get_norm_module():
|
|
return gen_norm_module().build_and_load()
|
|
|
|
|
|
def rmsnorm(
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
eps: float = 1e-6,
|
|
out: Optional[torch.Tensor] = None,
|
|
enable_pdl: Optional[bool] = None,
|
|
) -> torch.Tensor:
|
|
r"""Root mean square normalization.
|
|
|
|
``out[i] = (input[i] / RMS(input)) * weight[i]``
|
|
|
|
Parameters
|
|
----------
|
|
input: torch.Tensor
|
|
Input tensor, shape (batch_size, hidden_size).
|
|
weight: torch.Tensor
|
|
Weight tensor, shape (hidden_size,).
|
|
eps: float
|
|
Epsilon for numerical stability.
|
|
out: Optional[torch.Tensor]
|
|
The output tensor, if specified, the kernel will update this tensor inplace.
|
|
enable_pdl: bool
|
|
Whether to enable `programmatic dependent launch
|
|
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
|
|
|
|
Returns
|
|
-------
|
|
output: torch.Tensor
|
|
Normalized tensor, shape (batch_size, hidden_size).
|
|
"""
|
|
if enable_pdl is None:
|
|
enable_pdl = device_support_pdl(input.device)
|
|
if out is None:
|
|
out = torch.empty_like(input)
|
|
_rmsnorm(out, input, weight, eps, enable_pdl)
|
|
return out
|
|
|
|
|
|
@register_custom_op("flashinfer::rmsnorm", mutates_args=("out",))
|
|
def _rmsnorm(
|
|
out: torch.Tensor,
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
eps: float,
|
|
enable_pdl: Optional[bool],
|
|
) -> None:
|
|
if enable_pdl is None:
|
|
enable_pdl = device_support_pdl(input.device)
|
|
get_norm_module().rmsnorm(out, input, weight, eps, enable_pdl)
|
|
|
|
|
|
@register_fake_op("flashinfer::rmsnorm")
|
|
def _rmsnorm_fake(
|
|
out: torch.Tensor,
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
eps: float,
|
|
enable_pdl: Optional[bool],
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
@register_custom_op("flashinfer::fused_add_rmsnorm", mutates_args=("input", "residual"))
|
|
def fused_add_rmsnorm(
|
|
input: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
eps: float = 1e-6,
|
|
enable_pdl: Optional[bool] = None,
|
|
) -> None:
|
|
r"""Fused add root mean square normalization.
|
|
|
|
Step 1:
|
|
``residual[i] += input[i]``
|
|
|
|
Step 2:
|
|
``input[i] = (residual[i] / RMS(residual)) * weight[i]``
|
|
|
|
Parameters
|
|
----------
|
|
input: torch.Tensor
|
|
Input tensor, shape (batch_size, hidden_size).
|
|
residual: torch.Tensor
|
|
Residual tensor, shape (batch_size, hidden_size).
|
|
weight: torch.Tensor
|
|
Weight tensor, shape (hidden_size,).
|
|
eps: float
|
|
Epsilon for numerical stability.
|
|
enable_pdl: bool
|
|
Whether to enable `programmatic dependent launch
|
|
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
|
|
"""
|
|
if enable_pdl is None:
|
|
enable_pdl = device_support_pdl(input.device)
|
|
get_norm_module().fused_add_rmsnorm(input, residual, weight, eps, enable_pdl)
|
|
|
|
|
|
@register_fake_op("flashinfer::fused_add_rmsnorm")
|
|
def _fused_add_rmsnorm_fake(
|
|
input: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
eps: float = 1e-6,
|
|
enable_pdl: Optional[bool] = None,
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
def gemma_rmsnorm(
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
eps: float = 1e-6,
|
|
out: Optional[torch.Tensor] = None,
|
|
enable_pdl: Optional[bool] = None,
|
|
) -> torch.Tensor:
|
|
r"""Gemma-style root mean square normalization.
|
|
|
|
``out[i] = (input[i] / RMS(input)) * (weight[i] + 1)``
|
|
|
|
Parameters
|
|
----------
|
|
input: torch.Tensor
|
|
Input tensor, shape (batch_size, hidden_size).
|
|
weight: torch.Tensor
|
|
Weight tensor, shape (hidden_size,).
|
|
eps: float
|
|
Epsilon for numerical stability.
|
|
out: Optional[torch.Tensor]
|
|
The output tensor, if specified, the kernel will update this tensor inplace.
|
|
enable_pdl: bool
|
|
Whether to enable `programmatic dependent launch
|
|
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
|
|
|
|
Returns
|
|
-------
|
|
output: torch.Tensor
|
|
Gemma Normalized tensor, shape (batch_size, hidden_size).
|
|
"""
|
|
if enable_pdl is None:
|
|
enable_pdl = device_support_pdl(input.device)
|
|
if out is None:
|
|
out = torch.empty_like(input)
|
|
_gemma_rmsnorm(out, input, weight, eps, enable_pdl)
|
|
return out
|
|
|
|
|
|
@register_custom_op("flashinfer::gemma_rmsnorm", mutates_args=("out",))
|
|
def _gemma_rmsnorm(
|
|
out: torch.Tensor,
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
eps: float,
|
|
enable_pdl: Optional[bool],
|
|
) -> None:
|
|
if enable_pdl is None:
|
|
enable_pdl = device_support_pdl(input.device)
|
|
get_norm_module().gemma_rmsnorm(out, input, weight, eps, enable_pdl)
|
|
|
|
|
|
@register_fake_op("flashinfer::gemma_rmsnorm")
|
|
def _gemma_rmsnorm_fake(
|
|
out: torch.Tensor,
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
eps: float,
|
|
enable_pdl: Optional[bool],
|
|
) -> None:
|
|
pass
|
|
|
|
|
|
@register_custom_op(
|
|
"flashinfer::gemma_fused_add_rmsnorm", mutates_args=("input", "residual")
|
|
)
|
|
def gemma_fused_add_rmsnorm(
|
|
input: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
eps: float = 1e-6,
|
|
enable_pdl: Optional[bool] = None,
|
|
) -> None:
|
|
r"""Gemma-style fused add root mean square normalization.
|
|
|
|
Step 1:
|
|
``residual[i] += input[i]``
|
|
|
|
Step 2:
|
|
``input[i] = (residual[i] / RMS(residual)) * (weight + 1)``
|
|
|
|
Parameters
|
|
----------
|
|
input: torch.Tensor
|
|
Input tensor, shape (batch_size, hidden_size).
|
|
residual: torch.Tensor
|
|
Residual tensor, shape (batch_size, hidden_size).
|
|
weight: torch.Tensor
|
|
Weight tensor, shape (hidden_size,).
|
|
eps: float
|
|
Epsilon for numerical stability.
|
|
enable_pdl: bool
|
|
Whether to enable `programmatic dependent launch
|
|
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
|
|
"""
|
|
if enable_pdl is None:
|
|
enable_pdl = device_support_pdl(input.device)
|
|
get_norm_module().gemma_fused_add_rmsnorm(input, residual, weight, eps, enable_pdl)
|
|
|
|
|
|
@register_fake_op("flashinfer::gemma_fused_add_rmsnorm")
|
|
def _gemma_fused_add_rmsnorm_fake(
|
|
input: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
eps: float = 1e-6,
|
|
enable_pdl: Optional[bool] = None,
|
|
) -> None:
|
|
pass
|