88 lines
3.2 KiB
Plaintext
88 lines
3.2 KiB
Plaintext
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#pragma once
|
|
|
|
#include "utils.h"
|
|
|
|
#define kBitsToLoad 128
|
|
#define kBytesToLoad (kBitsToLoad / 8)
|
|
|
|
// Adapted from
|
|
// [flashinfer::activation::act_and_mul_kernel](https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/include/flashinfer/activation.cuh#L29)
|
|
|
|
namespace sgl_hip {
|
|
namespace activation {
|
|
|
|
template <typename T, T (*Activation)(const T&)>
|
|
__global__ void act_and_mul_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) {
|
|
constexpr uint32_t vec_size = kBytesToLoad / sizeof(T);
|
|
const int64_t token_idx = blockIdx.x;
|
|
const int64_t thread_idx = threadIdx.x;
|
|
const int64_t stride = blockDim.x;
|
|
const int64_t offset = token_idx * 2 * d;
|
|
|
|
#pragma unroll 1
|
|
for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) {
|
|
sgl_hip::vec_t<T, vec_size> x_vec, y_vec, out_vec;
|
|
x_vec.cast_load(input + offset + idx * vec_size);
|
|
y_vec.cast_load(input + offset + d + idx * vec_size);
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < vec_size; ++i) {
|
|
out_vec[i] = Activation(x_vec[i]) * y_vec[i];
|
|
}
|
|
out_vec.cast_store(out + token_idx * d + idx * vec_size);
|
|
}
|
|
|
|
const int64_t remaining_offset = d - d % (stride * vec_size);
|
|
// process the remaining elements
|
|
#pragma unroll 1
|
|
for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) {
|
|
T x = input[offset + remaining_offset + idx], y = input[offset + remaining_offset + d + idx];
|
|
out[token_idx * d + remaining_offset + idx] = Activation(x) * y;
|
|
}
|
|
}
|
|
|
|
template <typename T, T (*Activation)(const T&)>
|
|
__global__ void act_only_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) {
|
|
constexpr uint32_t vec_size = kBytesToLoad / sizeof(T);
|
|
const int64_t token_idx = blockIdx.x;
|
|
const int64_t thread_idx = threadIdx.x;
|
|
const int64_t stride = blockDim.x;
|
|
const int64_t offset = token_idx * d;
|
|
|
|
#pragma unroll 1
|
|
for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) {
|
|
sgl_hip::vec_t<T, vec_size> x_vec, y_vec, out_vec;
|
|
x_vec.cast_load(input + offset + idx * vec_size);
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < vec_size; ++i) {
|
|
out_vec[i] = Activation(x_vec[i]);
|
|
}
|
|
out_vec.cast_store(out + token_idx * d + idx * vec_size);
|
|
}
|
|
|
|
const int64_t remaining_offset = d - d % (stride * vec_size);
|
|
// process the remaining elements
|
|
#pragma unroll 1
|
|
for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) {
|
|
T x = input[offset + remaining_offset + idx];
|
|
out[token_idx * d + remaining_offset + idx] = Activation(x);
|
|
}
|
|
}
|
|
|
|
} // namespace activation
|
|
} // namespace sgl_hip
|