350 lines
13 KiB
Plaintext
350 lines
13 KiB
Plaintext
/*
|
|
* Copyright (c) 2024 by FlashInfer team.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#ifndef FLASHINFER_NORM_CUH_
|
|
#define FLASHINFER_NORM_CUH_
|
|
|
|
#include <numeric>
|
|
|
|
#include "flashinfer/utils.cuh"
|
|
#include "math.cuh"
|
|
#include "utils.cuh"
|
|
#include "vec_dtypes.cuh"
|
|
|
|
namespace flashinfer {
|
|
|
|
namespace norm {
|
|
|
|
template <uint32_t VEC_SIZE, typename T>
|
|
__global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T* __restrict__ output,
|
|
const uint32_t d, const uint32_t stride_input,
|
|
const uint32_t stride_output, float weight_bias, float eps) {
|
|
const uint32_t bx = blockIdx.x;
|
|
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
|
constexpr uint32_t warp_size = 32;
|
|
const uint32_t num_warps = blockDim.y;
|
|
// NOTE(Zihao): it's guaranteed that num_warps should be smaller than 32
|
|
const uint32_t thread_id = tx + ty * warp_size;
|
|
const uint32_t num_threads = num_warps * warp_size;
|
|
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
|
|
extern __shared__ float smem[];
|
|
|
|
float sum_sq = 0.f;
|
|
|
|
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
asm volatile("griddepcontrol.wait;");
|
|
#endif
|
|
|
|
for (uint32_t i = 0; i < rounds; i++) {
|
|
vec_t<T, VEC_SIZE> input_vec;
|
|
input_vec.fill(0.f);
|
|
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
|
|
input_vec.load(input + bx * stride_input + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
|
|
}
|
|
#pragma unroll
|
|
for (uint32_t j = 0; j < VEC_SIZE; j++) {
|
|
sum_sq += float(input_vec[j]) * float(input_vec[j]);
|
|
}
|
|
}
|
|
|
|
// first, warp reduce sum
|
|
#pragma unroll
|
|
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
|
|
sum_sq += math::shfl_xor_sync(sum_sq, offset);
|
|
}
|
|
|
|
smem[ty] = sum_sq;
|
|
__syncthreads();
|
|
// then, cross warp reduce sum using only the first warp
|
|
if (ty == 0) {
|
|
sum_sq = (tx < num_warps) ? smem[tx] : 0.f;
|
|
#pragma unroll
|
|
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
|
|
sum_sq += math::shfl_xor_sync(sum_sq, offset);
|
|
}
|
|
smem[0] = sum_sq;
|
|
}
|
|
__syncthreads();
|
|
|
|
float rms_rcp = math::rsqrt(smem[0] / float(d) + eps);
|
|
|
|
for (uint32_t i = 0; i < rounds; i++) {
|
|
vec_t<T, VEC_SIZE> input_vec;
|
|
vec_t<T, VEC_SIZE> weight_vec;
|
|
vec_t<T, VEC_SIZE> output_vec;
|
|
input_vec.fill(0.f);
|
|
weight_vec.fill(0.f);
|
|
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
|
|
input_vec.load(input + bx * stride_input + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
|
|
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
|
|
}
|
|
#pragma unroll
|
|
for (uint32_t j = 0; j < VEC_SIZE; j++) {
|
|
output_vec[j] = float(input_vec[j]) * rms_rcp * (weight_bias + float(weight_vec[j]));
|
|
}
|
|
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
|
|
output_vec.store(output + bx * stride_output + i * num_threads * VEC_SIZE +
|
|
thread_id * VEC_SIZE);
|
|
}
|
|
}
|
|
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
asm volatile("griddepcontrol.launch_dependents;");
|
|
#endif
|
|
}
|
|
|
|
template <typename T>
|
|
cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_t d,
|
|
uint32_t stride_input, uint32_t stride_output, float eps = 1e-5,
|
|
bool enable_pdl = false, cudaStream_t stream = 0) {
|
|
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
|
|
|
|
const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
|
|
const uint32_t num_warps = ceil_div(block_size, 32);
|
|
dim3 nblks(batch_size);
|
|
dim3 nthrs(32, num_warps);
|
|
const uint32_t smem_size = num_warps * sizeof(float);
|
|
float weight_bias = 0.f;
|
|
void* args[] = {&input, &weight, &output, &d, &stride_input, &stride_output, &weight_bias, &eps};
|
|
|
|
cudaLaunchConfig_t config;
|
|
config.gridDim = nblks;
|
|
config.blockDim = nthrs;
|
|
config.dynamicSmemBytes = smem_size;
|
|
config.stream = stream;
|
|
cudaLaunchAttribute attrs[1];
|
|
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
|
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
|
|
config.numAttrs = 1;
|
|
config.attrs = attrs;
|
|
|
|
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
|
|
auto kernel = RMSNormKernel<VEC_SIZE, T>;
|
|
FLASHINFER_CUDA_CALL(
|
|
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
|
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, input, weight, output, d, stride_input,
|
|
stride_output, weight_bias, eps));
|
|
});
|
|
return cudaSuccess;
|
|
}
|
|
|
|
template <uint32_t VEC_SIZE, typename T>
|
|
__global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual,
|
|
T* __restrict__ weight, const uint32_t d,
|
|
const uint32_t stride_input, const uint32_t stride_residual,
|
|
float weight_bias, float eps) {
|
|
const uint32_t bx = blockIdx.x;
|
|
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
|
constexpr uint32_t warp_size = 32;
|
|
const uint32_t num_warps = blockDim.y;
|
|
const uint32_t thread_id = tx + ty * warp_size;
|
|
const uint32_t num_threads = num_warps * warp_size;
|
|
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
|
|
extern __shared__ float smem[];
|
|
float* smem_x = smem + ceil_div(num_warps, 4) * 4;
|
|
|
|
float sum_sq = 0.f;
|
|
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
asm volatile("griddepcontrol.wait;");
|
|
#endif
|
|
|
|
for (uint32_t i = 0; i < rounds; i++) {
|
|
vec_t<T, VEC_SIZE> input_vec;
|
|
input_vec.fill(0.f);
|
|
vec_t<T, VEC_SIZE> residual_vec;
|
|
residual_vec.fill(0.f);
|
|
vec_t<float, VEC_SIZE> x_vec;
|
|
x_vec.fill(0.f);
|
|
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
|
|
input_vec.load(input + bx * stride_input + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
|
|
residual_vec.load(residual + bx * stride_residual + i * num_threads * VEC_SIZE +
|
|
thread_id * VEC_SIZE);
|
|
}
|
|
#pragma unroll
|
|
for (uint32_t j = 0; j < VEC_SIZE; j++) {
|
|
float x = float(input_vec[j]);
|
|
x += float(residual_vec[j]);
|
|
sum_sq += x * x;
|
|
residual_vec[j] = (T)x;
|
|
x_vec[j] = x;
|
|
}
|
|
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
|
|
residual_vec.store(residual + bx * stride_residual + i * num_threads * VEC_SIZE +
|
|
thread_id * VEC_SIZE);
|
|
x_vec.store(smem_x + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
|
|
}
|
|
}
|
|
|
|
// first, warp reduce sum
|
|
#pragma unroll
|
|
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
|
|
sum_sq += math::shfl_xor_sync(sum_sq, offset);
|
|
}
|
|
|
|
smem[ty] = sum_sq;
|
|
__syncthreads();
|
|
// then, cross warp reduce sum using only the first warp
|
|
if (ty == 0) {
|
|
sum_sq = (tx < num_warps) ? smem[tx] : 0.f;
|
|
#pragma unroll
|
|
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
|
|
sum_sq += math::shfl_xor_sync(sum_sq, offset);
|
|
}
|
|
smem[0] = sum_sq;
|
|
}
|
|
__syncthreads();
|
|
|
|
float rms_rcp = math::rsqrt(smem[0] / float(d) + eps);
|
|
|
|
for (uint32_t i = 0; i < rounds; i++) {
|
|
vec_t<T, VEC_SIZE> input_vec;
|
|
vec_t<T, VEC_SIZE> weight_vec;
|
|
vec_t<float, VEC_SIZE> x_vec;
|
|
input_vec.fill(0.f);
|
|
weight_vec.fill(0.f);
|
|
x_vec.fill(0.f);
|
|
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
|
|
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
|
|
x_vec.load(smem_x + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
|
|
}
|
|
#pragma unroll
|
|
for (uint32_t j = 0; j < VEC_SIZE; j++) {
|
|
input_vec[j] = x_vec[j] * rms_rcp * (weight_bias + float(weight_vec[j]));
|
|
}
|
|
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
|
|
input_vec.store(input + bx * stride_input + i * num_threads * VEC_SIZE +
|
|
thread_id * VEC_SIZE);
|
|
}
|
|
}
|
|
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
|
asm volatile("griddepcontrol.launch_dependents;");
|
|
#endif
|
|
}
|
|
|
|
template <typename T>
|
|
cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d,
|
|
uint32_t stride_input, uint32_t stride_residual, float eps = 1e-5,
|
|
bool enable_pdl = false, cudaStream_t stream = 0) {
|
|
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
|
|
|
|
const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
|
|
const uint32_t num_warps = ceil_div(block_size, 32);
|
|
dim3 nblks(batch_size);
|
|
dim3 nthrs(32, num_warps);
|
|
const uint32_t smem_size = (ceil_div(num_warps, 4) * 4 + d) * sizeof(float);
|
|
float weight_bias = 0.f;
|
|
void* args[] = {&input, &residual, &weight, &d,
|
|
&stride_input, &stride_residual, &weight_bias, &eps};
|
|
|
|
cudaLaunchConfig_t config;
|
|
config.gridDim = nblks;
|
|
config.blockDim = nthrs;
|
|
config.dynamicSmemBytes = smem_size;
|
|
config.stream = stream;
|
|
cudaLaunchAttribute attrs[1];
|
|
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
|
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
|
|
config.numAttrs = 1;
|
|
config.attrs = attrs;
|
|
|
|
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
|
|
auto kernel = FusedAddRMSNormKernel<VEC_SIZE, T>;
|
|
FLASHINFER_CUDA_CALL(
|
|
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
|
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, input, residual, weight, d,
|
|
stride_input, stride_residual, weight_bias, eps));
|
|
});
|
|
|
|
return cudaSuccess;
|
|
}
|
|
|
|
template <typename T>
|
|
cudaError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_t d,
|
|
uint32_t stride_input, uint32_t stride_output, float eps = 1e-5,
|
|
bool enable_pdl = false, cudaStream_t stream = 0) {
|
|
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
|
|
|
|
const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
|
|
const uint32_t num_warps = ceil_div(block_size, 32);
|
|
dim3 nblks(batch_size);
|
|
dim3 nthrs(32, num_warps);
|
|
const uint32_t smem_size = num_warps * sizeof(float);
|
|
float weight_bias = 1.f;
|
|
void* args[] = {&input, &weight, &output, &d, &stride_input, &stride_output, &weight_bias, &eps};
|
|
|
|
cudaLaunchConfig_t config;
|
|
config.gridDim = nblks;
|
|
config.blockDim = nthrs;
|
|
config.dynamicSmemBytes = smem_size;
|
|
config.stream = stream;
|
|
cudaLaunchAttribute attrs[1];
|
|
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
|
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
|
|
config.numAttrs = 1;
|
|
config.attrs = attrs;
|
|
|
|
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
|
|
auto kernel = RMSNormKernel<VEC_SIZE, T>;
|
|
FLASHINFER_CUDA_CALL(
|
|
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
|
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, input, weight, output, d, stride_input,
|
|
stride_output, weight_bias, eps));
|
|
});
|
|
return cudaSuccess;
|
|
}
|
|
|
|
template <typename T>
|
|
cudaError_t GemmaFusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d,
|
|
uint32_t stride_input, uint32_t stride_residual, float eps = 1e-5,
|
|
bool enable_pdl = false, cudaStream_t stream = 0) {
|
|
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
|
|
|
|
const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
|
|
const uint32_t num_warps = ceil_div(block_size, 32);
|
|
dim3 nblks(batch_size);
|
|
dim3 nthrs(32, num_warps);
|
|
// NOTE(Zihao): use ceil_div(num_warps, 4) * 4 for address alignment to 16 bytes
|
|
const uint32_t smem_size = (ceil_div(num_warps, 4) * 4 + d) * sizeof(float);
|
|
float weight_bias = 1.f;
|
|
void* args[] = {&input, &residual, &weight, &d,
|
|
&stride_input, &stride_residual, &weight_bias, &eps};
|
|
|
|
cudaLaunchConfig_t config;
|
|
config.gridDim = nblks;
|
|
config.blockDim = nthrs;
|
|
config.dynamicSmemBytes = smem_size;
|
|
config.stream = stream;
|
|
cudaLaunchAttribute attrs[1];
|
|
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
|
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
|
|
config.numAttrs = 1;
|
|
config.attrs = attrs;
|
|
|
|
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
|
|
auto kernel = FusedAddRMSNormKernel<VEC_SIZE, T>;
|
|
FLASHINFER_CUDA_CALL(
|
|
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
|
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, input, residual, weight, d,
|
|
stride_input, stride_residual, weight_bias, eps));
|
|
});
|
|
|
|
return cudaSuccess;
|
|
}
|
|
|
|
} // namespace norm
|
|
|
|
} // namespace flashinfer
|
|
|
|
#endif // FLASHINFER_NORM_CUH_
|