sglang_v0.5.2/flashinfer_0.3.1/csrc/xqa/mha_components.cuh

165 lines
5.4 KiB
Plaintext

#pragma once
#include "mma.cuh"
#include "utils.cuh"
using InstAcc = Array2D<float, 2, 2>;
template <uint32_t m, uint32_t n>
using WarpAccT = Array2D<InstAcc, exactDiv(m, 16), exactDiv(n, 8)>;
template <uint32_t accRows, uint32_t accCols>
__device__ inline void applyMask(Warp const& warp, Array2D<InstAcc, accRows, accCols>& acc,
uint32_t validColBeg, uint32_t validColEnd) {
uint32_t const idxInQuad = laneId() % 4;
uint32_t const idxQuad = laneId() / 4;
#pragma unroll
for (uint32_t n = 0; n < acc.cols; n++) {
#pragma unroll
for (uint32_t j = 0; j < InstAcc::cols; j++) {
uint32_t const col = 8 * n + InstAcc::cols * idxInQuad + j;
if (col >= validColBeg && col < validColEnd) {
continue;
}
#pragma unroll
for (uint32_t m = 0; m < acc.rows; m++) {
#pragma unroll
for (uint32_t i = 0; i < InstAcc::rows; i++) {
acc(m, n)(i, j) = mha::numeric_limits<float>::lowest();
}
}
}
}
}
template <uint32_t tileM>
using QuadRegRowMaxT =
Vec<float, divUp(tileM, warp_size) * 4>; // data is replicated across 4 threads in a MMA quad.
template <uint32_t tileM>
using ThrdRegRowMaxT =
Vec<float, divUp(tileM, warp_size)>; // unlike QuadRegRowMax, not replicated.
template <uint32_t tileM>
using UniformRescaleMaskT = Vec<uint32_t, divUp(tileM, warp_size)>; // uniform and stored in UR
inline constexpr uint32_t quadPerWarp = warp_size / 4;
// idxMat8 is the reduced row index in 8-row unit.
template <uint32_t n>
__device__ inline float replicateValForQuad(Warp const& warp, Vec<float, n> const& src,
uint32_t idxMat8) {
assertWarpConverged();
uint32_t const i = idxMat8 / 4;
uint32_t const j = idxMat8 % 4;
return __shfl_sync(~0U, src[i], quadPerWarp * j + laneId() / 4);
}
template <uint32_t n>
__device__ inline QuadRegRowMaxT<n * warp_size> replicateForQuad(Warp const& warp,
Vec<float, n> const& src) {
assertWarpConverged();
QuadRegRowMaxT<n * warp_size> dst{};
#pragma unroll
for (uint32_t i = 0; i < src.size; i++) {
#pragma unroll
for (uint32_t j = 0; j < 4; j++) {
dst[i * 4 + j] = __shfl_sync(~0U, src[i], quadPerWarp * j + laneId() / 4);
assert(dst[i * 4 + j] == replicateValForQuad(warp, src, i * 4 + j));
}
}
return dst;
}
template <uint32_t n>
__device__ inline ThrdRegRowMaxT<warp_size * exactDiv(n, 4)> dedupFromQuad(
Warp const& warp, Vec<float, n> const& src) {
#ifndef NDEBUG
for (uint32_t i = 0; i < src.size; i++) {
assert(src[i] == __shfl_sync(~0U, src[i], laneId() / 4 * 4));
}
#endif
ThrdRegRowMaxT<warp_size * exactDiv(n, 4)> dst{};
uint32_t const lane = laneId();
uint32_t const idxMat = lane / 8;
uint32_t const idxRow = lane % 8;
#pragma unroll
for (uint32_t i = 0; i < dst.size; i++) {
#pragma unroll
for (uint32_t j = 0; j < 4; j++) {
float const val = __shfl_sync(~0U, src[i * 4 + j], 4 * idxRow);
if (idxMat == j) {
dst[i] = val;
}
}
}
#ifndef NDEBUG // refcheck
QuadRegRowMaxT<warp_size * exactDiv(n, 4)> rep = replicateForQuad(warp, dst);
#pragma unroll
for (uint32_t i = 0; i < n; i++) {
assert(src[i] == rep[i]);
__syncwarp();
}
#endif
return dst;
}
template <uint32_t tileM, uint32_t tileN>
__device__ inline ThrdRegRowMaxT<tileM> computeRowSumF8(
Warp const& warp,
Array2D<Array2D<uint32_t, 2, 1>, exactDiv(tileM, 16), exactDiv(tileN, 16)> const& src) {
using WarpAcc = WarpAccT<tileM, 8>;
WarpAcc acc{};
Vec<__nv_fp8x2_e4m3, 2> const bWord = {__nv_fp8x2_e4m3{float2{1, 1}},
__nv_fp8x2_e4m3{float2{1, 1}}};
uint32_t const b[2][1] = {reinterpret_cast<uint32_t const&>(bWord),
reinterpret_cast<uint32_t const&>(bWord)};
#pragma unroll
for (uint32_t i = 0; i < WarpAcc::rows; i++) {
#pragma unroll
for (uint32_t k = 0; k < exactDiv(src.cols, 2); k++) {
mma<__nv_fp8_e4m3>(reinterpret_cast<float(&)[2][2]>(acc(i, 0)),
reinterpret_cast<uint32_t const(&)[2][2]>(src(i, k * 2)), b);
}
}
QuadRegRowMaxT<tileM> rowSum;
for (uint32_t i = 0; i < WarpAcc::rows; i++) {
for (uint32_t m = 0; m < InstAcc::rows; m++) {
#ifndef NDEBUG
assert(acc(i, 0)(m, 0) == acc(i, 0)(m, 1));
assert(acc(i, 0)(m, 0) == __shfl_sync(~0U, acc(i, 0)(m, 0), laneId() / 4 * 4));
#endif
rowSum[i * InstAcc::rows + m] = acc(i, 0)(m, 0);
}
}
return dedupFromQuad(warp, rowSum);
}
template <uint32_t tileM, uint32_t tileN>
__device__ inline ThrdRegRowMaxT<tileM> computeRowSumF32(Warp const& warp,
WarpAccT<tileM, tileN> const& src) {
QuadRegRowMaxT<tileM> rowSum{};
#pragma unroll
for (uint32_t n = 0; n < src.cols; n++) {
#pragma unroll
for (uint32_t j = 0; j < InstAcc::cols; j++) {
#pragma unroll
for (uint32_t m = 0; m < src.rows; m++) {
#pragma unroll
for (uint32_t i = 0; i < InstAcc::rows; i++) {
if (n == 0 && j == 0) {
rowSum[m * InstAcc::rows + i] = src(m, n)(i, j);
} else {
rowSum[m * InstAcc::rows + i] += src(m, n)(i, j);
}
}
}
}
}
uint32_t const lane = laneId();
#pragma unroll
for (uint32_t mask = 2; mask != 0; mask /= 2) {
#pragma unroll
for (uint32_t i = 0; i < rowSum.size; i++) {
rowSum[i] += __shfl_xor_sync(~0U, rowSum[i], mask);
}
}
return dedupFromQuad(warp, rowSum);
}