136 lines
4.5 KiB
C++
136 lines
4.5 KiB
C++
#include "common.h"
|
|
#include "vec.h"
|
|
|
|
namespace {
|
|
|
|
template <typename scalar_t, typename func_t, typename vec_func_t>
|
|
void act_and_mul_kernel_impl(
|
|
scalar_t* __restrict__ output,
|
|
const scalar_t* __restrict__ input,
|
|
int64_t num_tokens,
|
|
int64_t dim,
|
|
const func_t& f,
|
|
const vec_func_t& vf) {
|
|
using bVec = at::vec::Vectorized<scalar_t>;
|
|
using fVec = at::vec::Vectorized<float>;
|
|
|
|
constexpr int64_t kVecSize = bVec::size();
|
|
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
|
|
for (int64_t i = begin; i < end; ++i) {
|
|
// local ptrs
|
|
const scalar_t* __restrict__ input_ptr = input + i * 2 * dim;
|
|
const scalar_t* __restrict__ input_other_ptr = input_ptr + dim;
|
|
scalar_t* __restrict__ output_ptr = output + i * dim;
|
|
|
|
int64_t d;
|
|
#pragma GCC unroll 4
|
|
for (d = 0; d <= dim - kVecSize; d += kVecSize) {
|
|
bVec x_bvec = bVec::loadu(input_ptr + d);
|
|
fVec x_fvec0, x_fvec1;
|
|
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
|
|
|
bVec y_bvec = bVec::loadu(input_other_ptr + d);
|
|
fVec y_fvec0, y_fvec1;
|
|
std::tie(y_fvec0, y_fvec1) = at::vec::convert_to_float(y_bvec);
|
|
|
|
x_fvec0 = vf(x_fvec0);
|
|
x_fvec1 = vf(x_fvec1);
|
|
|
|
x_fvec0 = x_fvec0 * y_fvec0;
|
|
x_fvec1 = x_fvec1 * y_fvec1;
|
|
|
|
x_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
|
|
x_bvec.store(output_ptr + d);
|
|
}
|
|
#pragma GCC unroll 4
|
|
for (; d < dim; ++d) {
|
|
float x_val = static_cast<float>(input_ptr[d]);
|
|
float y_val = static_cast<float>(input_other_ptr[d]);
|
|
output_ptr[d] = f(x_val) * y_val;
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
// input : {num_tokens, 2 * d}
|
|
// output : {num_tokens, d}
|
|
at::Tensor silu_and_mul_cpu(at::Tensor& input) {
|
|
RECORD_FUNCTION("sgl-kernel::silu_and_mul_cpu", std::vector<c10::IValue>({input}));
|
|
auto sizes = input.sizes().vec();
|
|
int64_t last_dim = input.ndimension() - 1;
|
|
int64_t d = sizes[last_dim] / 2;
|
|
sizes[last_dim] = d;
|
|
int64_t num_tokens = input.numel() / input.size(-1);
|
|
at::Tensor out = at::empty(sizes, input.options());
|
|
|
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "silu_and_mul", [&] {
|
|
using Vec = at::vec::Vectorized<float>;
|
|
act_and_mul_kernel_impl(
|
|
out.data_ptr<scalar_t>(),
|
|
input.data_ptr<scalar_t>(),
|
|
num_tokens,
|
|
d,
|
|
[](float x) { return x / (1.f + std::exp(-x)); },
|
|
[](Vec x) { return x / (Vec(1.f) + x.neg().exp()); });
|
|
});
|
|
return out;
|
|
}
|
|
|
|
at::Tensor gelu_tanh_and_mul_cpu(const at::Tensor& input) {
|
|
RECORD_FUNCTION("sgl-kernel::gelu_tanh_and_mul_cpu", std::vector<c10::IValue>({input}));
|
|
auto sizes = input.sizes().vec();
|
|
int64_t last_dim = input.ndimension() - 1;
|
|
int64_t d = sizes[last_dim] / 2;
|
|
sizes[last_dim] = d;
|
|
int64_t num_tokens = input.numel() / input.size(-1);
|
|
at::Tensor out = at::empty(sizes, input.options());
|
|
const float sqrt_2_div_pi = std::sqrt(2.f / M_PI);
|
|
|
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "gelu_tanh_and_mul", [&] {
|
|
using Vec = at::vec::Vectorized<float>;
|
|
act_and_mul_kernel_impl(
|
|
out.data_ptr<scalar_t>(),
|
|
input.data_ptr<scalar_t>(),
|
|
num_tokens,
|
|
d,
|
|
[sqrt_2_div_pi](float x) {
|
|
float x3 = x * x * x;
|
|
float tanh_arg = sqrt_2_div_pi * (x + 0.044715f * x3);
|
|
return 0.5f * x * (1.f + std::tanh(tanh_arg));
|
|
},
|
|
[sqrt_2_div_pi](Vec x) {
|
|
Vec x3 = x * x * x;
|
|
Vec tanh_arg = Vec(sqrt_2_div_pi) * (x + Vec(0.044715f) * x3);
|
|
return Vec(0.5f) * x * (Vec(1.f) + tanh_arg.tanh());
|
|
});
|
|
});
|
|
|
|
return out;
|
|
}
|
|
|
|
at::Tensor gelu_and_mul_cpu(const at::Tensor& input) {
|
|
RECORD_FUNCTION("sgl-kernel::gelu_and_mul_cpu", std::vector<c10::IValue>({input}));
|
|
auto sizes = input.sizes().vec();
|
|
int64_t last_dim = input.ndimension() - 1;
|
|
int64_t d = sizes[last_dim] / 2;
|
|
sizes[last_dim] = d;
|
|
int64_t num_tokens = input.numel() / input.size(-1);
|
|
at::Tensor out = at::empty(sizes, input.options());
|
|
|
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul", [&] {
|
|
using Vec = at::vec::Vectorized<float>;
|
|
const float inv_sqrt2 = 1.0f / std::sqrt(2.0f);
|
|
act_and_mul_kernel_impl(
|
|
out.data_ptr<scalar_t>(),
|
|
input.data_ptr<scalar_t>(),
|
|
num_tokens,
|
|
d,
|
|
[inv_sqrt2](float x) { return 0.5f * x * (1.f + std::erf(x * inv_sqrt2)); },
|
|
[inv_sqrt2](Vec x) { return Vec(0.5f) * x * (Vec(1.f) + (x * Vec(inv_sqrt2)).erf()); });
|
|
});
|
|
|
|
return out;
|
|
}
|