sglang_v0.5.2/sglang/sgl-kernel/include/hip/hip_act_and_mul.cuh

88 lines
3.2 KiB
Plaintext

/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#pragma once
#include "utils.h"
#define kBitsToLoad 128
#define kBytesToLoad (kBitsToLoad / 8)
// Adapted from
// [flashinfer::activation::act_and_mul_kernel](https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/include/flashinfer/activation.cuh#L29)
namespace sgl_hip {
namespace activation {
template <typename T, T (*Activation)(const T&)>
__global__ void act_and_mul_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) {
constexpr uint32_t vec_size = kBytesToLoad / sizeof(T);
const int64_t token_idx = blockIdx.x;
const int64_t thread_idx = threadIdx.x;
const int64_t stride = blockDim.x;
const int64_t offset = token_idx * 2 * d;
#pragma unroll 1
for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) {
sgl_hip::vec_t<T, vec_size> x_vec, y_vec, out_vec;
x_vec.cast_load(input + offset + idx * vec_size);
y_vec.cast_load(input + offset + d + idx * vec_size);
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
out_vec[i] = Activation(x_vec[i]) * y_vec[i];
}
out_vec.cast_store(out + token_idx * d + idx * vec_size);
}
const int64_t remaining_offset = d - d % (stride * vec_size);
// process the remaining elements
#pragma unroll 1
for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) {
T x = input[offset + remaining_offset + idx], y = input[offset + remaining_offset + d + idx];
out[token_idx * d + remaining_offset + idx] = Activation(x) * y;
}
}
template <typename T, T (*Activation)(const T&)>
__global__ void act_only_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) {
constexpr uint32_t vec_size = kBytesToLoad / sizeof(T);
const int64_t token_idx = blockIdx.x;
const int64_t thread_idx = threadIdx.x;
const int64_t stride = blockDim.x;
const int64_t offset = token_idx * d;
#pragma unroll 1
for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) {
sgl_hip::vec_t<T, vec_size> x_vec, y_vec, out_vec;
x_vec.cast_load(input + offset + idx * vec_size);
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
out_vec[i] = Activation(x_vec[i]);
}
out_vec.cast_store(out + token_idx * d + idx * vec_size);
}
const int64_t remaining_offset = d - d % (stride * vec_size);
// process the remaining elements
#pragma unroll 1
for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) {
T x = input[offset + remaining_offset + idx];
out[token_idx * d + remaining_offset + idx] = Activation(x);
}
}
} // namespace activation
} // namespace sgl_hip