171 lines
5.6 KiB
Plaintext
171 lines
5.6 KiB
Plaintext
/*
|
|
* Copyright (c) 2024 by FlashInfer team.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
#include <torch/all.h>
|
|
|
|
#ifndef USE_ROCM
|
|
|
|
#include <flashinfer/activation.cuh>
|
|
|
|
#include "utils.h"
|
|
|
|
#else
|
|
#include "hip/hip_act_and_mul.cuh"
|
|
#endif
|
|
|
|
// Adapted from flashinfer activation
|
|
// https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/csrc/activation.cu#L44
|
|
|
|
namespace detail {
|
|
|
|
template <typename T>
|
|
__device__ __forceinline__ float to_f32(const T& x) {
|
|
#if USE_ROCM
|
|
return castToFloat(x);
|
|
#else
|
|
return static_cast<float>(x);
|
|
#endif
|
|
}
|
|
|
|
template <typename T>
|
|
__device__ __forceinline__ T from_f32(float f32) {
|
|
#if USE_ROCM
|
|
return castFromFloat<T>(f32);
|
|
#else
|
|
return static_cast<T>(f32);
|
|
#endif
|
|
}
|
|
|
|
} // namespace detail
|
|
|
|
template <typename T>
|
|
__device__ __forceinline__ T silu(const T& x) {
|
|
float f32_val = detail::to_f32(x);
|
|
return detail::from_f32<T>(f32_val / (1.0f + expf(-f32_val)));
|
|
}
|
|
|
|
template <typename T>
|
|
__device__ __forceinline__ T gelu(const T& x) {
|
|
constexpr float kAlpha = M_SQRT1_2;
|
|
float f32_val = detail::to_f32(x);
|
|
return detail::from_f32<T>(f32_val * (0.5f * (1.0f + erf(f32_val * kAlpha))));
|
|
}
|
|
|
|
// gelu_quick(x) = x * torch.sigmoid(1.702 * x)
|
|
template <typename T>
|
|
__device__ __forceinline__ T gelu_quick_act(const T& x) {
|
|
float f32_val = detail::to_f32(x);
|
|
return detail::from_f32<T>(f32_val / (1.0f + expf(-f32_val * 1.702f)));
|
|
}
|
|
|
|
template <typename T>
|
|
__device__ __forceinline__ T gelu_tanh(const T& x) {
|
|
constexpr float kAlpha = 0.044715f;
|
|
constexpr float kBeta = 0.7978845608028654f;
|
|
float f32_val = detail::to_f32(x);
|
|
const float cdf = 0.5f * (1.0f + tanhf((kBeta * (f32_val + kAlpha * f32_val * f32_val * f32_val))));
|
|
return detail::from_f32<T>(f32_val * cdf);
|
|
}
|
|
|
|
void silu_and_mul(at::Tensor& out, at::Tensor& input) {
|
|
int d = input.size(-1) / 2;
|
|
int64_t num_tokens = input.numel() / input.size(-1);
|
|
dim3 grid(num_tokens);
|
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
|
|
|
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
|
|
uint32_t vec_size = 16 / sizeof(c_type);
|
|
dim3 block(std::min(d / vec_size, 1024U));
|
|
#if USE_ROCM
|
|
sgl_hip::activation::act_and_mul_kernel<c_type, silu>
|
|
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
|
|
#else
|
|
flashinfer::activation::act_and_mul_kernel<c_type, silu>
|
|
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
|
|
#endif
|
|
return true;
|
|
});
|
|
}
|
|
|
|
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input) {
|
|
int d = input.size(-1) / 2;
|
|
int64_t num_tokens = input.numel() / input.size(-1);
|
|
dim3 grid(num_tokens);
|
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
|
|
|
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
|
|
uint32_t vec_size = 16 / sizeof(c_type);
|
|
dim3 block(std::min(d / vec_size, 1024U));
|
|
#if USE_ROCM
|
|
sgl_hip::activation::act_and_mul_kernel<c_type, gelu_tanh>
|
|
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
|
|
#else
|
|
flashinfer::activation::act_and_mul_kernel<c_type, gelu_tanh>
|
|
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
|
|
#endif
|
|
return true;
|
|
});
|
|
}
|
|
|
|
void gelu_and_mul(at::Tensor& out, at::Tensor& input) {
|
|
int d = input.size(-1) / 2;
|
|
int64_t num_tokens = input.numel() / input.size(-1);
|
|
dim3 grid(num_tokens);
|
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
|
|
|
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
|
|
uint32_t vec_size = 16 / sizeof(c_type);
|
|
dim3 block(std::min(d / vec_size, 1024U));
|
|
#if USE_ROCM
|
|
sgl_hip::activation::act_and_mul_kernel<c_type, gelu>
|
|
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
|
|
#else
|
|
flashinfer::activation::act_and_mul_kernel<c_type, gelu>
|
|
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
|
|
#endif
|
|
|
|
return true;
|
|
});
|
|
}
|
|
|
|
#if USE_ROCM
|
|
void gelu_quick(at::Tensor& out, const at::Tensor& input) {
|
|
int d = input.size(-1);
|
|
int64_t num_tokens = input.numel() / input.size(-1);
|
|
dim3 grid(num_tokens);
|
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
|
|
|
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] {
|
|
uint32_t vec_size = 16 / sizeof(c_type);
|
|
dim3 block(std::min(d / vec_size, 1024U));
|
|
sgl_hip::activation::act_only_kernel<c_type, gelu_quick_act>
|
|
<<<grid, block, 0, stream>>>(static_cast<c_type*>(out.data_ptr()), static_cast<c_type*>(input.data_ptr()), d);
|
|
|
|
return true;
|
|
});
|
|
}
|
|
#endif
|