70 lines
2.3 KiB
Plaintext
70 lines
2.3 KiB
Plaintext
/*
|
|
* Copyright (c) 2024 by FlashInfer team.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#ifndef FLASHINFER_ACTIVATION_CUH_
|
|
#define FLASHINFER_ACTIVATION_CUH_
|
|
|
|
#include "math.cuh"
|
|
#include "utils.cuh"
|
|
#include "vec_dtypes.cuh"
|
|
|
|
namespace flashinfer {
|
|
|
|
namespace activation {
|
|
|
|
template <typename T, float (*Activation)(const float&)>
|
|
__global__ void act_and_mul_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) {
|
|
constexpr uint32_t vec_size = 16 / sizeof(T);
|
|
const int64_t token_idx = blockIdx.x;
|
|
const int64_t thread_idx = threadIdx.x;
|
|
const int64_t stride = blockDim.x;
|
|
const int64_t offset = token_idx * 2 * d;
|
|
|
|
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
asm volatile("griddepcontrol.wait;");
|
|
#endif
|
|
|
|
#pragma unroll 1
|
|
for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) {
|
|
vec_t<float, vec_size> x_vec, y_vec, out_vec;
|
|
x_vec.cast_load(input + offset + idx * vec_size);
|
|
y_vec.cast_load(input + offset + d + idx * vec_size);
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < vec_size; ++i) {
|
|
out_vec[i] = Activation(x_vec[i]) * y_vec[i];
|
|
}
|
|
out_vec.cast_store(out + token_idx * d + idx * vec_size);
|
|
}
|
|
|
|
const int64_t remaining_offset = d - d % (stride * vec_size);
|
|
// process the remaining elements
|
|
#pragma unroll 1
|
|
for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) {
|
|
float x = input[offset + remaining_offset + idx],
|
|
y = input[offset + remaining_offset + d + idx];
|
|
out[token_idx * d + remaining_offset + idx] = Activation(x) * y;
|
|
}
|
|
|
|
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
asm volatile("griddepcontrol.launch_dependents;");
|
|
#endif
|
|
}
|
|
|
|
} // namespace activation
|
|
} // namespace flashinfer
|
|
|
|
#endif // FLASHINFER_ACTIVATION_CUH_
|