""" Copyright (c) 2024 by FlashInfer team. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import functools from types import SimpleNamespace from typing import Optional import torch from .jit import JitSpec from .jit import gen_act_and_mul_module as gen_act_and_mul_module_impl from .utils import device_support_pdl, register_custom_op, register_fake_op silu_def_cu_str = r""" __device__ __forceinline__ float silu(const float& val) { return val / (1.0f + __expf(-val)); } """ gelu_def_cu_str = r""" __device__ __forceinline__ float gelu(const float& val) { constexpr float kAlpha = M_SQRT1_2; return val * 0.5f * (1.0f + ::erf(val * kAlpha)); } """ gelu_def_tanh_cu_str = r""" __device__ __forceinline__ float gelu_tanh(const float& val) { const float cdf = 0.5f * (1.0f + math::tanh((0.7978845608028654f * (val + 0.044715f * val * val * val)))); return val * cdf; } """ act_func_def_str = { "silu": silu_def_cu_str, "gelu": gelu_def_cu_str, "gelu_tanh": gelu_def_tanh_cu_str, } def gen_act_and_mul_module(act_func_name: str) -> JitSpec: return gen_act_and_mul_module_impl(act_func_name, act_func_def_str[act_func_name]) @functools.cache def get_act_and_mul_module(act_func_name: str): module = gen_act_and_mul_module(act_func_name).build_and_load() # torch library for act_and_mul fname = f"{act_func_name}_and_mul" fn = getattr(module, fname).default @register_custom_op(f"flashinfer::{fname}", mutates_args=("out",)) def _act_and_mul( out: torch.Tensor, input: torch.Tensor, enable_pdl: Optional[bool] = None ) -> None: if enable_pdl is None: enable_pdl = device_support_pdl(input.device) fn(out, input, enable_pdl) @register_fake_op(f"flashinfer::{fname}") def _fake_act_and_mul( out: torch.Tensor, input: torch.Tensor, enable_pdl: Optional[bool] = None ) -> None: pass # Register the module return SimpleNamespace(**{fname: _act_and_mul}) def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None: assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}" assert input.shape[:-1] == output.shape[:-1], ( f"{input.shape[:-1]} != {output.shape[:-1]}" ) assert input.shape[-1] == 2 * output.shape[-1], ( f"{input.shape[-1]} != {2 * output.shape[-1]}" ) def silu_and_mul( input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None ) -> torch.Tensor: r"""Fused SiLU and Mul operation. ``silu(input[..., :hidden_size]) * input[..., hidden_size:]`` Parameters ---------- input: torch.Tensor Input tensor, shape (..., 2 * hidden_size). out: Optional[torch.Tensor] The output tensor, if specified, the kernel will update this tensor inplace. enable_pdl: bool Whether to enable `programmatic dependent launch `_ Returns ------- output: torch.Tensor Output tensor, shape (..., hidden_size). """ if enable_pdl is None: enable_pdl = device_support_pdl(input.device) if input.shape[-1] * input.dtype.itemsize % 16 != 0: raise ValueError("The pointers must be multiple of 16 bytes.") if out is not None: _check_shape(input, out) else: out = torch.empty( input.shape[:-1] + (input.shape[-1] // 2,), device=input.device, dtype=input.dtype, ) get_act_and_mul_module("silu").silu_and_mul( out, input, enable_pdl, ) return out def gelu_tanh_and_mul( input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None ) -> torch.Tensor: r"""Fused GeLU Tanh and Mul operation. ``gelu(tanh(input[..., :hidden_size])) * input[..., hidden_size:]`` Parameters ---------- input: torch.Tensor Input tensor, shape (..., 2 * hidden_size). out: Optional[torch.Tensor] The output tensor, if specified, the kernel will update this tensor inplace. enable_pdl: bool Whether to enable `programmatic dependent launch `_ Returns ------- output: torch.Tensor Output tensor, shape (..., hidden_size). """ if enable_pdl is None: enable_pdl = device_support_pdl(input.device) if input.shape[-1] * input.dtype.itemsize % 16 != 0: raise ValueError("The pointers must be multiple of 16 bytes.") if out is not None: _check_shape(input, out) else: out = torch.empty( input.shape[:-1] + (input.shape[-1] // 2,), device=input.device, dtype=input.dtype, ) get_act_and_mul_module("gelu_tanh").gelu_tanh_and_mul(out, input, enable_pdl) return out def gelu_and_mul( input: torch.Tensor, out: torch.Tensor = None, enable_pdl: Optional[bool] = None ) -> torch.Tensor: r"""Fused GeLU and Mul operation. ``gelu(input[..., :hidden_size]) * input[..., hidden_size:]`` Parameters ---------- input: torch.Tensor Input tensor, shape (..., 2 * hidden_size). out: Optional[torch.Tensor] The output tensor, if specified, the kernel will update this tensor inplace. enable_pdl: bool Whether to enable `programmatic dependent launch `_ Returns ------- output: torch.Tensor Output tensor, shape (..., hidden_size). """ if enable_pdl is None: enable_pdl = device_support_pdl(input.device) if input.shape[-1] * input.dtype.itemsize % 16 != 0: raise ValueError("The pointers must be multiple of 16 bytes.") if out is not None: _check_shape(input, out) else: out = torch.empty( input.shape[:-1] + (input.shape[-1] // 2,), device=input.device, dtype=input.dtype, ) get_act_and_mul_module("gelu").gelu_and_mul(out, input, enable_pdl) return out