1566 lines
50 KiB
C++
1566 lines
50 KiB
C++
#include "common.h"
|
||
#include "gemm.h"
|
||
#include "vec.h"
|
||
|
||
namespace {
|
||
|
||
// [NOTE] TODO list for this kernel:
|
||
// 1. tune the value for BLOCK_N
|
||
// 2. planning for {batches, num_heads, num_kv_splits}
|
||
// and use actual num_kv_splits for small seq length
|
||
// 3. try fast impl of `.tanh()`
|
||
// 4. provide amx kernel for index_gemm_kernel_nn when M = 16
|
||
//
|
||
|
||
#if defined(CPU_CAPABILITY_AVX512)
|
||
// key: from [N, 32] to [32/2, N, 2]
|
||
// val: from [N, 32] to [N/2, 32, 2]
|
||
template <typename scalar_t, typename index_t>
|
||
inline void pack_vnni_Nx32(
|
||
scalar_t* __restrict__ dst0,
|
||
scalar_t* __restrict__ dst1,
|
||
const scalar_t* __restrict__ src,
|
||
const index_t* __restrict__ ind,
|
||
int N,
|
||
int ld_src,
|
||
int ld_dst0,
|
||
int ld_dst1,
|
||
bool convert_v) {
|
||
__m512i vinputs[16];
|
||
int n = 0;
|
||
for (; n < N; ++n) {
|
||
vinputs[n] = _mm512_loadu_si512(src + ind[n] * ld_src);
|
||
}
|
||
// padding with zero to avoid uninitialized vectors
|
||
for (; n < 16; ++n) {
|
||
vinputs[n] = _mm512_set1_epi32(0);
|
||
}
|
||
|
||
// pack value, skip 64 elems for deepseek
|
||
// handle 2 vectors at a time from [2, 32] to [32, 2]
|
||
if (convert_v) {
|
||
for (int n = 0; n < 16; n += 2) {
|
||
__m512i d0, d1;
|
||
std::tie(d0, d1) = transpose_2x32_16bit(vinputs[n], vinputs[n + 1]);
|
||
_mm512_storeu_si512(dst1 + (n >> 1) * ld_dst1 * 2, d0);
|
||
_mm512_storeu_si512(dst1 + (n >> 1) * ld_dst1 * 2 + 32, d1);
|
||
}
|
||
}
|
||
|
||
// pack key
|
||
transpose_16x16_32bit(vinputs);
|
||
|
||
const __mmask16 vmask = (1 << N) - 1;
|
||
for (int k = 0; k < 16; ++k) {
|
||
_mm512_mask_storeu_epi32(dst0 + k * ld_dst0 * 2, vmask, vinputs[k]);
|
||
}
|
||
}
|
||
#endif
|
||
|
||
// [NOTE]: MLA vnni format conversion
|
||
//
|
||
// here we apply same strategy as `FlashMLA`:
|
||
// each kv_cache is loaded once and packed twice (L2 cache hit)
|
||
//
|
||
// * for key: from [N, K/2, 2] to [K/2, N, 2]
|
||
// * for value: from [N/2, 2, Kv] to [N/2, Kv, 2]
|
||
//
|
||
template <typename scalar_t, typename index_t>
|
||
void pack_vnni(
|
||
scalar_t* __restrict__ dst0,
|
||
scalar_t* __restrict__ dst1,
|
||
const scalar_t* __restrict__ src,
|
||
const index_t* __restrict__ ind,
|
||
int N,
|
||
int K,
|
||
int Kv,
|
||
int ld_src,
|
||
int ld_dst0,
|
||
int ld_dst1) {
|
||
#if defined(CPU_CAPABILITY_AVX512)
|
||
const int NB = div_up(N, 16);
|
||
const int KB = K / 32; // no remainder
|
||
const int KBv = Kv / 32; // no remainder
|
||
|
||
for (int nb = 0; nb < NB; ++nb) {
|
||
for (int kb = 0; kb < KB; ++kb) {
|
||
// handle 16x512bits each block
|
||
int nb_size = std::min(N - nb * 16, 16);
|
||
pack_vnni_Nx32<scalar_t, index_t>(
|
||
/* dst0 */ dst0 + ((kb * 32) >> 1) * ld_dst0 * 2 + nb * 16 * 2,
|
||
/* dst1 */ dst1 + ((nb * 16) >> 1) * ld_dst1 * 2 + kb * 32 * 2,
|
||
/* src */ src + kb * 32,
|
||
/* ind */ ind + nb * 16,
|
||
/* N */ nb_size,
|
||
/* ld_src */ ld_src,
|
||
/* ld_dst0 */ ld_dst0,
|
||
/* ld_dst1 */ ld_dst1,
|
||
/* cvt_v */ kb < KBv);
|
||
}
|
||
}
|
||
#else
|
||
for (int n = 0; n < N; ++n) {
|
||
index_t index = ind[n];
|
||
for (int k = 0; k < K / 2; ++k) {
|
||
for (int d = 0; d < 2; ++d) {
|
||
dst0[k * ld_dst0 * 2 + n * 2 + d] = src[index * ld_src + k * 2 + d];
|
||
}
|
||
}
|
||
}
|
||
// from [N/2, 2, K] to [N/2, K, 2]
|
||
for (int n = 0; n < (N >> 1) * 2; n += 2) {
|
||
index_t index0 = ind[n + 0];
|
||
index_t index1 = ind[n + 1];
|
||
for (int k = 0; k < Kv; ++k) {
|
||
dst1[(n >> 1) * ld_dst1 * 2 + k * 2 + 0] = src[index0 * ld_src + k];
|
||
dst1[(n >> 1) * ld_dst1 * 2 + k * 2 + 1] = src[index1 * ld_src + k];
|
||
}
|
||
}
|
||
if (N % 2 != 0) {
|
||
index_t index = ind[N - 1];
|
||
for (int k = 0; k < Kv; ++k) {
|
||
dst1[(N >> 1) * ld_dst1 * 2 + k * 2 + 0] = src[index * ld_src + k];
|
||
dst1[(N >> 1) * ld_dst1 * 2 + k * 2 + 1] = 0;
|
||
}
|
||
}
|
||
#endif
|
||
}
|
||
|
||
template <typename scalar_t>
|
||
inline void fill_stub(scalar_t* __restrict__ out, float val, int64_t size) {
|
||
using Vec = at::vec::Vectorized<scalar_t>;
|
||
constexpr int kVecSize = Vec::size();
|
||
const Vec data_vec = Vec(static_cast<scalar_t>(val));
|
||
int64_t d = 0;
|
||
#pragma GCC unroll 4
|
||
for (; d <= size - kVecSize; d += kVecSize) {
|
||
data_vec.store(out + d);
|
||
}
|
||
if (size - d > 0) {
|
||
data_vec.store(out + d, size - d);
|
||
}
|
||
}
|
||
|
||
template <typename scalar_t>
|
||
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc, float s, int64_t size) {
|
||
using bVec = at::vec::Vectorized<scalar_t>;
|
||
using fVec = at::vec::Vectorized<float>;
|
||
constexpr int kVecSize = bVec::size();
|
||
const fVec s_fvec = fVec(s);
|
||
int64_t d = 0;
|
||
#pragma GCC unroll 4
|
||
for (; d <= size - kVecSize; d += kVecSize) {
|
||
fVec a_fvec0 = fVec::loadu(acc + d) * s_fvec;
|
||
fVec a_fvec1 = fVec::loadu(acc + d + fVec::size()) * s_fvec;
|
||
bVec out_bvec = convert_from_float_ext<scalar_t>(a_fvec0, a_fvec1);
|
||
out_bvec.store(out + d);
|
||
}
|
||
for (; d < size; ++d) {
|
||
out[d] = static_cast<scalar_t>(acc[d] * s);
|
||
}
|
||
}
|
||
|
||
template <typename scalar_t>
|
||
inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ src, int64_t size) {
|
||
using bVec = at::vec::Vectorized<scalar_t>;
|
||
constexpr int kVecSize = bVec::size();
|
||
int64_t d = 0;
|
||
#pragma GCC unroll 4
|
||
for (; d <= size - kVecSize; d += kVecSize) {
|
||
bVec out_bvec = bVec::loadu(src + d);
|
||
out_bvec.store(out + d);
|
||
}
|
||
for (; d < size; ++d) {
|
||
out[d] = src[d];
|
||
}
|
||
}
|
||
|
||
template <typename scalar_t, int BLOCK_N>
|
||
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input) {
|
||
static_assert(BLOCK_N % 32 == 0);
|
||
using bVec = at::vec::Vectorized<scalar_t>;
|
||
using fVec = at::vec::Vectorized<float>;
|
||
|
||
constexpr int COLS = BLOCK_N / 16;
|
||
auto store = [&](auto i) {
|
||
constexpr int col = i % COLS;
|
||
// for COLS = 2, 4 use 512bit store
|
||
if constexpr (col % 2 == 0) {
|
||
fVec a_fvec0 = fVec::loadu(input + col * 16);
|
||
fVec a_fvec1 = fVec::loadu(input + col * 16 + 16);
|
||
bVec out_bvec = convert_from_float_ext<scalar_t>(a_fvec0, a_fvec1);
|
||
out_bvec.store(out + col * 16);
|
||
}
|
||
};
|
||
Unroll<COLS>{}(store);
|
||
}
|
||
|
||
// GEMM handles query @ key (indexed) x scale
|
||
// A : [M, K]
|
||
// B : [N, K] indexed
|
||
// C : [M, N]
|
||
//
|
||
template <typename scalar_t, typename index_t, int BLOCK_M, int BLOCK_N>
|
||
struct tinygemm_kernel_nt {
|
||
static inline void apply(
|
||
const scalar_t* __restrict__ A,
|
||
const scalar_t* __restrict__ B,
|
||
float* __restrict__ C,
|
||
const index_t* __restrict__ indices,
|
||
float scale,
|
||
int64_t lda,
|
||
int64_t ldb,
|
||
int64_t ldc,
|
||
int64_t K,
|
||
int64_t max_tokens) {
|
||
for (int64_t m = 0; m < BLOCK_M; ++m) {
|
||
for (int64_t n = 0; n < BLOCK_N; ++n) {
|
||
float sum = 0.f;
|
||
int64_t b_idx = indices[n];
|
||
TORCH_CHECK(b_idx < max_tokens, "token index out of scope!");
|
||
for (int64_t k = 0; k < K; ++k) {
|
||
sum += scale * static_cast<float>(A[m * lda + k]) * static_cast<float>(B[b_idx * ldb + k]);
|
||
}
|
||
C[m * ldc + n] = sum;
|
||
}
|
||
}
|
||
}
|
||
};
|
||
|
||
#if defined(CPU_CAPABILITY_AVX512)
|
||
template <typename index_t, int BLOCK_M, int BLOCK_N>
|
||
struct tinygemm_kernel_nt<at::BFloat16, index_t, BLOCK_M, BLOCK_N> {
|
||
static inline void apply(
|
||
const at::BFloat16* __restrict__ A,
|
||
const at::BFloat16* __restrict__ B,
|
||
float* __restrict__ C,
|
||
const index_t* __restrict__ indices,
|
||
float scale,
|
||
int64_t lda,
|
||
int64_t ldb,
|
||
int64_t ldc,
|
||
int64_t K,
|
||
int64_t max_tokens) {
|
||
constexpr int ROWS = BLOCK_M;
|
||
constexpr int COLS = BLOCK_N;
|
||
|
||
__m512bh va;
|
||
__m512bh vb[COLS];
|
||
__m512 vc[ROWS * COLS];
|
||
__m512 vscale = _mm512_set1_ps(scale);
|
||
|
||
auto loadc = [&](auto i) { vc[i] = _mm512_setzero_ps(); };
|
||
Unroll<ROWS * COLS>{}(loadc);
|
||
|
||
// for main loop
|
||
auto compute = [&](auto i, int64_t k) {
|
||
constexpr int row = i / COLS;
|
||
constexpr int col = i % COLS;
|
||
|
||
if constexpr (col == 0) {
|
||
va = (__m512bh)(_mm512_loadu_si512(A + row * lda + k));
|
||
}
|
||
if constexpr (row == 0) {
|
||
if constexpr (col + 1 < COLS) {
|
||
int64_t b_idx_prefetch = indices[col + 1];
|
||
_mm_prefetch(B + b_idx_prefetch * ldb + k, _MM_HINT_T0);
|
||
}
|
||
int64_t b_idx = indices[col];
|
||
TORCH_CHECK(b_idx < max_tokens, "token index out of scope!");
|
||
vb[col] = (__m512bh)(_mm512_loadu_si512(B + b_idx * ldb + k));
|
||
}
|
||
vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]);
|
||
};
|
||
|
||
// for remainder
|
||
auto compute2 = [&](auto i, int64_t k, __mmask32 mask) {
|
||
constexpr int row = i / COLS;
|
||
constexpr int col = i % COLS;
|
||
|
||
if constexpr (col == 0) {
|
||
va = (__m512bh)(_mm512_maskz_loadu_epi16(mask, A + row * lda + k));
|
||
}
|
||
if constexpr (row == 0) {
|
||
int64_t b_idx = indices[col];
|
||
TORCH_CHECK(b_idx < max_tokens, "token index out of scope!");
|
||
vb[col] = (__m512bh)(_mm512_maskz_loadu_epi16(mask, B + b_idx * ldb + k));
|
||
}
|
||
vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]);
|
||
};
|
||
|
||
int64_t k = 0;
|
||
for (; k <= K - 32; k += 32) {
|
||
Unroll<ROWS * COLS>{}(compute, k);
|
||
}
|
||
int64_t count = K - k;
|
||
if (count > 0) {
|
||
__mmask32 mask = (1ULL << count) - 1;
|
||
Unroll<ROWS * COLS>{}(compute2, k, mask);
|
||
}
|
||
|
||
auto storec = [&](auto i) {
|
||
constexpr int row = i / COLS;
|
||
constexpr int col = i % COLS;
|
||
C[row * ldc + col] = _mm512_reduce_add_ps(_mm512_mul_ps(vc[i], vscale));
|
||
};
|
||
Unroll<ROWS * COLS>{}(storec);
|
||
}
|
||
};
|
||
#endif
|
||
|
||
#define LAUNCH_TINYGEMM_KERNEL_NT(MB_SIZE, NB_SIZE) \
|
||
tinygemm_kernel_nt<scalar_t, index_t, MB_SIZE, NB_SIZE>::apply( \
|
||
A + mb_start * lda, B, C + mb_start * ldc + nb_start, indices + nb_start, scale, lda, ldb, ldc, K, max_tokens);
|
||
|
||
// this is used when N isn't multiple of 16,
|
||
// N corresponds to `head_size_v` which should be 16x
|
||
template <typename scalar_t, typename index_t>
|
||
inline void tinygemm_kernel_nn_scalar(
|
||
const float* __restrict__ A,
|
||
const scalar_t* __restrict__ B,
|
||
float* __restrict__ C,
|
||
const index_t* __restrict__ indices,
|
||
const float* __restrict__ scale,
|
||
int64_t M,
|
||
int64_t N,
|
||
int64_t K,
|
||
int64_t lda,
|
||
int64_t ldb,
|
||
int64_t ldc,
|
||
int64_t max_tokens) {
|
||
for (int64_t m = 0; m < M; ++m) {
|
||
for (int64_t n = 0; n < N; ++n) {
|
||
C[m * ldc + n] *= scale[m];
|
||
for (int64_t k = 0; k < K; ++k) {
|
||
int64_t b_idx = indices[k];
|
||
TORCH_CHECK(b_idx < max_tokens, "token index out of scope!");
|
||
C[m * ldc + n] += A[m * lda + k] * static_cast<float>(B[b_idx * ldb + n]);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// GEMM handles v' * scale + attn @ value (indexed)
|
||
// A : [M, K]
|
||
// B : [K, N] indexed
|
||
// C :[M, N]
|
||
//
|
||
template <typename scalar_t, typename index_t, int BLOCK_M, int BLOCK_N>
|
||
struct tinygemm_kernel_nn {
|
||
static inline void apply(
|
||
const float* __restrict__ A,
|
||
const scalar_t* __restrict__ B,
|
||
float* __restrict__ C,
|
||
const index_t* __restrict__ indices,
|
||
const float* __restrict__ scale,
|
||
int64_t lda,
|
||
int64_t ldb,
|
||
int64_t ldc,
|
||
int64_t K,
|
||
int64_t max_tokens) {
|
||
tinygemm_kernel_nn_scalar(A, B, C, indices, scale, BLOCK_M, BLOCK_N, K, lda, ldb, ldc, max_tokens);
|
||
}
|
||
};
|
||
|
||
#if defined(CPU_CAPABILITY_AVX512)
|
||
template <typename index_t, int BLOCK_M, int BLOCK_N>
|
||
struct tinygemm_kernel_nn<at::BFloat16, index_t, BLOCK_M, BLOCK_N> {
|
||
static inline void apply(
|
||
const float* __restrict__ A,
|
||
const at::BFloat16* __restrict__ B,
|
||
float* __restrict__ C,
|
||
const index_t* __restrict__ indices,
|
||
const float* __restrict__ scale,
|
||
int64_t lda,
|
||
int64_t ldb,
|
||
int64_t ldc,
|
||
int64_t K,
|
||
int64_t max_tokens) {
|
||
constexpr int ROWS = BLOCK_M;
|
||
constexpr int COLS = BLOCK_N / 16;
|
||
|
||
__m512 va;
|
||
__m512 vb[COLS];
|
||
__m512 vc[ROWS * COLS];
|
||
__m512 vscale;
|
||
|
||
auto loadc = [&](auto i) {
|
||
constexpr int row = i / COLS;
|
||
constexpr int col = i % COLS;
|
||
#pragma GCC diagnostic push
|
||
#pragma GCC diagnostic ignored "-Warray-bounds"
|
||
if constexpr (col == 0) {
|
||
vscale = _mm512_set1_ps(scale[row]);
|
||
}
|
||
#pragma GCC diagnostic pop
|
||
vc[i] = _mm512_loadu_ps(C + row * ldc + col * 16);
|
||
vc[i] = _mm512_mul_ps(vc[i], vscale);
|
||
};
|
||
Unroll<ROWS * COLS>{}(loadc);
|
||
|
||
auto compute = [&](auto i, int64_t k) {
|
||
constexpr int row = i / COLS;
|
||
constexpr int col = i % COLS;
|
||
|
||
if constexpr (col == 0) {
|
||
va = _mm512_set1_ps(A[row * lda + k]);
|
||
}
|
||
if constexpr (row == 0) {
|
||
if (k + 1 < K) {
|
||
int64_t b_idx_prefetch = indices[k + 1];
|
||
_mm_prefetch(B + b_idx_prefetch * ldb + col * 16, _MM_HINT_T0);
|
||
}
|
||
int64_t b_idx = indices[k];
|
||
TORCH_CHECK(b_idx < max_tokens, "token index out of scope!");
|
||
|
||
// for COLS = 2, 4, 6, 8 use 512 bit load
|
||
// for COLS = 1, 3, 5, 7 use 256 bit load
|
||
if constexpr (COLS % 2 == 0) {
|
||
if constexpr (col % 2 == 0) {
|
||
__m512i b16 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(B + b_idx * ldb + col * 16));
|
||
vb[col + 0] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(b16, 0));
|
||
vb[col + 1] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(b16, 1));
|
||
}
|
||
} else {
|
||
__m256i b16 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(B + b_idx * ldb + col * 16));
|
||
vb[col] = CVT_BF16_TO_FP32(b16);
|
||
}
|
||
}
|
||
vc[i] = _mm512_fmadd_ps(va, vb[col], vc[i]);
|
||
};
|
||
|
||
for (int64_t k = 0; k < K; ++k) {
|
||
Unroll<ROWS * COLS>{}(compute, k);
|
||
}
|
||
|
||
auto storec = [&](auto i) {
|
||
constexpr int row = i / COLS;
|
||
constexpr int col = i % COLS;
|
||
_mm512_storeu_ps(C + row * ldc + col * 16, vc[i]);
|
||
};
|
||
Unroll<ROWS * COLS>{}(storec);
|
||
}
|
||
};
|
||
#endif
|
||
|
||
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
|
||
tinygemm_kernel_nn<scalar_t, index_t, MB_SIZE, NB_SIZE>::apply( \
|
||
A + mb_start * lda, \
|
||
B + nb_start, \
|
||
C + mb_start * ldc + nb_start, \
|
||
indices, \
|
||
scale + mb_start, \
|
||
lda, \
|
||
ldb, \
|
||
ldc, \
|
||
K, \
|
||
max_tokens);
|
||
|
||
template <typename scalar_t, typename index_t>
|
||
void index_gemm_kernel_nt(
|
||
const scalar_t* __restrict__ A,
|
||
const scalar_t* __restrict__ B,
|
||
float* __restrict__ C,
|
||
const index_t* __restrict__ indices,
|
||
float scale,
|
||
int64_t M,
|
||
int64_t N,
|
||
int64_t K,
|
||
int64_t lda,
|
||
int64_t ldb,
|
||
int64_t ldc,
|
||
int64_t max_tokens) {
|
||
// pattern: 1-8-8
|
||
if (M == 1) {
|
||
constexpr int64_t BLOCK_N = 8;
|
||
const int64_t NB = div_up(N, BLOCK_N);
|
||
int64_t mb_start = 0, lda = 1, ldc = 1;
|
||
|
||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||
int64_t nb_start = nb * BLOCK_N;
|
||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||
|
||
switch (nb_size) {
|
||
case 1:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(1, 1);
|
||
break;
|
||
case 2:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(1, 2);
|
||
break;
|
||
case 3:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(1, 3);
|
||
break;
|
||
case 4:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(1, 4);
|
||
break;
|
||
case 5:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(1, 5);
|
||
break;
|
||
case 6:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(1, 6);
|
||
break;
|
||
case 7:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(1, 7);
|
||
break;
|
||
case 8:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(1, 8);
|
||
break;
|
||
default:
|
||
TORCH_CHECK(false, "Unexpected block size, 1x", "nb_size");
|
||
}
|
||
}
|
||
return;
|
||
}
|
||
|
||
// pattern: 1-6-24
|
||
constexpr int64_t BLOCK_M = 4;
|
||
constexpr int64_t BLOCK_N = 6;
|
||
const int64_t MB = div_up(M, BLOCK_M);
|
||
const int64_t NB = div_up(N, BLOCK_N);
|
||
|
||
for (int64_t mb = 0; mb < MB; ++mb) {
|
||
int64_t mb_start = mb * BLOCK_M;
|
||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||
int64_t nb_start = nb * BLOCK_N;
|
||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||
|
||
switch (mb_size << 4 | nb_size) {
|
||
// mb_size = 1
|
||
case 0x11:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(1, 1);
|
||
break;
|
||
case 0x12:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(1, 2);
|
||
break;
|
||
case 0x13:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(1, 3);
|
||
break;
|
||
case 0x14:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(1, 4);
|
||
break;
|
||
case 0x15:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(1, 5);
|
||
break;
|
||
case 0x16:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(1, 6);
|
||
break;
|
||
// mb_size = 2
|
||
case 0x21:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(2, 1);
|
||
break;
|
||
case 0x22:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(2, 2);
|
||
break;
|
||
case 0x23:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(2, 3);
|
||
break;
|
||
case 0x24:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(2, 4);
|
||
break;
|
||
case 0x25:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(2, 5);
|
||
break;
|
||
case 0x26:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(2, 6);
|
||
break;
|
||
// mb_size = 3
|
||
case 0x31:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(3, 1);
|
||
break;
|
||
case 0x32:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(3, 2);
|
||
break;
|
||
case 0x33:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(3, 3);
|
||
break;
|
||
case 0x34:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(3, 4);
|
||
break;
|
||
case 0x35:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(3, 5);
|
||
break;
|
||
case 0x36:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(3, 6);
|
||
break;
|
||
// mb_size = 4
|
||
case 0x41:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(4, 1);
|
||
break;
|
||
case 0x42:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(4, 2);
|
||
break;
|
||
case 0x43:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(4, 3);
|
||
break;
|
||
case 0x44:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(4, 4);
|
||
break;
|
||
case 0x45:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(4, 5);
|
||
break;
|
||
case 0x46:
|
||
LAUNCH_TINYGEMM_KERNEL_NT(4, 6);
|
||
break;
|
||
default:
|
||
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
template <typename scalar_t, typename index_t>
|
||
void index_gemm_kernel_nn(
|
||
const float* __restrict__ A,
|
||
const scalar_t* __restrict__ B,
|
||
float* __restrict__ C,
|
||
const index_t* __restrict__ indices,
|
||
float* __restrict__ scale,
|
||
int64_t M,
|
||
int64_t N,
|
||
int64_t K,
|
||
int64_t lda,
|
||
int64_t ldb,
|
||
int64_t ldc,
|
||
int64_t max_tokens) {
|
||
constexpr int kVecSize = 16;
|
||
if ((N & (kVecSize - 1)) != 0) {
|
||
tinygemm_kernel_nn_scalar(A, B, C, indices, scale, M, N, K, lda, ldb, ldc, max_tokens);
|
||
return;
|
||
}
|
||
|
||
// pattern: 1-8-8
|
||
if (M == 1) {
|
||
constexpr int64_t BLOCK_N = 8 * kVecSize;
|
||
const int64_t NB = div_up(N, BLOCK_N);
|
||
int64_t mb_start = 0, lda = 1, ldc = 1;
|
||
|
||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||
int64_t nb_start = nb * BLOCK_N;
|
||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||
|
||
switch (nb_size >> 4) {
|
||
case 1:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(1, 16);
|
||
break;
|
||
case 2:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
|
||
break;
|
||
case 3:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(1, 48);
|
||
break;
|
||
case 4:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(1, 64);
|
||
break;
|
||
case 5:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(1, 80);
|
||
break;
|
||
case 6:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(1, 96);
|
||
break;
|
||
case 7:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(1, 112);
|
||
break;
|
||
case 8:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(1, 128);
|
||
break;
|
||
default:
|
||
TORCH_CHECK(false, "Unexpected block size, 1x", "nb_size");
|
||
}
|
||
}
|
||
return;
|
||
}
|
||
|
||
constexpr int64_t BLOCK_M = 4;
|
||
constexpr int64_t BLOCK_N = 6 * kVecSize;
|
||
const int64_t MB = div_up(M, BLOCK_M);
|
||
const int64_t NB = div_up(N, BLOCK_N);
|
||
|
||
for (int64_t mb = 0; mb < MB; ++mb) {
|
||
int64_t mb_start = mb * BLOCK_M;
|
||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||
int64_t nb_start = nb * BLOCK_N;
|
||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||
|
||
switch (mb_size << 4 | nb_size >> 4) {
|
||
// mb_size = 1
|
||
case 0x11:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(1, 16);
|
||
break;
|
||
case 0x12:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
|
||
break;
|
||
case 0x13:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(1, 48);
|
||
break;
|
||
case 0x14:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(1, 64);
|
||
break;
|
||
case 0x15:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(1, 80);
|
||
break;
|
||
case 0x16:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(1, 96);
|
||
break;
|
||
// mb_size = 2
|
||
case 0x21:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(2, 16);
|
||
break;
|
||
case 0x22:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
|
||
break;
|
||
case 0x23:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(2, 48);
|
||
break;
|
||
case 0x24:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(2, 64);
|
||
break;
|
||
case 0x25:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(2, 80);
|
||
break;
|
||
case 0x26:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(2, 96);
|
||
break;
|
||
// mb_size = 3
|
||
case 0x31:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(3, 16);
|
||
break;
|
||
case 0x32:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
|
||
break;
|
||
case 0x33:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(3, 48);
|
||
break;
|
||
case 0x34:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(3, 64);
|
||
break;
|
||
case 0x35:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(3, 80);
|
||
break;
|
||
case 0x36:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(3, 96);
|
||
break;
|
||
// mb_size = 4
|
||
case 0x41:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(4, 16);
|
||
break;
|
||
case 0x42:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
|
||
break;
|
||
case 0x43:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(4, 48);
|
||
break;
|
||
case 0x44:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(4, 64);
|
||
break;
|
||
case 0x45:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(4, 80);
|
||
break;
|
||
case 0x46:
|
||
LAUNCH_TINYGEMM_KERNEL_NN(4, 96);
|
||
break;
|
||
default:
|
||
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
template <typename scalar_t>
|
||
void decode_set_kv_buffer(
|
||
scalar_t* __restrict__ k_buffer,
|
||
scalar_t* __restrict__ v_buffer,
|
||
const scalar_t* __restrict__ key,
|
||
const scalar_t* __restrict__ value,
|
||
const int64_t* __restrict__ loc,
|
||
int64_t batches,
|
||
int64_t num_heads_kv,
|
||
int64_t head_size,
|
||
int64_t head_size_v,
|
||
int64_t k_strideN,
|
||
int64_t k_strideH,
|
||
int64_t v_strideN,
|
||
int64_t v_strideH,
|
||
int64_t nk_strideN,
|
||
int64_t nk_strideH,
|
||
int64_t nv_strideN,
|
||
int64_t nv_strideH,
|
||
bool is_mla) {
|
||
at::parallel_for(0, batches * num_heads_kv, 0, [&](int64_t begin, int64_t end) {
|
||
int64_t bs{0}, head_kv_id{0};
|
||
data_index_init(begin, bs, batches, head_kv_id, num_heads_kv);
|
||
|
||
for (int64_t i = begin; i < end; i++) {
|
||
int64_t loc_val = loc[bs];
|
||
scalar_t* k_buffer_ptr = k_buffer + loc_val * k_strideN + head_kv_id * k_strideH;
|
||
const scalar_t* new_key_ptr = key + bs * nk_strideN + head_kv_id * nk_strideH;
|
||
copy_stub<scalar_t>(k_buffer_ptr, new_key_ptr, head_size);
|
||
if (!is_mla) {
|
||
scalar_t* v_buffer_ptr = v_buffer + loc_val * v_strideN + head_kv_id * v_strideH;
|
||
const scalar_t* new_value_ptr = value + bs * nv_strideN + head_kv_id * nv_strideH;
|
||
copy_stub<scalar_t>(v_buffer_ptr, new_value_ptr, head_size_v);
|
||
}
|
||
|
||
// move to the next index
|
||
data_index_step(bs, batches, head_kv_id, num_heads_kv);
|
||
}
|
||
});
|
||
}
|
||
|
||
template <typename scalar_t>
|
||
void decode_accumulate_kv_splits(
|
||
scalar_t* __restrict__ output,
|
||
float* __restrict__ attn_logits,
|
||
int64_t batches,
|
||
int64_t num_heads,
|
||
int64_t head_size_v,
|
||
int64_t num_kv_splits,
|
||
int64_t l_stride1,
|
||
int64_t l_stride2) {
|
||
using Vec = at::vec::Vectorized<float>;
|
||
|
||
// parallel on [batches, num_heads]
|
||
at::parallel_for(0, batches * num_heads, 0, [&](int64_t begin, int64_t end) {
|
||
// NB: here we use logits[b][h][0] as acc, since
|
||
// for the first kv split (kv_id == 0):
|
||
// m_delta = std::exp(-inf) = 0
|
||
// e_logic = std::exp(0) = 1
|
||
// acc = acc * m_delta + tv * e_logic = tv
|
||
for (int64_t i = begin; i < end; ++i) {
|
||
float* __restrict__ acc = attn_logits + i * l_stride1;
|
||
|
||
float s_prime = 0.f;
|
||
float m_prime = -std::numeric_limits<scalar_t>::infinity();
|
||
|
||
// update acc with from each kv_split
|
||
for (int64_t kv_id = 0; kv_id < num_kv_splits; ++kv_id) {
|
||
float* __restrict__ tv = acc + kv_id * l_stride2;
|
||
const float tlogic = (acc + kv_id * l_stride2)[head_size_v];
|
||
|
||
float m_i = std::max(tlogic, m_prime);
|
||
float m_delta = std::exp(m_prime - m_i);
|
||
float e_logic = std::exp(tlogic - m_i);
|
||
if (kv_id != 0) {
|
||
at::vec::map2<float>(
|
||
[m_delta, e_logic](Vec x, Vec y) { return x * Vec(m_delta) + y * Vec(e_logic); },
|
||
acc,
|
||
acc,
|
||
tv,
|
||
head_size_v);
|
||
}
|
||
|
||
s_prime = s_prime * m_delta + e_logic;
|
||
m_prime = m_i;
|
||
}
|
||
|
||
copy_stub<scalar_t>(output + i * head_size_v, acc, 1 / s_prime, head_size_v);
|
||
}
|
||
});
|
||
}
|
||
|
||
template <typename scalar_t, typename index_t, int64_t BLOCK_N>
|
||
void decode_attention_kernel_impl(
|
||
scalar_t* __restrict__ output,
|
||
float* __restrict__ attn_logits,
|
||
const scalar_t* __restrict__ query,
|
||
const scalar_t* __restrict__ k_buffer,
|
||
const scalar_t* __restrict__ v_buffer,
|
||
const index_t* __restrict__ req_to_token,
|
||
const int64_t* __restrict__ req_pool_indices,
|
||
const int64_t* __restrict__ seq_lens,
|
||
int64_t batches,
|
||
int64_t num_heads,
|
||
int64_t head_size,
|
||
int64_t head_size_v,
|
||
int64_t num_kv_splits,
|
||
int64_t k_strideN,
|
||
int64_t k_strideH,
|
||
int64_t v_strideN,
|
||
int64_t v_strideH,
|
||
float scaling,
|
||
float logit_cap,
|
||
int64_t max_num_reqs,
|
||
int64_t max_context_len,
|
||
int64_t max_total_num_tokens) {
|
||
using Vec = at::vec::Vectorized<float>;
|
||
|
||
// strides
|
||
const int64_t q_strideM = num_heads * head_size;
|
||
const int64_t q_strideH = head_size;
|
||
const int64_t l_stride1 = num_kv_splits * (head_size_v + 1);
|
||
const int64_t l_stride2 = head_size_v + 1;
|
||
|
||
const bool has_logit_cap = logit_cap > 0;
|
||
float rlogit_cap = has_logit_cap ? 1 / logit_cap : 0.f;
|
||
|
||
// parallel on [batches, num_heads, num_kv_splits]
|
||
at::parallel_for(0, batches * num_heads * num_kv_splits, 0, [&](int64_t begin, int64_t end) {
|
||
int64_t bs{0}, head_id{0}, kv_id{0};
|
||
data_index_init(begin, bs, batches, head_id, num_heads, kv_id, num_kv_splits);
|
||
|
||
// s_prime and s_delta
|
||
alignas(64) float s_i[BLOCK_N];
|
||
float* __restrict__ s_delta = s_i;
|
||
|
||
for (int64_t i = begin; i < end; ++i) {
|
||
// get query
|
||
const scalar_t* __restrict__ q_ptr = query + bs * q_strideM + head_id * q_strideH;
|
||
|
||
// get key/value
|
||
int64_t seq_len_kv = seq_lens[bs];
|
||
int64_t req_pool_id = req_pool_indices[bs];
|
||
TORCH_CHECK(seq_len_kv <= max_context_len, "seq_len_kv out of scope!");
|
||
TORCH_CHECK(req_pool_id < max_num_reqs, "req_pool_id out of scope!");
|
||
|
||
const int64_t SPLIT_SIZE = div_up(seq_len_kv, num_kv_splits);
|
||
const int64_t kv_start = kv_id * SPLIT_SIZE;
|
||
const int64_t kv_end = std::min(kv_start + SPLIT_SIZE, seq_len_kv);
|
||
|
||
float m_prime = -std::numeric_limits<float>::infinity();
|
||
float s_prime = 0.f;
|
||
|
||
// get v_prime, and init to zero
|
||
float* __restrict__ v_prime = attn_logits + i * (head_size_v + 1);
|
||
fill_stub(v_prime, 0.f, head_size_v);
|
||
|
||
// loop over K and V sequence with BLOCK_N
|
||
for (int64_t n = kv_start; n < kv_end; n += BLOCK_N) {
|
||
int64_t n_size = std::min(BLOCK_N, kv_end - n);
|
||
|
||
// calculate s_i <- scale * Q @ K
|
||
index_gemm_kernel_nt<scalar_t, index_t>(
|
||
/* A */ q_ptr,
|
||
/* B */ k_buffer + head_id * k_strideH,
|
||
/* C */ s_i,
|
||
/* ind */ req_to_token + req_pool_id * max_context_len + n,
|
||
/* scl */ scaling,
|
||
/* M */ 1,
|
||
/* N */ n_size,
|
||
/* K */ head_size,
|
||
/* lda */ 1,
|
||
/* ldb */ k_strideN,
|
||
/* ldc */ 1,
|
||
/* mtt */ max_total_num_tokens);
|
||
|
||
// TODO: `tanh` from torch uses sleef u10, going to be slow
|
||
if (has_logit_cap) {
|
||
at::vec::map<float>(
|
||
[logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); },
|
||
s_i,
|
||
s_i,
|
||
n_size);
|
||
}
|
||
|
||
// m_i: max value per row
|
||
float m_i = at::vec::reduce_all<float>([](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i, n_size);
|
||
m_i = std::max(m_i, m_prime);
|
||
|
||
// m_delta <- exp(m' - m_i)
|
||
float m_delta = std::exp(m_prime - m_i);
|
||
|
||
// s_delta <- exp(s_i - m_i)
|
||
at::vec::map<float>([m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta, s_i, n_size);
|
||
|
||
// s' <- s' * m_delta + sum(s_delta)
|
||
s_prime *= m_delta;
|
||
s_prime += at::vec::reduce_all<float>([](Vec& x, Vec& y) { return x + y; }, s_delta, n_size);
|
||
|
||
m_prime = m_i;
|
||
|
||
// calculate V' <- s_delta @ V + V' * m_delta
|
||
index_gemm_kernel_nn<scalar_t, index_t>(
|
||
/* A */ s_delta,
|
||
/* B */ v_buffer + head_id * v_strideH,
|
||
/* C */ v_prime,
|
||
/* ind */ req_to_token + req_pool_id * max_context_len + n,
|
||
/* scl */ &m_delta,
|
||
/* M */ 1,
|
||
/* N */ head_size_v,
|
||
/* K */ n_size,
|
||
/* lda */ 1,
|
||
/* ldb */ v_strideN,
|
||
/* ldc */ 1,
|
||
/* mtt */ max_total_num_tokens);
|
||
} // loop with KV blocks
|
||
|
||
// only update v' when kv_split_size > 0
|
||
if (kv_end > kv_start) {
|
||
float s = 1 / s_prime;
|
||
at::vec::map<float>([s](Vec out) { return out * Vec(s); }, v_prime, v_prime, head_size_v);
|
||
|
||
v_prime[head_size_v] = m_prime + std::log(s_prime);
|
||
}
|
||
|
||
// move to the next index
|
||
data_index_step(bs, batches, head_id, num_heads, kv_id, num_kv_splits);
|
||
}
|
||
});
|
||
|
||
decode_accumulate_kv_splits(
|
||
output, attn_logits, batches, num_heads, head_size_v, num_kv_splits, l_stride1, l_stride2);
|
||
} // MHA
|
||
|
||
template <typename scalar_t, typename index_t, int64_t BLOCK_N>
|
||
void decode_attention_mla_kernel_impl(
|
||
scalar_t* __restrict__ output,
|
||
float* __restrict__ attn_logits,
|
||
const scalar_t* __restrict__ query,
|
||
const scalar_t* __restrict__ k_buffer,
|
||
const scalar_t* __restrict__ v_buffer,
|
||
const index_t* __restrict__ req_to_token,
|
||
const int64_t* __restrict__ req_pool_indices,
|
||
const int64_t* __restrict__ seq_lens,
|
||
scalar_t* __restrict__ buffer,
|
||
int64_t batches,
|
||
int64_t num_heads,
|
||
int64_t head_size,
|
||
int64_t head_size_v,
|
||
int64_t num_kv_splits,
|
||
int64_t k_strideN,
|
||
int64_t k_strideH,
|
||
int64_t v_strideN,
|
||
int64_t v_strideH,
|
||
float scaling,
|
||
float logit_cap,
|
||
int64_t max_num_reqs,
|
||
int64_t max_context_len,
|
||
int64_t max_total_num_tokens,
|
||
int64_t buffer_size_per_thread) {
|
||
using Vec = at::vec::Vectorized<float>;
|
||
|
||
// block length for heads
|
||
const int64_t BLOCK_H = batches == 1 ? 6 : (batches > 16 ? 22 : 11);
|
||
|
||
// strides
|
||
const int64_t q_strideM = num_heads * head_size;
|
||
const int64_t q_strideH = head_size;
|
||
const int64_t l_stride0 = num_heads * num_kv_splits * (head_size_v + 1);
|
||
const int64_t l_stride1 = num_kv_splits * (head_size_v + 1);
|
||
const int64_t l_stride2 = head_size_v + 1;
|
||
|
||
TORCH_CHECK(logit_cap == 0.f, "decode MLA: expect no logit_cap.");
|
||
|
||
// partition the heads into blocks for parallel
|
||
const int64_t num_blocks = div_up(num_heads, BLOCK_H);
|
||
|
||
// parallel on [batches, num_blocks, num_kv_splits]
|
||
at::parallel_for(0, batches * num_blocks * num_kv_splits, 0, [&](int64_t begin, int64_t end) {
|
||
int64_t bs{0}, block_id{0}, kv_id{0};
|
||
data_index_init(begin, bs, batches, block_id, num_blocks, kv_id, num_kv_splits);
|
||
|
||
int tid = at::get_thread_num();
|
||
scalar_t* __restrict__ Btmp0 = buffer + tid * buffer_size_per_thread;
|
||
scalar_t* __restrict__ Btmp1 = Btmp0 + BLOCK_N * head_size;
|
||
|
||
// init Btmp1 just once for each thread to prevent NaN
|
||
// Btmp0 is not needed as it computes full K every single time
|
||
fill_stub(Btmp1, 0.f, BLOCK_N * head_size_v);
|
||
|
||
alignas(64) float s_i[BLOCK_H * BLOCK_N];
|
||
float* __restrict__ s_delta = s_i;
|
||
alignas(64) scalar_t s_delta2[BLOCK_H * BLOCK_N];
|
||
|
||
alignas(64) float s_prime[BLOCK_H];
|
||
alignas(64) float m_prime[BLOCK_H];
|
||
alignas(64) float m_delta[BLOCK_H];
|
||
|
||
for (int64_t i = begin; i < end; ++i) {
|
||
const int64_t h_start = block_id * BLOCK_H;
|
||
const int64_t h_end = std::min(block_id * BLOCK_H + BLOCK_H, num_heads);
|
||
const int64_t h_size = h_end - h_start;
|
||
|
||
// get query
|
||
const scalar_t* __restrict__ q_ptr = query + bs * q_strideM + h_start * q_strideH;
|
||
|
||
int64_t seq_len_kv = seq_lens[bs];
|
||
int64_t req_pool_id = req_pool_indices[bs];
|
||
TORCH_CHECK(seq_len_kv <= max_context_len, "seq_len_kv out of scope!");
|
||
TORCH_CHECK(req_pool_id < max_num_reqs, "req_pool_id out of scope!");
|
||
|
||
const int64_t SPLIT_SIZE = div_up(seq_len_kv, num_kv_splits);
|
||
const int64_t kv_start = kv_id * SPLIT_SIZE;
|
||
const int64_t kv_end = std::min(kv_start + SPLIT_SIZE, seq_len_kv);
|
||
|
||
fill_stub(s_prime, 0.f, BLOCK_H);
|
||
fill_stub(m_prime, -std::numeric_limits<float>::infinity(), BLOCK_H);
|
||
|
||
// get v_prime, and init to zero
|
||
float* __restrict__ v_prime = attn_logits + bs * l_stride0 + h_start * l_stride1 + kv_id * l_stride2;
|
||
for (int64_t h = 0; h < h_size; ++h) {
|
||
fill_stub(v_prime + h * l_stride1, 0.f, head_size_v);
|
||
}
|
||
|
||
// loop over K and V sequence with BLOCK_N
|
||
for (int64_t n = kv_start; n < kv_end; n += BLOCK_N) {
|
||
int64_t n_size = std::min(BLOCK_N, kv_end - n);
|
||
const int64_t padded_n_size = div_up(int(n_size), TILE_K) * TILE_K;
|
||
|
||
// get key and pack
|
||
pack_vnni<scalar_t, index_t>(
|
||
/* dst0 */ Btmp0,
|
||
/* dst1 */ Btmp1,
|
||
/* src */ k_buffer + /* head_kv_id */ 0 * k_strideH,
|
||
/* ind */ req_to_token + req_pool_id * max_context_len + n,
|
||
/* N */ n_size,
|
||
/* K */ head_size,
|
||
/* Kv */ head_size_v,
|
||
/* ld_src */ k_strideN,
|
||
/* ld_dst0 */ BLOCK_N,
|
||
/* ld_dst1 */ head_size_v);
|
||
|
||
// calculate s_i <- Q @ K
|
||
at::native::cpublas::brgemm(
|
||
/* M */ h_size,
|
||
/* N */ n_size,
|
||
/* K */ head_size,
|
||
/* lda */ q_strideH,
|
||
/* ldb */ BLOCK_N,
|
||
/* ldc */ BLOCK_N,
|
||
/* add_C */ false,
|
||
/* A */ q_ptr,
|
||
/* B */ Btmp0,
|
||
/* C */ s_i);
|
||
|
||
const Vec scale_vec = Vec(scaling);
|
||
for (int64_t h = 0; h < h_size; ++h) {
|
||
// s_i <- s_i * scale
|
||
at::vec::map<float>(
|
||
[scale_vec](Vec x) { return x * scale_vec; }, s_i + h * BLOCK_N, s_i + h * BLOCK_N, n_size);
|
||
|
||
// m_i: max value per row
|
||
float m_i = at::vec::reduce_all<float>(
|
||
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i + h * BLOCK_N, n_size);
|
||
m_i = std::max(m_i, m_prime[h]);
|
||
|
||
// m_delta <- exp(m' - m_i)
|
||
m_delta[h] = std::exp(m_prime[h] - m_i);
|
||
|
||
// s_delta <- exp(s_i - m_i)
|
||
at::vec::map<float>(
|
||
[m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta + h * BLOCK_N, s_i + h * BLOCK_N, n_size);
|
||
|
||
// s' <- s' * m_delta + sum(s_delta)
|
||
s_prime[h] *= m_delta[h];
|
||
s_prime[h] += at::vec::reduce_all<float>([](Vec& x, Vec& y) { return x + y; }, s_delta + h * BLOCK_N, n_size);
|
||
|
||
m_prime[h] = m_i;
|
||
|
||
// v' <- v' * m_delta
|
||
float scale_m = m_delta[h];
|
||
at::vec::map<float>(
|
||
[scale_m](Vec x) { return x * Vec(scale_m); },
|
||
v_prime + h * l_stride1,
|
||
v_prime + h * l_stride1,
|
||
head_size_v);
|
||
|
||
// pad s_delta with 0 first and then convert to scalar_t
|
||
fill_stub(s_delta + h * BLOCK_N + n_size, 0.f, padded_n_size - n_size);
|
||
copy_stub<scalar_t, BLOCK_N>(s_delta2 + h * BLOCK_N, s_delta + h * BLOCK_N);
|
||
}
|
||
|
||
// calculate V' <- s_delta @ V + V'
|
||
at::native::cpublas::brgemm(
|
||
/* M */ h_size,
|
||
/* N */ head_size_v,
|
||
/* K */ padded_n_size, // n_size
|
||
/* lda */ BLOCK_N,
|
||
/* ldb */ head_size_v,
|
||
/* ldc */ l_stride1,
|
||
/* add_C */ true,
|
||
/* A */ s_delta2,
|
||
/* B */ Btmp1,
|
||
/* C */ v_prime);
|
||
} // loop with KV blocks
|
||
|
||
// only update v' when kv_split_size > 0
|
||
if (kv_end > kv_start) {
|
||
for (int64_t h = 0; h < h_size; ++h) {
|
||
float s = 1 / s_prime[h];
|
||
at::vec::map<float>(
|
||
[s](Vec out) { return out * Vec(s); }, v_prime + h * l_stride1, v_prime + h * l_stride1, head_size_v);
|
||
(v_prime + h * l_stride1)[head_size_v] = m_prime[h] + std::log(s_prime[h]);
|
||
}
|
||
}
|
||
|
||
// move to the next index
|
||
data_index_step(bs, batches, block_id, num_blocks, kv_id, num_kv_splits);
|
||
}
|
||
at::native::cpublas::brgemm_release();
|
||
});
|
||
|
||
decode_accumulate_kv_splits(
|
||
output, attn_logits, batches, num_heads, head_size_v, num_kv_splits, l_stride1, l_stride2);
|
||
} // MLA
|
||
|
||
template <typename scalar_t, typename index_t, int64_t BLOCK_N>
|
||
void decode_attention_grouped_kernel_impl(
|
||
scalar_t* __restrict__ output,
|
||
float* __restrict__ attn_logits,
|
||
const scalar_t* __restrict__ query,
|
||
const scalar_t* __restrict__ k_buffer,
|
||
const scalar_t* __restrict__ v_buffer,
|
||
const index_t* __restrict__ req_to_token,
|
||
const int64_t* __restrict__ req_pool_indices,
|
||
const int64_t* __restrict__ seq_lens,
|
||
int64_t batches,
|
||
int64_t num_heads,
|
||
int64_t num_heads_kv,
|
||
int64_t head_size,
|
||
int64_t head_size_v,
|
||
int64_t num_kv_splits,
|
||
int64_t k_strideN,
|
||
int64_t k_strideH,
|
||
int64_t v_strideN,
|
||
int64_t v_strideH,
|
||
float scaling,
|
||
float logit_cap,
|
||
int64_t max_num_reqs,
|
||
int64_t max_context_len,
|
||
int64_t max_total_num_tokens) {
|
||
using Vec = at::vec::Vectorized<float>;
|
||
|
||
// block length for heads
|
||
// we parallel on [batches, divup(num_heads, BLOCK_H), num_kv_splits]
|
||
// use smaller BLOCK_H when batches is small to utilize all cores
|
||
constexpr int64_t kBLOCK_H = 16;
|
||
const int64_t BLOCK_H = std::min(4 * batches, kBLOCK_H);
|
||
|
||
// strides
|
||
const int64_t q_strideM = num_heads * head_size;
|
||
const int64_t q_strideH = head_size;
|
||
const int64_t l_stride0 = num_heads * num_kv_splits * (head_size_v + 1);
|
||
const int64_t l_stride1 = num_kv_splits * (head_size_v + 1);
|
||
const int64_t l_stride2 = head_size_v + 1;
|
||
|
||
const bool has_logit_cap = logit_cap > 0;
|
||
float rlogit_cap = has_logit_cap ? 1 / logit_cap : 0.f;
|
||
|
||
// partition the heads into blocks for parallel
|
||
const int64_t num_groups = num_heads / num_heads_kv;
|
||
const int64_t num_blocks = div_up(num_groups, BLOCK_H);
|
||
|
||
// parallel on [batches, num_heads_kv, num_blocks, num_kv_splits]
|
||
at::parallel_for(0, batches * num_heads_kv * num_blocks * num_kv_splits, 0, [&](int64_t begin, int64_t end) {
|
||
int64_t bs{0}, head_kv_id{0}, block_id{0}, kv_id{0};
|
||
data_index_init(begin, bs, batches, head_kv_id, num_heads_kv, block_id, num_blocks, kv_id, num_kv_splits);
|
||
|
||
alignas(64) float s_i[BLOCK_H * BLOCK_N];
|
||
float* __restrict__ s_delta = s_i;
|
||
|
||
alignas(64) float s_prime[BLOCK_H];
|
||
alignas(64) float m_prime[BLOCK_H];
|
||
alignas(64) float m_delta[BLOCK_H];
|
||
|
||
for (int64_t i = begin; i < end; ++i) {
|
||
const int64_t h_start = head_kv_id * num_groups + block_id * BLOCK_H;
|
||
const int64_t h_end = head_kv_id * num_groups + std::min(block_id * BLOCK_H + BLOCK_H, num_groups);
|
||
const int64_t h_size = h_end - h_start;
|
||
|
||
// get query
|
||
const scalar_t* __restrict__ q_ptr = query + bs * q_strideM + h_start * q_strideH;
|
||
|
||
int64_t seq_len_kv = seq_lens[bs];
|
||
int64_t req_pool_id = req_pool_indices[bs];
|
||
TORCH_CHECK(seq_len_kv <= max_context_len, "seq_len_kv out of scope!");
|
||
TORCH_CHECK(req_pool_id < max_num_reqs, "req_pool_id out of scope!");
|
||
|
||
const int64_t SPLIT_SIZE = div_up(seq_len_kv, num_kv_splits);
|
||
const int64_t kv_start = kv_id * SPLIT_SIZE;
|
||
const int64_t kv_end = std::min(kv_start + SPLIT_SIZE, seq_len_kv);
|
||
|
||
fill_stub(s_prime, 0.f, BLOCK_H);
|
||
fill_stub(m_prime, -std::numeric_limits<float>::infinity(), BLOCK_H);
|
||
|
||
// get v_prime, and init to zero
|
||
float* __restrict__ v_prime = attn_logits + bs * l_stride0 + h_start * l_stride1 + kv_id * l_stride2;
|
||
for (int64_t h = 0; h < h_size; ++h) {
|
||
fill_stub(v_prime + h * l_stride1, 0.f, head_size_v);
|
||
}
|
||
|
||
// loop over K and V sequence with BLOCK_N
|
||
for (int64_t n = kv_start; n < kv_end; n += BLOCK_N) {
|
||
int64_t n_size = std::min(BLOCK_N, kv_end - n);
|
||
|
||
// calculate Q @ K
|
||
index_gemm_kernel_nt<scalar_t, index_t>(
|
||
/* A */ q_ptr,
|
||
/* B */ k_buffer + head_kv_id * k_strideH,
|
||
/* C */ s_i,
|
||
/* ind */ req_to_token + req_pool_id * max_context_len + n,
|
||
/* scl */ scaling,
|
||
/* M */ h_size,
|
||
/* N */ n_size,
|
||
/* K */ head_size,
|
||
/* lda */ q_strideH,
|
||
/* ldb */ k_strideN,
|
||
/* ldc */ BLOCK_N,
|
||
/* mtt */ max_total_num_tokens);
|
||
|
||
if (has_logit_cap) {
|
||
at::vec::map<float>(
|
||
[logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); },
|
||
s_i,
|
||
s_i,
|
||
BLOCK_H * BLOCK_N);
|
||
}
|
||
|
||
// update the scaling coefficients
|
||
for (int64_t h = 0; h < h_size; ++h) {
|
||
// m_i: max value per row
|
||
float m_i = at::vec::reduce_all<float>(
|
||
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i + h * BLOCK_N, n_size);
|
||
m_i = std::max(m_i, m_prime[h]);
|
||
|
||
// m_delta <- exp(m' - m_i)
|
||
m_delta[h] = std::exp(m_prime[h] - m_i);
|
||
|
||
// s_delta <- exp(s_i - m_i)
|
||
at::vec::map<float>(
|
||
[m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta + h * BLOCK_N, s_i + h * BLOCK_N, n_size);
|
||
|
||
// s' <- s' * m_delta + sum(s_delta)
|
||
s_prime[h] *= m_delta[h];
|
||
s_prime[h] += at::vec::reduce_all<float>([](Vec& x, Vec& y) { return x + y; }, s_delta + h * BLOCK_N, n_size);
|
||
|
||
m_prime[h] = m_i;
|
||
}
|
||
|
||
// calculate V' <- s_delta @ V + V' * m_delta
|
||
index_gemm_kernel_nn<scalar_t, index_t>(
|
||
/* A */ s_delta,
|
||
/* B */ v_buffer + head_kv_id * v_strideH,
|
||
/* C */ v_prime,
|
||
/* ind */ req_to_token + req_pool_id * max_context_len + n,
|
||
/* scl */ m_delta,
|
||
/* M */ h_size,
|
||
/* N */ head_size_v,
|
||
/* K */ n_size,
|
||
/* lda */ BLOCK_N,
|
||
/* ldb */ v_strideN,
|
||
/* ldc */ l_stride1,
|
||
/* mtt */ max_total_num_tokens);
|
||
} // loop with KV blocks
|
||
|
||
// only update v' when kv_split_size > 0
|
||
if (kv_end > kv_start) {
|
||
for (int64_t h = 0; h < h_size; ++h) {
|
||
float s = 1 / s_prime[h];
|
||
at::vec::map<float>(
|
||
[s](Vec out) { return out * Vec(s); }, v_prime + h * l_stride1, v_prime + h * l_stride1, head_size_v);
|
||
(v_prime + h * l_stride1)[head_size_v] = m_prime[h] + std::log(s_prime[h]);
|
||
}
|
||
}
|
||
|
||
// move to the next index
|
||
data_index_step(bs, batches, head_kv_id, num_heads_kv, block_id, num_blocks, kv_id, num_kv_splits);
|
||
}
|
||
});
|
||
|
||
decode_accumulate_kv_splits(
|
||
output, attn_logits, batches, num_heads, head_size_v, num_kv_splits, l_stride1, l_stride2);
|
||
} // GQA/MQA
|
||
|
||
} // anonymous namespace
|
||
|
||
// query: [num_tokens, num_heads, head_size]
|
||
// output: [num_tokens, num_heads, head_size]
|
||
// k_buffer: [max_total_num_tokens, num_heads, head_size]
|
||
// v_buffer: [max_total_num_tokens, num_heads, head_size_v]
|
||
// attn_logits: [num_seqs, num_heads, num_kv_splits, head_size_v + 1]
|
||
// req_to_token: [max_num_reqs, max_context_len] int32 or int64
|
||
// req_pool_indices: [num_seqs] int64
|
||
// seq_lens: [num_seqs] int64
|
||
//
|
||
void decode_attention_cpu(
|
||
at::Tensor& query,
|
||
at::Tensor& k_buffer,
|
||
at::Tensor& v_buffer,
|
||
at::Tensor& output,
|
||
at::Tensor& key,
|
||
at::Tensor& value,
|
||
at::Tensor& loc,
|
||
at::Tensor& attn_logits,
|
||
at::Tensor& req_to_token,
|
||
at::Tensor& req_pool_indices,
|
||
at::Tensor& seq_lens,
|
||
double sm_scale,
|
||
double logit_cap) {
|
||
RECORD_FUNCTION(
|
||
"sgl-kernel::decode_attention_cpu",
|
||
std::vector<c10::IValue>(
|
||
{query, output, k_buffer, v_buffer, attn_logits, req_to_token, req_pool_indices, seq_lens}));
|
||
|
||
CHECK_INPUT(query);
|
||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_buffer);
|
||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_buffer);
|
||
// for MLA, key and value shares the same storage and value could be non-contiguous
|
||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(key);
|
||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(value);
|
||
CHECK_DIM(3, query);
|
||
CHECK_DIM(3, k_buffer);
|
||
CHECK_DIM(3, v_buffer);
|
||
CHECK_DIM(3, key);
|
||
CHECK_DIM(3, value);
|
||
CHECK_DIM(1, loc);
|
||
|
||
int64_t num_seqs = seq_lens.size(0);
|
||
int64_t max_num_reqs = req_to_token.size(0);
|
||
int64_t max_context_len = req_to_token.size(1);
|
||
int64_t max_total_num_tokens = k_buffer.size(0);
|
||
|
||
int64_t num_heads = query.size(1);
|
||
int64_t num_heads_kv = k_buffer.size(1);
|
||
int64_t head_size = query.size(2);
|
||
int64_t head_size_v = v_buffer.size(2);
|
||
|
||
int64_t num_kv_splits = attn_logits.size(2);
|
||
|
||
CHECK_EQ(loc.numel(), num_seqs);
|
||
CHECK_EQ(attn_logits.size(0), num_seqs);
|
||
CHECK_EQ(attn_logits.size(1), num_heads);
|
||
CHECK_EQ(attn_logits.size(3), head_size_v + 1);
|
||
CHECK_EQ(attn_logits.scalar_type(), at::kFloat);
|
||
|
||
// strides for k_buffer and v_buffer
|
||
int64_t k_strideN = k_buffer.stride(0);
|
||
int64_t k_strideH = k_buffer.stride(1);
|
||
int64_t v_strideN = v_buffer.stride(0);
|
||
int64_t v_strideH = v_buffer.stride(1);
|
||
// strides for new key and value
|
||
int64_t nk_strideN = key.stride(0);
|
||
int64_t nk_strideH = key.stride(1);
|
||
int64_t nv_strideN = value.stride(0);
|
||
int64_t nv_strideH = value.stride(1);
|
||
|
||
// check index data types
|
||
const auto index_dtype = req_to_token.scalar_type();
|
||
TORCH_CHECK(
|
||
index_dtype == at::kInt || index_dtype == at::kLong,
|
||
"decode: expect req_to_token to be int32 or int64, got ",
|
||
index_dtype);
|
||
TORCH_CHECK(seq_lens.scalar_type() == at::kLong, "decode: expect req_lens to be int64, got ", seq_lens.scalar_type());
|
||
TORCH_CHECK(
|
||
req_pool_indices.scalar_type() == at::kLong,
|
||
"decode: expect req_pool_indices to be int64, got ",
|
||
req_pool_indices.scalar_type());
|
||
|
||
// check if we have MLA here
|
||
void* k_buffer_data = k_buffer.data_ptr();
|
||
void* v_buffer_data = v_buffer.data_ptr();
|
||
const bool is_mla = (k_buffer_data == v_buffer_data) && (num_heads_kv == 1) && (head_size == head_size_v + 64);
|
||
|
||
// block length for k_buffer and v_buffer
|
||
constexpr int BLOCK_N = 256;
|
||
|
||
// buffer for packing k_cache and v_cache
|
||
int num_threads = at::get_num_threads();
|
||
int64_t size_per_thread = is_mla ? BLOCK_N * head_size + BLOCK_N * head_size_v : 0;
|
||
auto buffer = at::empty({num_threads, size_per_thread}, k_buffer.options());
|
||
|
||
AT_DISPATCH_REDUCED_FLOATING_TYPES(query.scalar_type(), "decode_attention_kernel", [&] {
|
||
AT_DISPATCH_INDEX_TYPES(index_dtype, "decode_attention_indices", [&] {
|
||
// update the kv buffer
|
||
decode_set_kv_buffer(
|
||
(scalar_t*)k_buffer_data,
|
||
(scalar_t*)v_buffer_data,
|
||
key.data_ptr<scalar_t>(),
|
||
value.data_ptr<scalar_t>(),
|
||
loc.data_ptr<int64_t>(),
|
||
num_seqs,
|
||
num_heads_kv,
|
||
head_size,
|
||
head_size_v,
|
||
k_strideN,
|
||
k_strideH,
|
||
v_strideN,
|
||
v_strideH,
|
||
nk_strideN,
|
||
nk_strideH,
|
||
nv_strideN,
|
||
nv_strideH,
|
||
is_mla);
|
||
|
||
if (num_heads == num_heads_kv) {
|
||
// MHA
|
||
decode_attention_kernel_impl<scalar_t, index_t, BLOCK_N>(
|
||
output.data_ptr<scalar_t>(),
|
||
attn_logits.data_ptr<float>(),
|
||
query.data_ptr<scalar_t>(),
|
||
(const scalar_t*)k_buffer_data,
|
||
(const scalar_t*)v_buffer_data,
|
||
req_to_token.data_ptr<index_t>(),
|
||
req_pool_indices.data_ptr<int64_t>(),
|
||
seq_lens.data_ptr<int64_t>(),
|
||
num_seqs,
|
||
num_heads,
|
||
head_size,
|
||
head_size_v,
|
||
num_kv_splits,
|
||
k_strideN,
|
||
k_strideH,
|
||
v_strideN,
|
||
v_strideH,
|
||
sm_scale,
|
||
logit_cap,
|
||
max_num_reqs,
|
||
max_context_len,
|
||
max_total_num_tokens);
|
||
} else if (is_mla) {
|
||
// MLA
|
||
decode_attention_mla_kernel_impl<scalar_t, index_t, BLOCK_N>(
|
||
output.data_ptr<scalar_t>(),
|
||
attn_logits.data_ptr<float>(),
|
||
query.data_ptr<scalar_t>(),
|
||
(const scalar_t*)k_buffer_data,
|
||
(const scalar_t*)v_buffer_data,
|
||
req_to_token.data_ptr<index_t>(),
|
||
req_pool_indices.data_ptr<int64_t>(),
|
||
seq_lens.data_ptr<int64_t>(),
|
||
buffer.data_ptr<scalar_t>(),
|
||
num_seqs,
|
||
num_heads,
|
||
head_size,
|
||
head_size_v,
|
||
num_kv_splits,
|
||
k_strideN,
|
||
k_strideH,
|
||
v_strideN,
|
||
v_strideH,
|
||
sm_scale,
|
||
logit_cap,
|
||
max_num_reqs,
|
||
max_context_len,
|
||
max_total_num_tokens,
|
||
size_per_thread);
|
||
} else {
|
||
// GQA/MQA
|
||
decode_attention_grouped_kernel_impl<scalar_t, index_t, BLOCK_N>(
|
||
output.data_ptr<scalar_t>(),
|
||
attn_logits.data_ptr<float>(),
|
||
query.data_ptr<scalar_t>(),
|
||
(const scalar_t*)k_buffer_data,
|
||
(const scalar_t*)v_buffer_data,
|
||
req_to_token.data_ptr<index_t>(),
|
||
req_pool_indices.data_ptr<int64_t>(),
|
||
seq_lens.data_ptr<int64_t>(),
|
||
num_seqs,
|
||
num_heads,
|
||
num_heads_kv,
|
||
head_size,
|
||
head_size_v,
|
||
num_kv_splits,
|
||
k_strideN,
|
||
k_strideH,
|
||
v_strideN,
|
||
v_strideH,
|
||
sm_scale,
|
||
logit_cap,
|
||
max_num_reqs,
|
||
max_context_len,
|
||
max_total_num_tokens);
|
||
}
|
||
});
|
||
});
|
||
}
|