217 lines
6.6 KiB
Python
217 lines
6.6 KiB
Python
"""
|
|
Copyright (c) 2024 by FlashInfer team.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import functools
|
|
from types import SimpleNamespace
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
from .jit import JitSpec
|
|
from .jit import gen_act_and_mul_module as gen_act_and_mul_module_impl
|
|
from .utils import device_support_pdl, register_custom_op, register_fake_op
|
|
|
|
silu_def_cu_str = r"""
|
|
__device__ __forceinline__ float silu(const float& val) {
|
|
return val / (1.0f + __expf(-val));
|
|
}
|
|
"""
|
|
|
|
gelu_def_cu_str = r"""
|
|
__device__ __forceinline__ float gelu(const float& val) {
|
|
constexpr float kAlpha = M_SQRT1_2;
|
|
return val * 0.5f * (1.0f + ::erf(val * kAlpha));
|
|
}
|
|
"""
|
|
|
|
gelu_def_tanh_cu_str = r"""
|
|
__device__ __forceinline__ float gelu_tanh(const float& val) {
|
|
const float cdf =
|
|
0.5f * (1.0f + math::tanh((0.7978845608028654f * (val + 0.044715f * val * val * val))));
|
|
return val * cdf;
|
|
}
|
|
"""
|
|
|
|
act_func_def_str = {
|
|
"silu": silu_def_cu_str,
|
|
"gelu": gelu_def_cu_str,
|
|
"gelu_tanh": gelu_def_tanh_cu_str,
|
|
}
|
|
|
|
|
|
def gen_act_and_mul_module(act_func_name: str) -> JitSpec:
|
|
return gen_act_and_mul_module_impl(act_func_name, act_func_def_str[act_func_name])
|
|
|
|
|
|
@functools.cache
|
|
def get_act_and_mul_module(act_func_name: str):
|
|
module = gen_act_and_mul_module(act_func_name).build_and_load()
|
|
|
|
# torch library for act_and_mul
|
|
fname = f"{act_func_name}_and_mul"
|
|
fn = getattr(module, fname).default
|
|
|
|
@register_custom_op(f"flashinfer::{fname}", mutates_args=("out",))
|
|
def _act_and_mul(
|
|
out: torch.Tensor, input: torch.Tensor, enable_pdl: Optional[bool] = None
|
|
) -> None:
|
|
if enable_pdl is None:
|
|
enable_pdl = device_support_pdl(input.device)
|
|
fn(out, input, enable_pdl)
|
|
|
|
@register_fake_op(f"flashinfer::{fname}")
|
|
def _fake_act_and_mul(
|
|
out: torch.Tensor, input: torch.Tensor, enable_pdl: Optional[bool] = None
|
|
) -> None:
|
|
pass
|
|
|
|
# Register the module
|
|
return SimpleNamespace(**{fname: _act_and_mul})
|
|
|
|
|
|
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
|
|
assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}"
|
|
assert input.shape[:-1] == output.shape[:-1], (
|
|
f"{input.shape[:-1]} != {output.shape[:-1]}"
|
|
)
|
|
assert input.shape[-1] == 2 * output.shape[-1], (
|
|
f"{input.shape[-1]} != {2 * output.shape[-1]}"
|
|
)
|
|
|
|
|
|
def silu_and_mul(
|
|
input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
|
|
) -> torch.Tensor:
|
|
r"""Fused SiLU and Mul operation.
|
|
|
|
``silu(input[..., :hidden_size]) * input[..., hidden_size:]``
|
|
|
|
Parameters
|
|
----------
|
|
input: torch.Tensor
|
|
Input tensor, shape (..., 2 * hidden_size).
|
|
|
|
out: Optional[torch.Tensor]
|
|
The output tensor, if specified, the kernel will update this tensor inplace.
|
|
|
|
enable_pdl: bool
|
|
Whether to enable `programmatic dependent launch
|
|
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
|
|
|
|
Returns
|
|
-------
|
|
output: torch.Tensor
|
|
Output tensor, shape (..., hidden_size).
|
|
"""
|
|
if enable_pdl is None:
|
|
enable_pdl = device_support_pdl(input.device)
|
|
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
|
raise ValueError("The pointers must be multiple of 16 bytes.")
|
|
if out is not None:
|
|
_check_shape(input, out)
|
|
else:
|
|
out = torch.empty(
|
|
input.shape[:-1] + (input.shape[-1] // 2,),
|
|
device=input.device,
|
|
dtype=input.dtype,
|
|
)
|
|
get_act_and_mul_module("silu").silu_and_mul(
|
|
out,
|
|
input,
|
|
enable_pdl,
|
|
)
|
|
return out
|
|
|
|
|
|
def gelu_tanh_and_mul(
|
|
input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
|
|
) -> torch.Tensor:
|
|
r"""Fused GeLU Tanh and Mul operation.
|
|
|
|
``gelu(tanh(input[..., :hidden_size])) * input[..., hidden_size:]``
|
|
|
|
Parameters
|
|
----------
|
|
input: torch.Tensor
|
|
Input tensor, shape (..., 2 * hidden_size).
|
|
|
|
out: Optional[torch.Tensor]
|
|
The output tensor, if specified, the kernel will update this tensor inplace.
|
|
|
|
enable_pdl: bool
|
|
Whether to enable `programmatic dependent launch
|
|
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
|
|
|
|
Returns
|
|
-------
|
|
output: torch.Tensor
|
|
Output tensor, shape (..., hidden_size).
|
|
"""
|
|
if enable_pdl is None:
|
|
enable_pdl = device_support_pdl(input.device)
|
|
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
|
raise ValueError("The pointers must be multiple of 16 bytes.")
|
|
if out is not None:
|
|
_check_shape(input, out)
|
|
else:
|
|
out = torch.empty(
|
|
input.shape[:-1] + (input.shape[-1] // 2,),
|
|
device=input.device,
|
|
dtype=input.dtype,
|
|
)
|
|
get_act_and_mul_module("gelu_tanh").gelu_tanh_and_mul(out, input, enable_pdl)
|
|
return out
|
|
|
|
|
|
def gelu_and_mul(
|
|
input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None
|
|
) -> torch.Tensor:
|
|
r"""Fused GeLU and Mul operation.
|
|
|
|
``gelu(input[..., :hidden_size]) * input[..., hidden_size:]``
|
|
|
|
Parameters
|
|
----------
|
|
input: torch.Tensor
|
|
Input tensor, shape (..., 2 * hidden_size).
|
|
|
|
out: Optional[torch.Tensor]
|
|
The output tensor, if specified, the kernel will update this tensor inplace.
|
|
|
|
enable_pdl: bool
|
|
Whether to enable `programmatic dependent launch
|
|
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
|
|
|
|
Returns
|
|
-------
|
|
output: torch.Tensor
|
|
Output tensor, shape (..., hidden_size).
|
|
"""
|
|
if enable_pdl is None:
|
|
enable_pdl = device_support_pdl(input.device)
|
|
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
|
raise ValueError("The pointers must be multiple of 16 bytes.")
|
|
if out is not None:
|
|
_check_shape(input, out)
|
|
else:
|
|
out = torch.empty(
|
|
input.shape[:-1] + (input.shape[-1] // 2,),
|
|
device=input.device,
|
|
dtype=input.dtype,
|
|
)
|
|
get_act_and_mul_module("gelu").gelu_and_mul(out, input, enable_pdl)
|
|
return out
|