#include "common.h" #include "vec.h" namespace { template void act_and_mul_kernel_impl( scalar_t* __restrict__ output, const scalar_t* __restrict__ input, int64_t num_tokens, int64_t dim, const func_t& f, const vec_func_t& vf) { using bVec = at::vec::Vectorized; using fVec = at::vec::Vectorized; constexpr int64_t kVecSize = bVec::size(); at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) { for (int64_t i = begin; i < end; ++i) { // local ptrs const scalar_t* __restrict__ input_ptr = input + i * 2 * dim; const scalar_t* __restrict__ input_other_ptr = input_ptr + dim; scalar_t* __restrict__ output_ptr = output + i * dim; int64_t d; #pragma GCC unroll 4 for (d = 0; d <= dim - kVecSize; d += kVecSize) { bVec x_bvec = bVec::loadu(input_ptr + d); fVec x_fvec0, x_fvec1; std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); bVec y_bvec = bVec::loadu(input_other_ptr + d); fVec y_fvec0, y_fvec1; std::tie(y_fvec0, y_fvec1) = at::vec::convert_to_float(y_bvec); x_fvec0 = vf(x_fvec0); x_fvec1 = vf(x_fvec1); x_fvec0 = x_fvec0 * y_fvec0; x_fvec1 = x_fvec1 * y_fvec1; x_bvec = convert_from_float_ext(x_fvec0, x_fvec1); x_bvec.store(output_ptr + d); } #pragma GCC unroll 4 for (; d < dim; ++d) { float x_val = static_cast(input_ptr[d]); float y_val = static_cast(input_other_ptr[d]); output_ptr[d] = f(x_val) * y_val; } } }); } } // anonymous namespace // input : {num_tokens, 2 * d} // output : {num_tokens, d} at::Tensor silu_and_mul_cpu(at::Tensor& input) { RECORD_FUNCTION("sgl-kernel::silu_and_mul_cpu", std::vector({input})); auto sizes = input.sizes().vec(); int64_t last_dim = input.ndimension() - 1; int64_t d = sizes[last_dim] / 2; sizes[last_dim] = d; int64_t num_tokens = input.numel() / input.size(-1); at::Tensor out = at::empty(sizes, input.options()); AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "silu_and_mul", [&] { using Vec = at::vec::Vectorized; act_and_mul_kernel_impl( out.data_ptr(), input.data_ptr(), num_tokens, d, [](float x) { return x / (1.f + std::exp(-x)); }, [](Vec x) { return x / (Vec(1.f) + x.neg().exp()); }); }); return out; }