sglang_v0.5.2/flashinfer_0.3.1/include/flashinfer/layout.cuh

129 lines
4.9 KiB
Plaintext

/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_LAYOUT_CUH_
#define FLASHINFER_LAYOUT_CUH_
#include <cstdint>
#include <string>
#include <tuple>
namespace flashinfer {
/*!
* \brief The Layout of QKV matrices
*/
enum class QKVLayout {
// [seq_len, num_heads, head_dim]
kNHD = 0U,
// [num_heads, seq_len, head_dim]
kHND = 1U,
};
__host__ __device__ __forceinline__ size_t get_elem_offset_impl(size_t elem_idx, size_t head_idx,
size_t feat_idx, size_t stride_n,
size_t stride_h) {
return elem_idx * stride_n + head_idx * stride_h + feat_idx;
}
__host__ __forceinline__ auto get_qkv_strides(QKVLayout kv_layout, uint32_t kv_len,
uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t head_dim) {
const uint32_t q_stride_n = num_qo_heads * head_dim, q_stride_h = head_dim,
kv_stride_n = (kv_layout == QKVLayout::kNHD) ? num_kv_heads * head_dim : head_dim,
kv_stride_h = (kv_layout == QKVLayout::kNHD) ? head_dim : kv_len * head_dim;
return std::make_tuple(q_stride_n, q_stride_h, kv_stride_n, kv_stride_h);
}
struct tensor_info_t {
uint32_t qo_len;
uint32_t kv_len;
uint32_t num_qo_heads;
uint32_t num_kv_heads;
uint32_t q_stride_n;
uint32_t q_stride_h;
uint32_t kv_stride_n;
uint32_t kv_stride_h;
uint32_t head_dim;
__host__ __device__ __forceinline__ tensor_info_t(uint32_t qo_len, uint32_t kv_len,
uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t q_stride_n, uint32_t q_stride_h,
uint32_t kv_stride_n, uint32_t kv_stride_h,
uint32_t head_dim)
: qo_len(qo_len),
kv_len(kv_len),
num_qo_heads(num_qo_heads),
num_kv_heads(num_kv_heads),
q_stride_n(q_stride_n),
q_stride_h(q_stride_h),
kv_stride_n(kv_stride_n),
kv_stride_h(kv_stride_h),
head_dim(head_dim) {}
__host__ __device__ __forceinline__ tensor_info_t(uint32_t qo_len, uint32_t kv_len,
uint32_t num_qo_heads, uint32_t num_kv_heads,
QKVLayout kv_layout, uint32_t head_dim)
: qo_len(qo_len),
kv_len(kv_len),
num_qo_heads(num_qo_heads),
num_kv_heads(num_kv_heads),
head_dim(head_dim) {
q_stride_n = num_qo_heads * head_dim;
q_stride_h = head_dim;
kv_stride_n = (kv_layout == QKVLayout::kNHD) ? num_kv_heads * head_dim : head_dim;
kv_stride_h = (kv_layout == QKVLayout::kNHD) ? head_dim : kv_len * head_dim;
}
__host__ __device__ __forceinline__ size_t get_q_elem_offset(uint32_t qo_idx,
uint32_t qo_head_idx,
uint32_t feat_idx) const {
return get_elem_offset_impl(qo_idx, qo_head_idx, feat_idx, q_stride_n, q_stride_h);
}
__host__ __device__ __forceinline__ size_t get_o_elem_offset(uint32_t qo_idx,
uint32_t qo_head_idx,
uint32_t feat_idx) const {
return get_elem_offset_impl(qo_idx, qo_head_idx, feat_idx, num_qo_heads * head_dim, head_dim);
}
__host__ __device__ __forceinline__ size_t get_kv_elem_offset(uint32_t kv_idx,
uint32_t kv_head_idx,
uint32_t feat_idx) const {
return get_elem_offset_impl(kv_idx, kv_head_idx, feat_idx, kv_stride_n, kv_stride_h);
}
__host__ __device__ __forceinline__ uint32_t get_group_size() const {
return num_qo_heads / num_kv_heads;
}
};
/*!
* \brief Convert QKVLayout to string
* \param layout The QKVLayout to convert
*/
inline std::string QKVLayoutToString(const QKVLayout& layout) {
switch (layout) {
case QKVLayout::kNHD:
return "NHD";
case QKVLayout::kHND:
return "HND";
default:
return "Unknown";
}
}
} // namespace flashinfer
#endif // FLASHINFER_LAYOUT_CUH_