1071 lines
47 KiB
Plaintext
1071 lines
47 KiB
Plaintext
/*
|
|
* Copyright (c) 2023 by FlashInfer team.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#ifndef FLASHINFER_POS_ENC_CUH_
|
|
#define FLASHINFER_POS_ENC_CUH_
|
|
|
|
#include <cmath>
|
|
#include <cstdint>
|
|
#include <iostream>
|
|
#include <string>
|
|
|
|
#include "layout.cuh"
|
|
#include "math.cuh"
|
|
#include "utils.cuh"
|
|
#include "vec_dtypes.cuh"
|
|
|
|
namespace flashinfer {
|
|
|
|
/*!
|
|
* \brief An enumeration class that defines different modes for applying RoPE
|
|
* (Rotary Positional Embeddings).
|
|
*/
|
|
enum class PosEncodingMode {
|
|
// No rotary positional embeddings
|
|
kNone = 0U,
|
|
// Apply Llama-style rope.
|
|
kRoPELlama = 1U,
|
|
// Apply ALiBi bias
|
|
kALiBi = 2U
|
|
};
|
|
|
|
/*!
|
|
* \brief Convert PosEncodingMode to string
|
|
* \param pos_encoding_mode A PosEncodingMode value
|
|
*/
|
|
inline std::string PosEncodingModeToString(const PosEncodingMode& pos_encoding_mode) {
|
|
switch (pos_encoding_mode) {
|
|
case PosEncodingMode::kNone:
|
|
return "None";
|
|
case PosEncodingMode::kRoPELlama:
|
|
return "Llama";
|
|
case PosEncodingMode::kALiBi:
|
|
return "ALiBi";
|
|
default:
|
|
return "Unknown";
|
|
}
|
|
}
|
|
|
|
__device__ __forceinline__ float get_alibi_slope(uint32_t head_idx, uint32_t num_heads) {
|
|
int n = math::ptx_exp2((int)math::ptx_log2(num_heads));
|
|
return head_idx < n ? math::ptx_exp2(-8. * float(head_idx + 1) / float(n))
|
|
: math::ptx_exp2(-4. * float((head_idx + 1 - n) * 2 - 1) / float(n));
|
|
}
|
|
|
|
/*!
|
|
* \brief Apply RoPE (Rotary Positional Embeddings) to x[0: head_dim],
|
|
* return thread-local vector
|
|
* \tparam vec_size A template integer indicates the vector size used
|
|
* in the kernel
|
|
* \tparam bdx A template integer indicates the blockDim.x
|
|
* \tparam T A template type indicates the x data type
|
|
* \param x A pointer to the start of x data
|
|
* \param freq A vector of float indicates the thread-local rope frequency
|
|
* \param offset A integer indicates the offset of the position in RoPE
|
|
*/
|
|
template <uint32_t vec_size, uint32_t bdx, typename T>
|
|
__device__ __forceinline__ vec_t<float, vec_size> vec_apply_llama_rope(
|
|
const T* x, const vec_t<float, vec_size>& freq, int32_t offset,
|
|
const uint32_t rotary_dim = vec_size * bdx) {
|
|
vec_t<float, vec_size> permuted_vec, vec;
|
|
vec.cast_load(x + threadIdx.x * vec_size);
|
|
|
|
if (threadIdx.x * vec_size < rotary_dim) {
|
|
permuted_vec.cast_load(x + ((threadIdx.x * vec_size < rotary_dim / 2)
|
|
? threadIdx.x * vec_size + rotary_dim / 2
|
|
: threadIdx.x * vec_size - rotary_dim / 2));
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < vec_size; ++i) {
|
|
float embed = float(offset) * freq[i];
|
|
float cos, sin;
|
|
__sincosf(embed, &sin, &cos);
|
|
vec[i] =
|
|
vec[i] * cos +
|
|
((threadIdx.x * vec_size < rotary_dim / 2) ? -permuted_vec[i] : permuted_vec[i]) * sin;
|
|
}
|
|
}
|
|
return vec;
|
|
}
|
|
|
|
template <uint32_t vec_size, uint32_t bdx, typename T>
|
|
__device__ __forceinline__ vec_t<float, vec_size> vec_apply_llama_rope_cos_sin(
|
|
const T* x, const vec_t<float, vec_size>& cos, const vec_t<float, vec_size>& sin,
|
|
const uint32_t rotary_dim = vec_size * bdx) {
|
|
vec_t<float, vec_size> permuted_vec, vec;
|
|
vec.cast_load(x + threadIdx.x * vec_size);
|
|
|
|
if (threadIdx.x * vec_size < rotary_dim) {
|
|
permuted_vec.cast_load(x + ((threadIdx.x * vec_size < rotary_dim / 2)
|
|
? threadIdx.x * vec_size + rotary_dim / 2
|
|
: threadIdx.x * vec_size - rotary_dim / 2));
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < vec_size; ++i) {
|
|
vec[i] =
|
|
vec[i] * cos[i] +
|
|
((threadIdx.x * vec_size < rotary_dim / 2) ? -permuted_vec[i] : permuted_vec[i]) * sin[i];
|
|
}
|
|
}
|
|
return vec;
|
|
}
|
|
|
|
/*!
|
|
* \brief Apply RoPE (Rotary Positional Embeddings) to x[0: head_dim] with interleave,
|
|
* return thread-local vector.
|
|
* \tparam vec_size A template integer indicates the vector size used
|
|
* in the kernel
|
|
* \tparam bdx A template integer indicates the blockDim.x
|
|
* \tparam T A template type indicates the x data type
|
|
* \param x A pointer to the start of x data
|
|
* \param freq A vector of float indicates the thread-local rope frequency
|
|
* \param offset A integer indicates the offset of the position in RoPE
|
|
*/
|
|
template <uint32_t vec_size, uint32_t bdx, typename T>
|
|
__device__ __forceinline__ vec_t<float, vec_size> vec_apply_llama_rope_interleave(
|
|
const T* x, const vec_t<float, vec_size>& freq, int32_t offset,
|
|
const uint32_t rotary_dim = vec_size * bdx) {
|
|
vec_t<float, vec_size> vec, vec_before;
|
|
vec.cast_load(x + threadIdx.x * vec_size);
|
|
|
|
if (threadIdx.x * vec_size < rotary_dim) {
|
|
vec_before = vec;
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < vec_size; ++i) {
|
|
float embed = float(offset) * freq[i];
|
|
float cos, sin;
|
|
__sincosf(embed, &sin, &cos);
|
|
vec[i] = vec[i] * cos + ((i % 2 == 0) ? -vec_before[i ^ 1] : vec_before[i ^ 1]) * sin;
|
|
}
|
|
}
|
|
return vec;
|
|
}
|
|
|
|
template <uint32_t vec_size, uint32_t bdx, typename T>
|
|
__device__ __forceinline__ vec_t<float, vec_size> vec_apply_llama_rope_cos_sin_interleave(
|
|
const T* x, const vec_t<float, vec_size>& cos, const vec_t<float, vec_size>& sin,
|
|
const uint32_t rotary_dim = vec_size * bdx) {
|
|
vec_t<float, vec_size> vec, vec_before;
|
|
vec.cast_load(x + threadIdx.x * vec_size);
|
|
|
|
if (threadIdx.x * vec_size < rotary_dim) {
|
|
vec_before = vec;
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < vec_size; ++i) {
|
|
vec[i] = vec[i] * cos[i] + ((i % 2 == 0) ? -vec_before[i ^ 1] : vec_before[i ^ 1]) * sin[i];
|
|
}
|
|
}
|
|
return vec;
|
|
}
|
|
|
|
/*
|
|
HACK (ByronHsu): in the interleave mode with cos_sin_cache, we actually only use the first half of
|
|
cos and sin
|
|
|
|
For example,
|
|
In the below example, the vec_size is 4
|
|
the computation in the kernel is:
|
|
[x1, x2, x3, x4...] * [cos1, cos1, cos2, cos2] + [-x2, x1, -x4, x3...] * [sin1, sin1, sin2,
|
|
sin2] the data we loaded are:
|
|
- loaded vec = [x1, x2, x3, x4]
|
|
- loaded cos = [cos1, cos2, cos3, cos4]
|
|
- loaded sin = [sin1, sin2, sin3, sin4]
|
|
But only the first half of cos and sin is used in the computation.
|
|
|
|
However, we argue the additional overhead is acceptable:
|
|
1. loading additional elements of cos and sin is not adding much overhead. The arithmetic
|
|
intensity is the same as non-interleave mode. Each elements of cos and sin is load twice
|
|
2. we don't want two code paths of cos and sin vector for interleave and non-interleave mode.
|
|
*/
|
|
template <uint32_t vec_size, uint32_t bdx, typename T>
|
|
__device__ __forceinline__ vec_t<float, vec_size>
|
|
vec_apply_llama_rope_cos_sin_interleave_reuse_half(const T* x, const vec_t<float, vec_size>& cos,
|
|
const vec_t<float, vec_size>& sin,
|
|
const uint32_t rotary_dim = vec_size * bdx) {
|
|
vec_t<float, vec_size> vec, vec_before;
|
|
vec.cast_load(x + threadIdx.x * vec_size);
|
|
|
|
if (threadIdx.x * vec_size < rotary_dim) {
|
|
vec_before = vec;
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < vec_size; ++i) {
|
|
// i / 2 is to get the index of the first half of cos and sin
|
|
vec[i] = vec[i] * cos[i / 2] +
|
|
((i % 2 == 0) ? -vec_before[i ^ 1] : vec_before[i ^ 1]) * sin[i / 2];
|
|
}
|
|
}
|
|
return vec;
|
|
}
|
|
|
|
template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
|
|
typename IdType>
|
|
__global__ void BatchQKApplyRotaryPosIdsCosSinCacheHeadParallelismKernel(
|
|
DType* q, DType* k, DType* q_rope, DType* k_rope, float* __restrict__ cos_sin_cache,
|
|
IdType* __restrict__ pos_ids, uint32_t nnz, uint32_t num_qo_heads, uint32_t num_kv_heads,
|
|
uint32_t rotary_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h,
|
|
size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n,
|
|
size_t k_rope_stride_h) {
|
|
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
|
|
uint32_t by = blockIdx.y;
|
|
const uint32_t bdy = blockDim.y;
|
|
|
|
vec_t<float, vec_size> cos, sin;
|
|
if (bx * bdy + ty < nnz) {
|
|
const uint32_t idx = bx * bdy + ty;
|
|
const IdType pos = pos_ids[idx];
|
|
|
|
const int half_rotary_dim = rotary_dim / 2;
|
|
|
|
// 1. if interleave:
|
|
// - cos = cos_sin_cache[pos_id][tx * vec_size // 2]
|
|
// - sin = cos_sin_cache[pos_id][(rot_dim // 2) + tx * vec_size // 2]
|
|
// 2. if not interleave
|
|
// - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)]
|
|
// - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % (rot_dim // 2)]
|
|
if (tx * vec_size < rotary_dim) {
|
|
int sin_offset = rotary_dim / 2;
|
|
int vec_idx;
|
|
if constexpr (interleave) {
|
|
vec_idx = (tx * vec_size) / 2; // Force integer division
|
|
} else {
|
|
vec_idx = (tx * vec_size) % half_rotary_dim; // Use half_rotary_dim
|
|
}
|
|
cos.load(cos_sin_cache + (pos * rotary_dim) + vec_idx);
|
|
sin.load(cos_sin_cache + (pos * rotary_dim) + (sin_offset + vec_idx));
|
|
}
|
|
|
|
if (by < num_qo_heads) {
|
|
uint32_t qo_head_idx = by;
|
|
DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h);
|
|
DType* q_rope_ptr =
|
|
q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h);
|
|
vec_t<float, vec_size> q_vec;
|
|
if constexpr (interleave) {
|
|
q_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(q_ptr, cos, sin,
|
|
rotary_dim);
|
|
} else {
|
|
q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
|
|
}
|
|
q_vec.cast_store(q_rope_ptr + tx * vec_size);
|
|
} else {
|
|
uint32_t kv_head_idx = by - num_qo_heads;
|
|
DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h);
|
|
DType* k_rope_ptr =
|
|
k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h);
|
|
vec_t<float, vec_size> k_vec;
|
|
if constexpr (interleave) {
|
|
k_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(k_ptr, cos, sin,
|
|
rotary_dim);
|
|
} else {
|
|
k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
|
|
}
|
|
k_vec.cast_store(k_rope_ptr + tx * vec_size);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
|
|
typename IdType>
|
|
__global__ void BatchQKApplyRotaryPosIdsCosSinCacheKernel(
|
|
DType* q, DType* k, DType* q_rope, DType* k_rope, float* __restrict__ cos_sin_cache,
|
|
IdType* __restrict__ pos_ids, uint32_t nnz, uint32_t num_qo_heads, uint32_t num_kv_heads,
|
|
uint32_t rotary_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h,
|
|
size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n,
|
|
size_t k_rope_stride_h) {
|
|
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
|
|
const uint32_t bdy = blockDim.y;
|
|
|
|
vec_t<float, vec_size> cos, sin;
|
|
if (bx * bdy + ty < nnz) {
|
|
const uint32_t idx = bx * bdy + ty;
|
|
const IdType pos = pos_ids[idx];
|
|
const int half_rotary_dim = rotary_dim / 2;
|
|
|
|
// 1. if interleave:
|
|
// - cos = cos_sin_cache[pos_id][tx * vec_size // 2]
|
|
// - sin = cos_sin_cache[pos_id][(rot_dim // 2) + tx * vec_size // 2]
|
|
// 2. if not interleave
|
|
// - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)]
|
|
// - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % (rot_dim // 2)]
|
|
if (tx * vec_size < rotary_dim) {
|
|
int sin_offset = rotary_dim / 2;
|
|
int vec_idx;
|
|
if constexpr (interleave) {
|
|
vec_idx = (tx * vec_size) / 2; // Force integer division
|
|
} else {
|
|
vec_idx = (tx * vec_size) % half_rotary_dim; // Use half_rotary_dim
|
|
}
|
|
cos.load(cos_sin_cache + (pos * rotary_dim) + vec_idx);
|
|
sin.load(cos_sin_cache + (pos * rotary_dim) + (sin_offset + vec_idx));
|
|
}
|
|
|
|
// not to unroll the loop, because num head might be large and might lead to worse performance
|
|
#pragma unroll 1
|
|
for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) {
|
|
DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h);
|
|
DType* q_rope_ptr =
|
|
q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h);
|
|
vec_t<float, vec_size> q_vec;
|
|
if constexpr (interleave) {
|
|
q_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(q_ptr, cos, sin,
|
|
rotary_dim);
|
|
} else {
|
|
q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
|
|
}
|
|
q_vec.cast_store(q_rope_ptr + tx * vec_size);
|
|
}
|
|
|
|
#pragma unroll 1
|
|
for (uint32_t kv_head_idx = 0; kv_head_idx < num_kv_heads; ++kv_head_idx) {
|
|
DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h);
|
|
DType* k_rope_ptr =
|
|
k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h);
|
|
vec_t<float, vec_size> k_vec;
|
|
if constexpr (interleave) {
|
|
k_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(k_ptr, cos, sin,
|
|
rotary_dim);
|
|
} else {
|
|
k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
|
|
}
|
|
k_vec.cast_store(k_rope_ptr + tx * vec_size);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <bool interleave, uint32_t vec_size, uint32_t bdx, typename DType, typename IdType,
|
|
typename QuantType>
|
|
__global__ void MLARopeQuantizeKernel(
|
|
DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, QuantType* q_rope_out,
|
|
QuantType* k_rope_out, QuantType* q_nope_out, QuantType* k_nope_out,
|
|
float* __restrict__ cos_sin_cache, IdType* __restrict__ pos_ids, uint32_t nnz,
|
|
uint32_t num_heads, size_t q_rope_in_stride_n, size_t q_rope_in_stride_h,
|
|
size_t q_nope_in_stride_n, size_t q_nope_in_stride_h, size_t q_rope_out_stride_n,
|
|
size_t q_rope_out_stride_h, size_t q_nope_out_stride_n, size_t q_nope_out_stride_h,
|
|
size_t k_rope_in_stride, size_t k_nope_in_stride, size_t k_rope_out_stride,
|
|
size_t k_nope_out_stride, float quant_scale_q, float quant_scale_kv) {
|
|
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
|
|
uint32_t by = blockIdx.y;
|
|
uint32_t bdy = blockDim.y;
|
|
constexpr uint32_t rotary_dim = 64;
|
|
|
|
vec_t<float, vec_size> cos, sin;
|
|
if (bx * bdy + ty < nnz) {
|
|
const uint32_t idx = bx * bdy + ty;
|
|
const IdType pos = pos_ids[idx];
|
|
|
|
const int half_rotary_dim = rotary_dim / 2;
|
|
// 1. if interleave:
|
|
// - cos = cos_sin_cache[pos_id][tx * vec_size // 2]
|
|
// - sin = cos_sin_cache[pos_id][(rot_dim // 2) + tx * vec_size // 2]
|
|
// 2. if not interleave
|
|
// - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)]
|
|
// - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % (rot_dim // 2)]
|
|
if (tx * vec_size < rotary_dim) {
|
|
int sin_offset = rotary_dim / 2;
|
|
int vec_idx;
|
|
if constexpr (interleave) {
|
|
vec_idx = (tx * vec_size) / 2; // Force integer division
|
|
} else {
|
|
vec_idx = (tx * vec_size) % half_rotary_dim; // Use half_rotary_dim
|
|
}
|
|
cos.load(cos_sin_cache + (pos * rotary_dim) + vec_idx);
|
|
sin.load(cos_sin_cache + (pos * rotary_dim) + (sin_offset + vec_idx));
|
|
}
|
|
|
|
if (by < num_heads) {
|
|
// Query RoPE, 64 dim
|
|
// allocate (num_heads,) blocks on blockDim.y
|
|
uint32_t q_head_idx = by;
|
|
DType* q_rope_in_ptr =
|
|
q_rope_in + get_elem_offset_impl(idx, q_head_idx, /*elem_idx=*/0, q_rope_in_stride_n,
|
|
q_rope_in_stride_h);
|
|
QuantType* q_rope_out_ptr =
|
|
q_rope_out + get_elem_offset_impl(idx, q_head_idx, /*elem_idx=*/0, q_rope_out_stride_n,
|
|
q_rope_out_stride_h);
|
|
vec_t<float, vec_size> q_rope_vec;
|
|
if constexpr (interleave) {
|
|
q_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(
|
|
q_rope_in_ptr, cos, sin, rotary_dim);
|
|
} else {
|
|
q_rope_vec =
|
|
vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_rope_in_ptr, cos, sin, rotary_dim);
|
|
}
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < vec_size; ++i) {
|
|
q_rope_vec[i] = q_rope_vec[i] * quant_scale_q;
|
|
}
|
|
q_rope_vec.cast_store(q_rope_out_ptr + tx * vec_size);
|
|
} else if (by == num_heads) {
|
|
// k/v RoPE, 64 dim
|
|
// allocate (1,) blocks on blockDim.y
|
|
DType* k_rope_in_ptr = k_rope_in + get_elem_offset_impl(idx, /*head_idx=*/0, /*elem_idx=*/0,
|
|
k_rope_in_stride, k_rope_in_stride);
|
|
QuantType* k_rope_out_ptr =
|
|
k_rope_out + get_elem_offset_impl(idx, /*head_idx=*/0, /*elem_idx=*/0, k_rope_out_stride,
|
|
k_rope_out_stride);
|
|
vec_t<float, vec_size> k_rope_vec;
|
|
if constexpr (interleave) {
|
|
k_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(
|
|
k_rope_in_ptr, cos, sin, rotary_dim);
|
|
} else {
|
|
k_rope_vec =
|
|
vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_rope_in_ptr, cos, sin, rotary_dim);
|
|
}
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < vec_size; ++i) {
|
|
k_rope_vec[i] = k_rope_vec[i] * quant_scale_kv;
|
|
}
|
|
k_rope_vec.cast_store(k_rope_out_ptr + tx * vec_size);
|
|
} else if (by <= num_heads + 8) {
|
|
// K/v Non-RoPE part, 512 dim
|
|
// allocate (8,) blocks on blockDim.y
|
|
uint32_t chunk_idx = (by - num_heads - 1);
|
|
DType* k_nope_in_ptr =
|
|
k_nope_in + get_elem_offset_impl(idx, /*head_idx=*/0, /*elem_idx=*/64 * chunk_idx,
|
|
k_nope_in_stride, k_nope_in_stride);
|
|
QuantType* k_nope_out_ptr =
|
|
k_nope_out + get_elem_offset_impl(idx, /*head_idx=*/0, /*elem_idx=*/64 * chunk_idx,
|
|
k_nope_out_stride, k_nope_out_stride);
|
|
vec_t<float, vec_size> k_nope_vec;
|
|
k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size);
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < vec_size; ++i) {
|
|
k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv;
|
|
}
|
|
k_nope_vec.cast_store(k_nope_out_ptr + tx * vec_size);
|
|
} else {
|
|
// Query Non-RoPE part, 512 dim
|
|
// allocate (num_heads * 8,) blocks on blockDim.y
|
|
uint32_t q_head_idx = (by - num_heads - 8 - 1) / 8;
|
|
uint32_t chunk_idx = (by - num_heads - 8 - 1) % 8;
|
|
DType* q_nope_in_ptr =
|
|
q_nope_in + get_elem_offset_impl(idx, q_head_idx, /*elem_idx=*/64 * chunk_idx,
|
|
q_nope_in_stride_n, q_nope_in_stride_h);
|
|
QuantType* q_nope_out_ptr =
|
|
q_nope_out + get_elem_offset_impl(idx, q_head_idx, /*elem_idx=*/64 * chunk_idx,
|
|
q_nope_out_stride_n, q_nope_out_stride_h);
|
|
vec_t<float, vec_size> q_nope_vec;
|
|
q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size);
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < vec_size; ++i) {
|
|
q_nope_vec[i] = q_nope_vec[i] * quant_scale_q;
|
|
}
|
|
q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
|
|
typename IdType>
|
|
__global__ void BatchQKApplyRotaryPosIdsHeadParallelismKernel(
|
|
DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ pos_ids, uint32_t nnz,
|
|
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rotary_dim, size_t q_stride_n,
|
|
size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n,
|
|
size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, float smooth_a,
|
|
float smooth_b, float rope_rcp_scale, float rope_rcp_theta) {
|
|
// NOTE: q and q_rope may be the same ptr, so do k and k_rope
|
|
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
|
|
uint32_t by = blockIdx.y;
|
|
const uint32_t bdy = blockDim.y;
|
|
vec_t<float, vec_size> freq;
|
|
if (tx * vec_size < rotary_dim) {
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < vec_size; ++i) {
|
|
if constexpr (interleave) {
|
|
freq[i] = __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) / 2)) / float(rotary_dim));
|
|
} else {
|
|
freq[i] = __powf(rope_rcp_theta,
|
|
float(2 * ((tx * vec_size + i) % (rotary_dim / 2))) / float(rotary_dim));
|
|
}
|
|
|
|
float smooth = freq[i] * smooth_a + smooth_b;
|
|
smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1]
|
|
freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i];
|
|
}
|
|
}
|
|
|
|
vec_t<float, vec_size> cos, sin;
|
|
|
|
if (bx * bdy + ty < nnz) {
|
|
const uint32_t idx = bx * bdy + ty;
|
|
const IdType pos = pos_ids[idx];
|
|
|
|
if (tx * vec_size < rotary_dim) {
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < vec_size; ++i) {
|
|
float embed = float(pos) * freq[i];
|
|
__sincosf(embed, &sin[i], &cos[i]);
|
|
}
|
|
}
|
|
|
|
if (by < num_qo_heads) {
|
|
uint32_t qo_head_idx = by;
|
|
DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h);
|
|
DType* q_rope_ptr =
|
|
q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h);
|
|
vec_t<float, vec_size> q_vec;
|
|
if constexpr (interleave) {
|
|
q_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
|
|
} else {
|
|
q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
|
|
}
|
|
q_vec.cast_store(q_rope_ptr + tx * vec_size);
|
|
} else {
|
|
uint32_t kv_head_idx = by - num_qo_heads;
|
|
DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h);
|
|
DType* k_rope_ptr =
|
|
k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h);
|
|
vec_t<float, vec_size> k_vec;
|
|
if constexpr (interleave) {
|
|
k_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
|
|
} else {
|
|
k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
|
|
}
|
|
k_vec.cast_store(k_rope_ptr + tx * vec_size);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
|
|
typename IdType>
|
|
__global__ void BatchQKApplyRotaryPosIdsKernel(
|
|
DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ pos_ids, uint32_t nnz,
|
|
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rotary_dim, size_t q_stride_n,
|
|
size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n,
|
|
size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, float smooth_a,
|
|
float smooth_b, float rope_rcp_scale, float rope_rcp_theta) {
|
|
// NOTE: q and q_rope may be the same ptr, so do k and k_rope
|
|
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
|
|
const uint32_t bdy = blockDim.y;
|
|
vec_t<float, vec_size> freq;
|
|
if (tx * vec_size < rotary_dim) {
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < vec_size; ++i) {
|
|
if constexpr (interleave) {
|
|
freq[i] = __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) / 2)) / float(rotary_dim));
|
|
} else {
|
|
freq[i] = __powf(rope_rcp_theta,
|
|
float(2 * ((tx * vec_size + i) % (rotary_dim / 2))) / float(rotary_dim));
|
|
}
|
|
|
|
float smooth = freq[i] * smooth_a + smooth_b;
|
|
smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1]
|
|
freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i];
|
|
}
|
|
}
|
|
|
|
vec_t<float, vec_size> cos, sin;
|
|
|
|
if (bx * bdy + ty < nnz) {
|
|
const uint32_t idx = bx * bdy + ty;
|
|
const IdType pos = pos_ids[idx];
|
|
|
|
if (tx * vec_size < rotary_dim) {
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < vec_size; ++i) {
|
|
float embed = float(pos) * freq[i];
|
|
__sincosf(embed, &sin[i], &cos[i]);
|
|
}
|
|
}
|
|
|
|
#pragma unroll 1
|
|
for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) {
|
|
DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h);
|
|
DType* q_rope_ptr =
|
|
q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h);
|
|
vec_t<float, vec_size> q_vec;
|
|
if constexpr (interleave) {
|
|
q_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
|
|
} else {
|
|
q_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_ptr, cos, sin, rotary_dim);
|
|
}
|
|
q_vec.cast_store(q_rope_ptr + tx * vec_size);
|
|
}
|
|
|
|
#pragma unroll 1
|
|
for (uint32_t kv_head_idx = 0; kv_head_idx < num_kv_heads; ++kv_head_idx) {
|
|
DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h);
|
|
DType* k_rope_ptr =
|
|
k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h);
|
|
vec_t<float, vec_size> k_vec;
|
|
if constexpr (interleave) {
|
|
k_vec = vec_apply_llama_rope_cos_sin_interleave<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
|
|
} else {
|
|
k_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_ptr, cos, sin, rotary_dim);
|
|
}
|
|
k_vec.cast_store(k_rope_ptr + tx * vec_size);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
|
|
typename IdType>
|
|
__global__ void BatchQKApplyRotaryKernel(
|
|
DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ indptr,
|
|
IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
|
|
uint32_t rotary_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h,
|
|
size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h,
|
|
float smooth_a, float smooth_b, float rope_rcp_scale, float rope_rcp_theta) {
|
|
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
|
|
const uint32_t bdy = blockDim.y;
|
|
vec_t<float, vec_size> freq;
|
|
if (tx * vec_size < rotary_dim) {
|
|
#pragma unroll
|
|
for (uint32_t i = 0; i < vec_size; ++i) {
|
|
if constexpr (interleave) {
|
|
freq[i] = __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) / 2)) / float(rotary_dim));
|
|
} else {
|
|
freq[i] = __powf(rope_rcp_theta,
|
|
float(2 * ((tx * vec_size + i) % (rotary_dim / 2))) / float(rotary_dim));
|
|
}
|
|
|
|
float smooth = freq[i] * smooth_a + smooth_b;
|
|
smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1]
|
|
freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i];
|
|
}
|
|
}
|
|
|
|
if (bx < batch_size * num_qo_heads) {
|
|
// apply rotary to q
|
|
const uint32_t batch_idx = bx / num_qo_heads;
|
|
const uint32_t qo_head_idx = bx % num_qo_heads;
|
|
const uint32_t seq_len = indptr[batch_idx + 1] - indptr[batch_idx];
|
|
const uint32_t offset = offsets[batch_idx];
|
|
#pragma unroll 2
|
|
for (uint32_t i = 0; i < (seq_len + bdy - 1) / bdy; ++i) {
|
|
vec_t<float, vec_size> q_vec;
|
|
if (i * bdy + ty < seq_len) {
|
|
DType* q_ptr = q + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0,
|
|
q_stride_n, q_stride_h);
|
|
DType* q_rope_ptr =
|
|
q_rope + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0,
|
|
q_rope_stride_n, q_rope_stride_h);
|
|
if constexpr (interleave) {
|
|
q_vec = vec_apply_llama_rope_interleave<vec_size, bdx>(q_ptr, freq, offset + i * bdy + ty,
|
|
rotary_dim);
|
|
} else {
|
|
q_vec =
|
|
vec_apply_llama_rope<vec_size, bdx>(q_ptr, freq, offset + i * bdy + ty, rotary_dim);
|
|
}
|
|
q_vec.cast_store(q_rope_ptr + tx * vec_size);
|
|
}
|
|
}
|
|
} else {
|
|
// apply rotary to k
|
|
uint32_t batch_idx = (bx - batch_size * num_qo_heads) / num_kv_heads;
|
|
uint32_t kv_head_idx = (bx - batch_size * num_qo_heads) % num_kv_heads;
|
|
const uint32_t seq_len = indptr[batch_idx + 1] - indptr[batch_idx];
|
|
const uint32_t offset = offsets[batch_idx];
|
|
#pragma unroll 2
|
|
for (uint32_t i = 0; i < (seq_len + bdy - 1) / bdy; ++i) {
|
|
vec_t<float, vec_size> k_vec;
|
|
if (i * bdy + ty < seq_len) {
|
|
DType* k_ptr = k + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0,
|
|
k_stride_n, k_stride_h);
|
|
DType* k_rope_ptr =
|
|
k_rope + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0,
|
|
k_rope_stride_n, k_rope_stride_h);
|
|
if constexpr (interleave) {
|
|
k_vec = vec_apply_llama_rope_interleave<vec_size, bdx>(k_ptr, freq, offset + i * bdy + ty,
|
|
rotary_dim);
|
|
} else {
|
|
k_vec =
|
|
vec_apply_llama_rope<vec_size, bdx>(k_ptr, freq, offset + i * bdy + ty, rotary_dim);
|
|
}
|
|
k_vec.cast_store(k_rope_ptr + tx * vec_size);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
|
|
if (interleave) { \
|
|
const bool INTERLEAVE = true; \
|
|
__VA_ARGS__ \
|
|
} else { \
|
|
const bool INTERLEAVE = false; \
|
|
__VA_ARGS__ \
|
|
}
|
|
|
|
template <typename DType, typename IdType, typename QuantType>
|
|
cudaError_t MLARopeQuantize(DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in,
|
|
QuantType* q_rope_out, QuantType* k_rope_out, QuantType* q_nope_out,
|
|
QuantType* k_nope_out, float* cos_sin_cache, IdType* pos_ids,
|
|
uint32_t nnz, uint32_t num_heads, size_t q_rope_in_stride_n,
|
|
size_t q_rope_in_stride_h, size_t q_nope_in_stride_n,
|
|
size_t q_nope_in_stride_h, size_t q_rope_out_stride_n,
|
|
size_t q_rope_out_stride_h, size_t q_nope_out_stride_n,
|
|
size_t q_nope_out_stride_h, size_t k_rope_in_stride,
|
|
size_t k_nope_in_stride, size_t k_rope_out_stride,
|
|
size_t k_nope_out_stride, float quant_scale_q, float quant_scale_kv,
|
|
bool interleave, cudaStream_t stream = nullptr) {
|
|
int dev_id = 0;
|
|
int num_sms = 0;
|
|
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
|
|
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
|
|
|
|
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
|
constexpr uint32_t rotary_dim = 64;
|
|
constexpr uint32_t vec_size = 16 / sizeof(DType);
|
|
constexpr uint32_t bdx = rotary_dim / vec_size;
|
|
uint32_t num_threads = 128U;
|
|
uint32_t bdy = num_threads / bdx;
|
|
uint32_t nblks_x = (nnz + bdy - 1) / bdy;
|
|
|
|
void* args[] = {(void*)&q_rope_in,
|
|
(void*)&k_rope_in,
|
|
(void*)&q_nope_in,
|
|
(void*)&k_nope_in,
|
|
(void*)&q_rope_out,
|
|
(void*)&k_rope_out,
|
|
(void*)&q_nope_out,
|
|
(void*)&k_nope_out,
|
|
(void*)&cos_sin_cache,
|
|
(void*)&pos_ids,
|
|
(void*)&nnz,
|
|
(void*)&num_heads,
|
|
(void*)&q_rope_in_stride_n,
|
|
(void*)&q_rope_in_stride_h,
|
|
(void*)&q_nope_in_stride_n,
|
|
(void*)&q_nope_in_stride_h,
|
|
(void*)&q_rope_out_stride_n,
|
|
(void*)&q_rope_out_stride_h,
|
|
(void*)&q_nope_out_stride_n,
|
|
(void*)&q_nope_out_stride_h,
|
|
(void*)&k_rope_in_stride,
|
|
(void*)&k_nope_in_stride,
|
|
(void*)&k_rope_out_stride,
|
|
(void*)&k_nope_out_stride,
|
|
(void*)&quant_scale_q,
|
|
(void*)&quant_scale_kv};
|
|
auto kernel = MLARopeQuantizeKernel<INTERLEAVE, vec_size, bdx, DType, IdType, QuantType>;
|
|
dim3 nblks(nblks_x, num_heads + 8 + 1 + num_heads * 8);
|
|
dim3 nthrs(bdx, bdy);
|
|
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
|
|
});
|
|
|
|
return cudaSuccess;
|
|
}
|
|
|
|
template <typename DType, typename IdType>
|
|
cudaError_t BatchQKApplyRotaryPosIdsCosSinCache(
|
|
DType* q, DType* k, DType* q_rope, DType* k_rope, float* cos_sin_cache, IdType* pos_ids,
|
|
uint32_t nnz, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rotary_dim,
|
|
uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h,
|
|
size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h,
|
|
bool interleave, cudaStream_t stream = nullptr) {
|
|
int dev_id = 0;
|
|
int num_sms = 0;
|
|
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
|
|
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
|
|
|
|
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
|
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
|
|
// operate on 16 Bytes at a time
|
|
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
|
|
// how many threads needed per head_dim
|
|
constexpr uint32_t bdx = HEAD_DIM / vec_size;
|
|
// how many threads needed per block
|
|
uint32_t num_threads = std::max(128U, bdx);
|
|
// how many tokens can we process in a block
|
|
uint32_t bdy = num_threads / bdx;
|
|
// how many blocks needed to process all tokens
|
|
uint32_t nblks_x = (nnz + bdy - 1) / bdy;
|
|
void* args[] = {(void*)&q,
|
|
(void*)&k,
|
|
(void*)&q_rope,
|
|
(void*)&k_rope,
|
|
(void*)&cos_sin_cache,
|
|
(void*)&pos_ids,
|
|
(void*)&nnz,
|
|
(void*)&num_qo_heads,
|
|
(void*)&num_kv_heads,
|
|
(void*)&rotary_dim,
|
|
(void*)&q_stride_n,
|
|
(void*)&q_stride_h,
|
|
(void*)&k_stride_n,
|
|
(void*)&k_stride_h,
|
|
(void*)&q_rope_stride_n,
|
|
(void*)&q_rope_stride_h,
|
|
(void*)&k_rope_stride_n,
|
|
(void*)&k_rope_stride_h};
|
|
auto kernel_0 = BatchQKApplyRotaryPosIdsCosSinCacheKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx,
|
|
DType, IdType>;
|
|
|
|
int num_blocks_per_sm_0 = 0;
|
|
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
|
&num_blocks_per_sm_0, kernel_0, num_threads, /*smem_size=*/0));
|
|
uint32_t num_ctas_0 = num_blocks_per_sm_0 * num_sms;
|
|
|
|
if ((nnz + bdy - 1) / bdy >= num_ctas_0) {
|
|
dim3 nblks(nblks_x);
|
|
dim3 nthrs(bdx, bdy);
|
|
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_0, nblks, nthrs, args, 0, stream));
|
|
} else {
|
|
dim3 nblks(nblks_x, num_qo_heads + num_kv_heads);
|
|
dim3 nthrs(bdx, bdy);
|
|
auto kernel_1 =
|
|
BatchQKApplyRotaryPosIdsCosSinCacheHeadParallelismKernel<INTERLEAVE, HEAD_DIM, vec_size,
|
|
bdx, DType, IdType>;
|
|
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_1, nblks, nthrs, args, 0, stream));
|
|
}
|
|
});
|
|
});
|
|
|
|
return cudaSuccess;
|
|
}
|
|
|
|
template <typename DType, typename IdType>
|
|
cudaError_t BatchQKApplyRotaryPosIds(
|
|
DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ pos_ids, uint32_t nnz,
|
|
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rotary_dim, uint32_t head_dim,
|
|
size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h,
|
|
size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h,
|
|
bool interleave, float rope_scale, float rope_theta, cudaStream_t stream = nullptr) {
|
|
float rope_rcp_scale = 1.0f / rope_scale;
|
|
float rope_rcp_theta = 1.0f / rope_theta;
|
|
float smooth_a = 0.f;
|
|
float smooth_b = 0.f;
|
|
int dev_id = 0;
|
|
int num_sms = 0;
|
|
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
|
|
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
|
|
|
|
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
|
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
|
|
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
|
|
constexpr uint32_t bdx = HEAD_DIM / vec_size;
|
|
uint32_t num_threads = std::max(128U, bdx);
|
|
uint32_t bdy = num_threads / bdx;
|
|
uint32_t nblks_x = (nnz + bdy - 1) / bdy;
|
|
|
|
void* args[] = {(void*)&q,
|
|
(void*)&k,
|
|
(void*)&q_rope,
|
|
(void*)&k_rope,
|
|
(void*)&pos_ids,
|
|
(void*)&nnz,
|
|
(void*)&num_qo_heads,
|
|
(void*)&num_kv_heads,
|
|
(void*)&rotary_dim,
|
|
(void*)&q_stride_n,
|
|
(void*)&q_stride_h,
|
|
(void*)&k_stride_n,
|
|
(void*)&k_stride_h,
|
|
(void*)&q_rope_stride_n,
|
|
(void*)&q_rope_stride_h,
|
|
(void*)&k_rope_stride_n,
|
|
(void*)&k_rope_stride_h,
|
|
(void*)&smooth_a,
|
|
(void*)&smooth_b,
|
|
(void*)&rope_rcp_scale,
|
|
(void*)&rope_rcp_theta};
|
|
auto kernel_0 =
|
|
BatchQKApplyRotaryPosIdsKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;
|
|
|
|
int num_blocks_per_sm_0 = 0;
|
|
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
|
&num_blocks_per_sm_0, kernel_0, num_threads, /*smem_size=*/0));
|
|
uint32_t num_ctas_0 = num_blocks_per_sm_0 * num_sms;
|
|
if (nblks_x >= num_ctas_0) {
|
|
dim3 nblks(nblks_x);
|
|
dim3 nthrs(bdx, bdy);
|
|
|
|
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_0, nblks, nthrs, args, 0, stream));
|
|
} else {
|
|
dim3 nblks(nblks_x, num_qo_heads + num_kv_heads);
|
|
dim3 nthrs(bdx, bdy);
|
|
auto kernel_1 = BatchQKApplyRotaryPosIdsHeadParallelismKernel<INTERLEAVE, HEAD_DIM,
|
|
vec_size, bdx, DType, IdType>;
|
|
|
|
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_1, nblks, nthrs, args, 0, stream));
|
|
}
|
|
});
|
|
});
|
|
|
|
return cudaSuccess;
|
|
}
|
|
|
|
template <typename DType, typename IdType>
|
|
cudaError_t BatchQKApplyRotary(DType* q, DType* k, DType* q_rope, DType* k_rope,
|
|
IdType* __restrict__ indptr, IdType* __restrict__ offsets,
|
|
uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
|
|
uint32_t rotary_dim, uint32_t head_dim, size_t q_stride_n,
|
|
size_t q_stride_h, size_t k_stride_n, size_t k_stride_h,
|
|
size_t q_rope_stride_n, size_t q_rope_stride_h,
|
|
size_t k_rope_stride_n, size_t k_rope_stride_h, bool interleave,
|
|
float rope_scale, float rope_theta, cudaStream_t stream = nullptr) {
|
|
float rope_rcp_scale = 1.0f / rope_scale;
|
|
float rope_rcp_theta = 1.0f / rope_theta;
|
|
float smooth_a = 0.f;
|
|
float smooth_b = 0.f;
|
|
|
|
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
|
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
|
|
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
|
|
constexpr uint32_t bdx = HEAD_DIM / vec_size;
|
|
uint32_t num_threads = std::max(128U, bdx);
|
|
uint32_t bdy = num_threads / bdx;
|
|
dim3 nblks(batch_size * (num_qo_heads + num_kv_heads));
|
|
dim3 nthrs(bdx, bdy);
|
|
auto kernel = BatchQKApplyRotaryKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;
|
|
void* args[] = {(void*)&q,
|
|
(void*)&k,
|
|
(void*)&q_rope,
|
|
(void*)&k_rope,
|
|
(void*)&indptr,
|
|
(void*)&offsets,
|
|
(void*)&batch_size,
|
|
(void*)&num_qo_heads,
|
|
(void*)&num_kv_heads,
|
|
(void*)&rotary_dim,
|
|
(void*)&q_stride_n,
|
|
(void*)&q_stride_h,
|
|
(void*)&k_stride_n,
|
|
(void*)&k_stride_h,
|
|
(void*)&q_rope_stride_n,
|
|
(void*)&q_rope_stride_h,
|
|
(void*)&k_rope_stride_n,
|
|
(void*)&k_rope_stride_h,
|
|
(void*)&smooth_a,
|
|
(void*)&smooth_b,
|
|
(void*)&rope_rcp_scale,
|
|
(void*)&rope_rcp_theta};
|
|
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
|
|
});
|
|
});
|
|
|
|
return cudaSuccess;
|
|
}
|
|
|
|
template <typename DType, typename IdType>
|
|
cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ k,
|
|
IdType* __restrict__ indptr, IdType* __restrict__ offsets,
|
|
uint32_t batch_size, uint32_t num_qo_heads,
|
|
uint32_t num_kv_heads, uint32_t rotary_dim, uint32_t head_dim,
|
|
size_t q_stride_n, size_t q_stride_h, size_t k_stride_n,
|
|
size_t k_stride_h, bool interleave, float rope_scale,
|
|
float rope_theta, cudaStream_t stream = nullptr) {
|
|
return BatchQKApplyRotary<DType, IdType>(
|
|
q, k, q, k, indptr, offsets, batch_size, num_qo_heads, num_kv_heads, rotary_dim, head_dim,
|
|
q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_stride_n, q_stride_h, k_stride_n,
|
|
k_stride_h, interleave, rope_scale, rope_theta, stream);
|
|
}
|
|
|
|
template <typename DType, typename IdType>
|
|
cudaError_t BatchQKApplyLlama31Rotary(
|
|
DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ indptr,
|
|
IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
|
|
uint32_t rotary_dim, uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n,
|
|
size_t k_stride_h, size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n,
|
|
size_t k_rope_stride_h, bool interleave, float rope_scale, float rope_theta,
|
|
float low_freq_factor, float high_freq_factor, float old_context_length,
|
|
cudaStream_t stream = nullptr) {
|
|
float rope_rcp_scale = 1.0f / rope_scale;
|
|
float rope_rcp_theta = 1.0f / rope_theta;
|
|
float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor);
|
|
float smooth_b = -1.0f / (high_freq_factor / low_freq_factor - 1.0f);
|
|
|
|
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
|
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
|
|
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
|
|
constexpr uint32_t bdx = HEAD_DIM / vec_size;
|
|
uint32_t num_threads = std::max(128U, bdx);
|
|
uint32_t bdy = num_threads / bdx;
|
|
dim3 nblks(batch_size * (num_qo_heads + num_kv_heads));
|
|
dim3 nthrs(bdx, bdy);
|
|
auto kernel = BatchQKApplyRotaryKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;
|
|
void* args[] = {(void*)&q,
|
|
(void*)&k,
|
|
(void*)&q_rope,
|
|
(void*)&k_rope,
|
|
(void*)&indptr,
|
|
(void*)&offsets,
|
|
(void*)&batch_size,
|
|
(void*)&num_qo_heads,
|
|
(void*)&num_kv_heads,
|
|
(void*)&rotary_dim,
|
|
(void*)&q_stride_n,
|
|
(void*)&q_stride_h,
|
|
(void*)&k_stride_n,
|
|
(void*)&k_stride_h,
|
|
(void*)&q_rope_stride_n,
|
|
(void*)&q_rope_stride_h,
|
|
(void*)&k_rope_stride_n,
|
|
(void*)&k_rope_stride_h,
|
|
(void*)&smooth_a,
|
|
(void*)&smooth_b,
|
|
(void*)&rope_rcp_scale,
|
|
(void*)&rope_rcp_theta};
|
|
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
|
|
});
|
|
});
|
|
|
|
return cudaSuccess;
|
|
}
|
|
|
|
template <typename DType, typename IdType>
|
|
cudaError_t BatchQKApplyLlama31RotaryPosIds(
|
|
DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* pos_ids, uint32_t nnz,
|
|
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rotary_dim, uint32_t head_dim,
|
|
size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h,
|
|
size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h,
|
|
bool interleave, float rope_scale, float rope_theta, float low_freq_factor,
|
|
float high_freq_factor, float old_context_length, cudaStream_t stream = nullptr) {
|
|
float rope_rcp_scale = 1.0f / rope_scale;
|
|
float rope_rcp_theta = 1.0f / rope_theta;
|
|
float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor);
|
|
float smooth_b = -1.0f / (high_freq_factor / low_freq_factor - 1.0f);
|
|
|
|
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
|
|
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
|
|
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
|
|
constexpr uint32_t bdx = HEAD_DIM / vec_size;
|
|
uint32_t num_threads = std::max(128U, bdx);
|
|
uint32_t bdy = num_threads / bdx;
|
|
dim3 nblks((nnz + bdy - 1) / bdy);
|
|
dim3 nthrs(bdx, bdy);
|
|
auto kernel =
|
|
BatchQKApplyRotaryPosIdsKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;
|
|
void* args[] = {(void*)&q,
|
|
(void*)&k,
|
|
(void*)&q_rope,
|
|
(void*)&k_rope,
|
|
(void*)&pos_ids,
|
|
(void*)&nnz,
|
|
(void*)&num_qo_heads,
|
|
(void*)&num_kv_heads,
|
|
(void*)&rotary_dim,
|
|
(void*)&q_stride_n,
|
|
(void*)&q_stride_h,
|
|
(void*)&k_stride_n,
|
|
(void*)&k_stride_h,
|
|
(void*)&q_rope_stride_n,
|
|
(void*)&q_rope_stride_h,
|
|
(void*)&k_rope_stride_n,
|
|
(void*)&k_rope_stride_h,
|
|
(void*)&smooth_a,
|
|
(void*)&smooth_b,
|
|
(void*)&rope_rcp_scale,
|
|
(void*)&rope_rcp_theta};
|
|
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
|
|
});
|
|
});
|
|
|
|
return cudaSuccess;
|
|
}
|
|
|
|
} // namespace flashinfer
|
|
|
|
#endif // FLASHINFER_POS_ENC_CUH_
|