sglang_v0.5.2/flashinfer_0.3.1/include/flashinfer/comm/trtllm_allreduce_fusion.cuh

1516 lines
50 KiB
Plaintext

#include <cooperative_groups.h>
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#if CUDA_VERSION >= 12080
#include <cuda_fp4.h>
#endif
#include <cuda/std/optional>
#include <tuple>
#include <type_traits>
#include "../exception.h"
#include "../fp4_layout.cuh"
#include "../logging.h"
#include "../utils.cuh"
#include "../vec_dtypes.cuh"
namespace flashinfer {
namespace trtllm_allreduce_fusion {
using flashinfer::QuantizationSFLayout;
namespace details {
static constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
static constexpr int CVT_FP4_SF_VEC_SIZE = 16;
static constexpr int kBytesPerAccess = 16;
static constexpr int kOneShotMaxToken = 128;
static constexpr int kBarrierFlagCount = 256;
} // namespace details
namespace maths {
// // ============================== Cast ==============================
template <typename T_OUT, typename T_IN>
__device__ inline T_OUT cuda_cast(T_IN val) {
return val;
}
template <>
__device__ inline float2 cuda_cast<float2, int2>(int2 val) {
return make_float2(val.x, val.y);
}
template <>
__device__ inline float2 cuda_cast<float2, float>(float val) {
return make_float2(val, val);
}
template <>
__device__ inline float2 cuda_cast<float2, half2>(half2 val) {
return __half22float2(val);
}
template <>
__device__ inline half2 cuda_cast<half2, float2>(float2 val) {
return __float22half2_rn(val);
}
template <>
__device__ inline half2 cuda_cast<half2, float>(float val) {
return __float2half2_rn(val);
}
template <>
__device__ inline half2 cuda_cast<half2, half>(half val) {
return __half2half2(val);
}
template <>
__device__ inline int8_t cuda_cast<int8_t, half>(half val) {
union {
int8_t int8[2];
int16_t int16;
};
union {
half fp16;
int16_t int16_in;
};
fp16 = val;
asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in));
return int8[0];
}
template <>
__device__ inline int16_t cuda_cast<int16_t, half2>(half2 val) {
union {
int8_t int8[2];
int16_t int16;
};
int8[0] = cuda_cast<int8_t>(val.x);
int8[1] = cuda_cast<int8_t>(val.y);
return int16;
}
template <>
__device__ inline int8_t cuda_cast<int8_t, float>(float val) {
union {
int8_t int8[2];
int16_t int16;
};
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
return int8[0];
}
template <>
__device__ inline int16_t cuda_cast<int16_t, float2>(float2 val) {
union {
int8_t int8[2];
int16_t int16;
};
int8[0] = cuda_cast<int8_t>(val.x);
int8[1] = cuda_cast<int8_t>(val.y);
return int16;
}
template <>
__device__ inline half2 cuda_cast<half2, int16_t>(int16_t val) {
union {
int8_t int8[2];
int16_t int16;
};
int16 = val;
return make_half2(int8[0], int8[1]);
}
template <>
__device__ inline float2 cuda_cast<float2, int16_t>(int16_t val) {
union {
int8_t int8[2];
int16_t int16;
};
int16 = val;
return make_float2(int8[0], int8[1]);
}
template <>
__device__ inline __nv_bfloat16 cuda_cast(int32_t val) {
return static_cast<float>(val);
}
template <>
__device__ inline __nv_bfloat16 cuda_cast(int8_t val) {
return static_cast<float>(val);
}
template <>
__device__ inline int8_t cuda_cast(__nv_bfloat16 val) {
return static_cast<float>(val);
}
template <>
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) {
return __bfloat162float(val);
}
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val;
f_val.x = __low2float(val);
f_val.y = __high2float(val);
return f_val;
#else
return __bfloat1622float2(val);
#endif
}
template <>
__device__ inline float2 cuda_cast<float2, __nv_bfloat162>(__nv_bfloat162 val) {
return bf1622float2(val);
}
template <>
__device__ inline half cuda_cast<half, __nv_bfloat16>(__nv_bfloat16 val) {
return __float2half(__bfloat162float(val));
}
inline __device__ int16_t bf1622int16(__nv_bfloat162 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val;
f_val.x = max(min(__low2float(val), 127.f), -128.f);
f_val.y = max(min(__high2float(val), 127.f), -128.f);
union {
int8_t int8[2];
int16_t int16;
};
int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y));
return int16;
#else
val = __hmin2(val, make_bfloat162(127., 127.));
val = __hmax2(val, make_bfloat162(-128., -128.));
union {
int8_t int8[2];
int16_t int16;
};
int8[0] = static_cast<int8_t>(static_cast<short>(val.x));
int8[1] = static_cast<int8_t>(static_cast<short>(val.y));
return int16;
#endif
}
template <>
__device__ inline int16_t cuda_cast<int16_t, __nv_bfloat162>(__nv_bfloat162 val) {
return bf1622int16(val);
}
template <>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) {
return __float2bfloat16(val);
}
template <>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val) {
return __float2bfloat16(__half2float(val));
}
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
__nv_bfloat162 val2;
val2.x = val;
val2.y = val;
return val2;
#else
return __bfloat162bfloat162(val);
#endif
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val) {
return bf162bf162(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val) {
return __float2bfloat162_rn(val);
}
inline __device__ __nv_bfloat162 float22bf162(const float2 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return __floats2bfloat162_rn(val.x, val.y);
#else
return __float22bfloat162_rn(val);
#endif
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val) {
return float22bf162(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val) {
union {
int8_t int8[2];
int16_t int16;
};
int16 = val;
__nv_bfloat162 res;
res.x = cuda_cast<__nv_bfloat16>(int8[0]);
res.y = cuda_cast<__nv_bfloat16>(int8[1]);
return res;
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val) {
return float22bf162(__half22float2(val));
}
// // ============================== Abs ==============================
template <typename T>
__device__ inline T cuda_abs(T val) {
assert(false);
return {};
}
template <>
__device__ inline float cuda_abs(float val) {
return fabs(val);
}
template <>
__device__ inline float2 cuda_abs(float2 val) {
return make_float2(fabs(val.x), fabs(val.y));
}
template <>
__device__ inline half cuda_abs(half val) {
return __habs(val);
}
template <>
__device__ inline half2 cuda_abs(half2 val) {
return __habs2(val);
}
#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)
template <>
__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val) {
return __habs(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val) {
return __habs2(val);
}
#endif
// // ============================== Max ==============================
template <typename To, typename Ti>
__device__ inline To cuda_max(Ti val) {
return cuda_cast<To>(val);
};
template <>
__device__ inline float cuda_max(float2 val) {
return fmaxf(val.x, val.y);
}
template <>
__device__ inline half cuda_max(half2 val) {
return __hmax(val.x, val.y);
}
template <>
__device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
return __hmax(val.x, val.y);
#else
assert(0);
asm volatile("brkpt;\n" ::);
return __nv_bfloat16(0);
#endif
}
// Binary maximum: compute the max of two values.
template <typename T>
__device__ inline T cuda_max(T val1, T val2) {
return (val1 > val2) ? val1 : val2;
}
template <>
__device__ inline float2 cuda_max(float2 val1, float2 val2) {
float2 out;
out.x = fmaxf(val1.x, val2.x);
out.y = fmaxf(val1.y, val2.y);
return out;
}
template <>
__device__ inline half2 cuda_max(half2 val1, half2 val2) {
return __hmax2(val1, val2);
}
template <>
__device__ inline __nv_bfloat162 cuda_max(__nv_bfloat162 val1, __nv_bfloat162 val2) {
return __hmax2(val1, val2);
}
// // ============================== Reciprocal ==============================
// Fast reciprocal.
inline __device__ float reciprocal_approximate_ftz(float a) {
float b;
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
return b;
}
} // namespace maths
namespace utils {
#define FINAL_MASK 0xffffffff
template <typename T, int NUM>
__inline__ __device__ T warpReduceSumV2(T* val) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32);
}
return (T)(0.0f);
}
template <typename T, int NUM>
__inline__ __device__ T blockReduceSumV2(T* val) {
static __shared__ T shared[NUM][33];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
warpReduceSumV2<T, NUM>(val);
if (lane == 0) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
shared[i][wid] = val[i];
}
}
__syncthreads();
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
for (int i = 0; i < NUM; i++) {
val[i] = is_mask ? shared[i][lane] : (T)(0.0f);
}
warpReduceSumV2<T, NUM>(val);
return (T)0.0f;
}
inline int getSMVersion() {
int device{-1};
FLASHINFER_CUDA_CALL(cudaGetDevice(&device));
int sm_major = 0;
int sm_minor = 0;
FLASHINFER_CUDA_CALL(
cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device));
FLASHINFER_CUDA_CALL(
cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
return sm_major * 10 + sm_minor;
}
inline int getSMRegisters() {
int device{-1};
FLASHINFER_CUDA_CALL(cudaGetDevice(&device));
int regs_per_block;
FLASHINFER_CUDA_CALL(
cudaDeviceGetAttribute(&regs_per_block, cudaDevAttrMaxRegistersPerBlock, device));
return regs_per_block;
}
inline __device__ int64_t get_sf_out_offset_128x4(std::optional<int> batchIdx, int mIdx, int kIdx,
std::optional<int> numRows, int numCols) {
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
// batched tensor
// SF layout [numBTiles, numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [bTileIdx, mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t innerKIdx = (kIdx % 4);
int64_t innerKStride = 1;
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
int64_t innerMStride = 4 * innerKStride; // 4
// M tile layout [32, 4] is column-major.
int32_t outerMIdx = (mIdx % 32);
int64_t outerMStride = 4 * innerMStride; // 16
int32_t kTileIdx = (kIdx / 4);
int64_t kTileStride = 32 * outerMStride; // 512
// SF vector size 16. We round the "numCols" up to a multiple of 64.
int factor = details::CVT_FP4_SF_VEC_SIZE * 4;
int32_t numKTiles = (numCols + factor - 1) / factor;
int32_t mTileIdx = mIdx / (32 * 4);
int64_t mTileStride = numKTiles * kTileStride;
// Each SF block has 128 rows so pad rows to the multiple of 128.
int32_t numMTiles = (numRows.value_or(0) + 128 - 1) / 128;
int64_t bTileStride = numMTiles * mTileStride;
// Compute the global offset.
int64_t SFOffset = batchIdx.value_or(0) * bTileStride + mTileIdx * mTileStride +
kTileIdx * kTileStride + outerMIdx * outerMStride + innerMIdx * innerMStride +
innerKIdx * innerKStride;
return SFOffset;
}
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(std::optional<int> batchIdx, int rowIdx,
int colIdx, std::optional<int> numRows,
int numCols, SFType* SFout,
QuantizationSFLayout layout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || CVT_FP4_NUM_THREADS_PER_SF == 2);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
if (layout == QuantizationSFLayout::SWIZZLED_128x4) {
// SF vector index (16 elements share one SF in the K dimension).
// numRows and numCols are unpadded.
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
int32_t mIdx = rowIdx;
auto SFOffset = get_sf_out_offset_128x4(batchIdx, mIdx, kIdx, numRows, numCols);
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
} else if (layout == QuantizationSFLayout::LINEAR) {
// Linear row-major layout, no padding required.
int32_t KTileIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
int32_t numKTiles = numCols / details::CVT_FP4_SF_VEC_SIZE;
int64_t mTileStride = numKTiles;
int64_t BTileStride = numRows.value_or(0) * mTileStride;
int64_t SFOffset = batchIdx.value_or(0) * BTileStride + rowIdx * mTileStride + KTileIdx;
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
} else {
return nullptr;
}
}
#endif
return nullptr;
}
__forceinline__ __device__ uint32_t pack_bytes(uint8_t c0, uint8_t c1, uint8_t c2, uint8_t c3) {
uint32_t val0 = c0;
uint32_t val1 = c1;
uint32_t val2 = c2;
uint32_t val3 = c3;
return (val3 << 24) | (val2 << 16) | (val1 << 8) | val0;
}
#if CUDA_VERSION >= 12080
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
// NOTE: bypass sm_100 requirement by __nv_cvt_float2_to_fp4x2
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), "f"(array[4]), "f"(array[5]),
"f"(array[6]), "f"(array[7]));
return val;
#else
uint32_t val;
__nv_fp4x2_storage_t vals[4];
#pragma unroll
for (int i = 0; i < 4; i++) {
vals[i] = __nv_cvt_float2_to_fp4x2(*(((float2*)array) + i), __NV_E2M1, cudaRoundNearest);
}
val = pack_bytes(vals[0], vals[1], vals[2], vals[3]);
return val;
#endif
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), "f"(array[2].x),
"f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
return val;
#else
uint32_t val;
__nv_fp4x2_storage_t vals[4];
#pragma unroll
for (int i = 0; i < 4; i++) {
vals[i] = __nv_cvt_float2_to_fp4x2(array[i], __NV_E2M1, cudaRoundNearest);
}
val = pack_bytes(vals[0], vals[1], vals[2], vals[3]);
return val;
#endif
}
// Quantizes the provided PackedVec into the uint32_t output
template <typename T, uint32_t VEC_SIZE, bool UE8M0_SF = false>
__device__ uint32_t cvt_warp_fp16_to_fp4(vec_t<T, VEC_SIZE>& vec, float SFScaleVal,
uint8_t* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Get absolute maximum values among the local 8 values.
auto localMax = maths::cuda_abs(get_vec2_element(vec, 0));
#pragma unroll
for (int i = 1; i < details::CVT_FP4_ELTS_PER_THREAD / 2; i++) {
localMax = maths::cuda_max(localMax, maths::cuda_abs(get_vec2_element(vec, i)));
}
// Get the absolute maximum among all 16 values (two threads).
localMax = maths::cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
// Get the final absolute maximum values.
float vecMax = float(maths::cuda_max(localMax.x, localMax.y));
// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float SFValue = SFScaleVal * (vecMax * maths::reciprocal_approximate_ftz(6.0f));
// 8 bits representation of the SF.
uint8_t fp8SFVal;
// Write the SF to global memory (STG.8).
if constexpr (UE8M0_SF) {
#if (__CUDACC_VER_MAJOR__ * 1000 + __CUDACC_VER_MINOR__ * 10 >= 12080)
__nv_fp8_e8m0 tmp;
tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf);
SFValue = static_cast<float>(tmp);
fp8SFVal = tmp.__x;
#else
#error "FP8 E8M0 support requires CUDA 12.8 or newer."
#endif
} else {
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
fp8SFVal = tmp.__x;
SFValue = static_cast<float>(tmp);
}
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * reciprocal(SFScaleVal))
float outputScale = SFValue != 0 ? maths::reciprocal_approximate_ftz(
SFValue * maths::reciprocal_approximate_ftz(SFScaleVal))
: 0.0f;
if (SFout) {
// Write the SF to global memory (STG.8).
*SFout = fp8SFVal;
}
// Convert the input to float.
float2 fp2Vals[details::CVT_FP4_ELTS_PER_THREAD / 2];
#pragma unroll
for (int i = 0; i < details::CVT_FP4_ELTS_PER_THREAD / 2; i++) {
if constexpr (std::is_same_v<T, half>) {
fp2Vals[i] = __half22float2(get_vec2_element(vec, i));
} else {
fp2Vals[i] = __bfloat1622float2(get_vec2_element(vec, i));
}
fp2Vals[i].x *= outputScale;
fp2Vals[i].y *= outputScale;
}
// Convert to e2m1 values.
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
// Write the e2m1 values to global memory.
return e2m1Vec;
#else
return 0;
#endif
}
#endif
} // namespace utils
template <typename T, uint32_t VEC_SIZE>
__device__ __forceinline__ vec_t<T, VEC_SIZE> vec_add(const vec_t<T, VEC_SIZE>& a,
const vec_t<T, VEC_SIZE>& b) {
vec_t<T, VEC_SIZE> ret;
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
ret[i] = static_cast<float>(a[i]) + static_cast<float>(b[i]);
}
return ret;
}
enum class AllReduceFusionPattern : int {
kAllReduce = 0,
kARResidualRMSNorm = 1,
kARResidualRMSNormFP8Quant = 2,
kARResidualRMSNormFP4Quant = 3,
// The difference between these two and the standard version is that the NormOut version outputs
// the result of the norm.
kARResidualRMSNormOutFP8Quant = 4,
kARResidualRMSNormOutFP4Quant = 5
};
enum class QuantType : int {
kNone = 0,
kFP8 = 1,
kFP4 = 2,
};
template <AllReduceFusionPattern Pattern>
struct FusionPatternTraits;
#define DEFINE_FUSION_PATTERN_TRAITS(pattern, hasAllReduceOut, hasResidual, hasResidualOut, \
hasRMSNorm, hasNormOut, quantType) \
template <> \
struct FusionPatternTraits<pattern> { \
static constexpr bool kHasAllReduceOut = hasAllReduceOut; \
static constexpr bool kHasResidual = hasResidual; \
static constexpr bool kHasResidualOut = hasResidualOut; \
static constexpr bool kHasRMSNorm = hasRMSNorm; \
static constexpr bool kHasNormOut = hasNormOut; \
static constexpr QuantType kQuantType = quantType; \
};
DEFINE_FUSION_PATTERN_TRAITS(AllReduceFusionPattern::kAllReduce, true, false, false, false, false,
QuantType::kNone);
DEFINE_FUSION_PATTERN_TRAITS(AllReduceFusionPattern::kARResidualRMSNorm, false, true, true, true,
true, QuantType::kNone);
DEFINE_FUSION_PATTERN_TRAITS(AllReduceFusionPattern::kARResidualRMSNormFP8Quant, false, true, true,
true, false, QuantType::kFP8);
DEFINE_FUSION_PATTERN_TRAITS(AllReduceFusionPattern::kARResidualRMSNormFP4Quant, false, true, true,
true, false, QuantType::kFP4);
DEFINE_FUSION_PATTERN_TRAITS(AllReduceFusionPattern::kARResidualRMSNormOutFP8Quant, false, true,
true, true, true, QuantType::kFP8);
DEFINE_FUSION_PATTERN_TRAITS(AllReduceFusionPattern::kARResidualRMSNormOutFP4Quant, false, true,
true, true, true, QuantType::kFP4);
#undef DEFINE_FUSION_PATTERN_TRAITS
template <AllReduceFusionPattern Pattern>
constexpr bool HasResidual = FusionPatternTraits<Pattern>::kHasResidual;
template <AllReduceFusionPattern Pattern>
constexpr bool HasRMSNorm = FusionPatternTraits<Pattern>::kHasRMSNorm;
template <AllReduceFusionPattern Pattern>
constexpr bool HasAllReduceOut = FusionPatternTraits<Pattern>::kHasAllReduceOut;
template <AllReduceFusionPattern Pattern>
constexpr bool HasResidualOut = FusionPatternTraits<Pattern>::kHasResidualOut;
template <AllReduceFusionPattern Pattern>
constexpr bool HasNormOut = FusionPatternTraits<Pattern>::kHasNormOut;
template <AllReduceFusionPattern Pattern>
constexpr QuantType GetQuantType = FusionPatternTraits<Pattern>::kQuantType;
template <typename T>
struct AllReduceFusionParams {
int nranks;
int rank;
int size;
int hidden_dim;
void** workspace;
void* allreduce_in;
void* allreduce_out;
void* residual_in;
void* residual_out;
void* norm_out;
void* quant_out;
void* scale_out;
void* rms_gamma;
float rms_eps;
float* scale_factor;
bool use_oneshot;
QuantizationSFLayout layout = QuantizationSFLayout::SWIZZLED_128x4;
cudaStream_t stream;
AllReduceFusionPattern pattern;
bool trigger_completion_at_end = true;
};
template <int NRanks>
struct SyncComm {
__device__ __forceinline__ SyncComm(void** workspace) {
counter_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[0];
flag_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[1];
flag_value = *flag_ptr;
for (int r = 0; r < NRanks; ++r) {
comm_bufs[r] = workspace[r];
barrier_flags[r] = workspace[NRanks + r];
}
__syncthreads();
if (threadIdx.x == 0) {
atomicAdd(counter_ptr, 1);
}
}
__device__ __forceinline__ void update(int new_flag_value) {
if (blockIdx.x == 0 && threadIdx.x == 0) {
while (*reinterpret_cast<int volatile*>(counter_ptr) != gridDim.x) {
}
*flag_ptr = new_flag_value;
*counter_ptr = 0;
}
}
int* counter_ptr;
int* flag_ptr;
void* comm_bufs[NRanks];
void* barrier_flags[NRanks];
int flag_value;
};
template <int NRanks>
struct LamportComm {
__device__ __forceinline__ LamportComm(void** workspace, int rank) {
counter_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[0];
flag_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[2];
clear_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[4];
flag_value = *flag_ptr;
int comm_size = reinterpret_cast<int*>(workspace[NRanks * 3])[3];
clear_size = *clear_ptr;
int data_offset = flag_value % 3;
int clear_offset = (flag_value + 2) % 3;
for (int r = 0; r < NRanks; ++r) {
data_bufs[r] = reinterpret_cast<uint8_t*>(workspace[2 * NRanks + r]) +
static_cast<int64_t>(data_offset) * comm_size;
}
clear_buf = reinterpret_cast<uint8_t*>(workspace[2 * NRanks + rank]) + clear_offset * comm_size;
__syncthreads();
if (threadIdx.x == 0) {
atomicAdd(counter_ptr, 1);
}
}
__device__ __forceinline__ void update(int new_clear_size) {
if (blockIdx.x == 0 && threadIdx.x == 0) {
while (*reinterpret_cast<int volatile*>(counter_ptr) != gridDim.x) {
}
*flag_ptr = (flag_value + 1) % 3;
*clear_ptr = new_clear_size;
*counter_ptr = 0;
}
}
int* counter_ptr;
int* flag_ptr;
int* clear_ptr;
uint8_t* data_bufs[NRanks];
uint8_t* clear_buf;
int clear_size;
int flag_value;
};
template <int NRanks>
class Barrier {
public:
__device__ __forceinline__ Barrier(int rank, SyncComm<NRanks> const& comm) {
if (threadIdx.x < NRanks) {
m_flag_value = comm.flag_value;
int current_rank = rank;
int target_rank = threadIdx.x;
m_target_flag = reinterpret_cast<int*>(comm.barrier_flags[target_rank]) + current_rank;
m_current_flag = reinterpret_cast<int*>(comm.barrier_flags[current_rank]) +
blockIdx.x * NRanks + target_rank;
}
}
__device__ __forceinline__ void sync() {
__syncthreads();
if (threadIdx.x < NRanks) {
m_flag_value = next_flag(m_flag_value);
// To avoid the ABA problem, we need to synchronize the correct flag value to all
// barrier_flags, even if the corresponding CTA has not been launched.
for (int flag_idx = blockIdx.x; flag_idx < details::kBarrierFlagCount;
flag_idx += gridDim.x) {
st_flag(m_target_flag + flag_idx * NRanks, m_flag_value);
}
while (ld_flag(m_current_flag) == prev_flag(m_flag_value)) {
}
}
__syncthreads();
}
protected:
__device__ __forceinline__ void st_flag(int* addr, int flag) {
asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(addr));
}
__device__ __forceinline__ int ld_flag(int* addr) {
int flag;
asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(addr));
return flag;
}
__device__ __forceinline__ int next_flag(int flag) { return flag == 2 ? 0 : flag + 1; }
__device__ __forceinline__ int prev_flag(int flag) { return flag == 0 ? 2 : flag - 1; }
public:
int m_flag_value;
private:
int* m_target_flag;
int* m_current_flag;
};
template <AllReduceFusionPattern Pattern, typename T>
class FusedOp {
static constexpr int VEC_SIZE = details::kBytesPerAccess / sizeof(T);
public:
__device__ __forceinline__ FusedOp(AllReduceFusionParams<T> const& params, int access_id,
int access_id_in_token)
: m_params(params), m_access_id(access_id), m_access_id_in_token(access_id_in_token) {
if constexpr (HasRMSNorm<Pattern>) {
m_gamma_val.load(reinterpret_cast<T*>(params.rms_gamma) + m_access_id_in_token * VEC_SIZE);
}
if constexpr (HasResidual<Pattern>) {
m_residual_val.load(reinterpret_cast<T*>(params.residual_in) + m_access_id * VEC_SIZE);
}
if constexpr (GetQuantType<Pattern> == QuantType::kFP8) {
m_scale_factor = 1.f / *(params.scale_factor);
} else if constexpr (GetQuantType<Pattern> == QuantType::kFP4) {
m_scale_factor = *(params.scale_factor);
}
}
// template <typename T>
__device__ __forceinline__ void update(int access_id) {
if (m_access_id != access_id) {
m_access_id = access_id;
if constexpr (HasResidual<Pattern>) {
m_residual_val.load(reinterpret_cast<T*>(m_params.residual_in) + m_access_id * VEC_SIZE);
}
}
}
// template <typename T, uint32_t VEC_SIZE>
__device__ __forceinline__ void operator()(vec_t<T, VEC_SIZE> val, int token_id) {
if constexpr (HasAllReduceOut<Pattern>) {
val.store(reinterpret_cast<T*>(m_params.allreduce_out) + m_access_id * VEC_SIZE);
}
if constexpr (HasResidual<Pattern>) {
val = vec_add<T, VEC_SIZE>(val, m_residual_val);
if constexpr (HasResidualOut<Pattern>) {
val.store(reinterpret_cast<T*>(m_params.residual_out) + m_access_id * VEC_SIZE);
}
}
if constexpr (HasRMSNorm<Pattern>) {
val = rms_norm(val, m_gamma_val);
if constexpr (HasNormOut<Pattern>) {
val.store(reinterpret_cast<T*>(m_params.norm_out) + m_access_id * VEC_SIZE);
}
}
#if CUDA_VERSION >= 12080
if constexpr (GetQuantType<Pattern> == QuantType::kFP4) {
// NOTE(Yingyi): might update later
auto sf_out = utils::cvt_quant_to_fp4_get_sf_out_offset<uint32_t, 2>(
std::nullopt /* batchIdx */, token_id, m_access_id_in_token, std::nullopt /* numRows */,
m_params.hidden_dim, reinterpret_cast<uint32_t*>(m_params.scale_out), m_params.layout);
reinterpret_cast<uint32_t*>(m_params.quant_out)[m_access_id] =
utils::cvt_warp_fp16_to_fp4<T, VEC_SIZE>(val, m_scale_factor, sf_out);
} else
#endif
if constexpr (GetQuantType<Pattern> == QuantType::kFP8) {
using PackedQuantizedType = std::conditional_t<std::is_same_v<T, float>, float, float2>;
PackedQuantizedType ret;
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
reinterpret_cast<__nv_fp8_e4m3*>(&ret)[i] = static_cast<__nv_fp8_e4m3>(
static_cast<float>(reinterpret_cast<T*>(&val)[i]) * m_scale_factor);
}
reinterpret_cast<PackedQuantizedType*>(m_params.quant_out)[m_access_id] = ret;
} else {
static_assert(GetQuantType<Pattern> == QuantType::kNone, "Invalid quant type");
}
}
protected:
__device__ __forceinline__ vec_t<T, VEC_SIZE> rms_norm(vec_t<T, VEC_SIZE> const& residual,
vec_t<T, VEC_SIZE> const& gamma) {
__shared__ float s_val;
vec_t<T, VEC_SIZE> norm_out;
float acc = 0.f;
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
float v = static_cast<float>(reinterpret_cast<T const*>(&residual)[i]);
acc += v * v;
}
utils::blockReduceSumV2<float, 1>(&acc);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
namespace cg = cooperative_groups;
cg::cluster_group cluster = cg::this_cluster();
if (cluster.num_blocks() > 1) {
if (threadIdx.x == 0) {
s_val = acc;
acc = 0.f;
}
cluster.sync();
if (threadIdx.x == 0) {
for (int i = 0; i < cluster.num_blocks(); ++i) {
acc += *cluster.map_shared_rank(&s_val, i);
}
}
cluster.sync();
}
#endif
if (threadIdx.x == 0) {
s_val = rsqrtf(acc / m_params.hidden_dim + m_params.rms_eps);
}
__syncthreads();
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
reinterpret_cast<T*>(&norm_out)[i] =
static_cast<T>(static_cast<float>(reinterpret_cast<T const*>(&residual)[i]) * s_val *
static_cast<float>(reinterpret_cast<T const*>(&gamma)[i]));
}
return norm_out;
}
private:
AllReduceFusionParams<T> const& m_params;
int m_access_id;
int m_access_id_in_token;
float m_scale_factor;
vec_t<T, VEC_SIZE> m_residual_val;
vec_t<T, VEC_SIZE> m_gamma_val;
};
template <typename T>
struct neg_zero {
static constexpr T value = -T(0);
};
template <>
struct neg_zero<half> {
static constexpr unsigned short neg_zero_bits = 0x8000U;
static constexpr __half value = __half_raw{neg_zero_bits};
};
template <>
struct neg_zero<nv_bfloat16> {
static constexpr unsigned short neg_zero_bits = 0x8000U;
static constexpr __nv_bfloat16 value = __nv_bfloat16_raw{neg_zero_bits};
};
template <>
struct neg_zero<float> {
static constexpr unsigned int neg_zero_bits = 0x80000000U;
static constexpr float value = -0.0f;
};
template <typename T>
__device__ static constexpr T neg_zero_v = neg_zero<T>::value;
template <typename T>
__device__ bool is_negative_zero(T) {
return false;
}
// float specialization
template <>
__device__ bool is_negative_zero<float>(float x) {
return (__float_as_int(x) == 0x80000000);
}
// double specialization
template <>
__device__ bool is_negative_zero<double>(double x) {
return (__double_as_longlong(x) == 0x8000000000000000ULL);
}
// __half specialization
template <>
__device__ bool is_negative_zero<__half>(__half x) {
return (__half_as_ushort(x) == 0x8000);
}
// __nv_bfloat16 specialization
template <>
__device__ bool is_negative_zero<__nv_bfloat16>(__nv_bfloat16 x) {
return (__bfloat16_as_ushort(x) == 0x8000);
}
template <typename T, uint32_t VEC_SIZE>
__device__ __forceinline__ bool has_neg_zero(const vec_t<T, VEC_SIZE>& vec) {
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
if (is_negative_zero(vec[i])) {
return true;
}
}
return false;
}
template <typename T, uint32_t VEC_SIZE>
__device__ __forceinline__ void remove_neg_zero(vec_t<T, VEC_SIZE>& vec) {
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
vec[i] = (is_negative_zero(vec[i])) ? static_cast<T>(0.f) : vec[i];
}
}
template <typename T>
__device__ __forceinline__ void set_neg_zero(T* addr) {
vec_t<T, details::kBytesPerAccess / sizeof(T)> val;
val.fill(neg_zero_v<T>);
val.store_global_volatile(addr);
}
template <typename T, uint32_t VEC_SIZE, int NRanks, bool Fp32Acc>
__device__ __forceinline__ vec_t<T, VEC_SIZE> allreduce_sum(vec_t<T, VEC_SIZE>* vals) {
if constexpr (Fp32Acc) {
static_assert(!std::is_same_v<T, float>);
float acc_f32[VEC_SIZE];
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
acc_f32[i] = static_cast<float>(reinterpret_cast<T*>(&vals[0])[i]);
}
#pragma unroll
for (int r = 1; r < NRanks; ++r) {
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
acc_f32[i] += static_cast<float>(reinterpret_cast<T*>(&vals[r])[i]);
}
}
vec_t<T, VEC_SIZE> acc;
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
acc[i] = static_cast<T>(acc_f32[i]);
}
return acc;
} else {
vec_t<T, VEC_SIZE> acc = vals[0];
#pragma unroll
for (int r = 1; r < NRanks; ++r) {
acc = vec_add<T, VEC_SIZE>(acc, vals[r]);
}
return acc;
}
}
template <typename T>
class IndexHelper {
public:
__device__ __forceinline__ IndexHelper(AllReduceFusionParams<T> const& params) {
static constexpr int VEC_SIZE = details::kBytesPerAccess / sizeof(T);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
namespace cg = cooperative_groups;
cg::cluster_group cluster = cg::this_cluster();
cg::grid_group grid = cg::this_grid();
token_id = grid.cluster_rank();
access_id_in_token = cluster.thread_rank();
token_stride = grid.num_clusters();
#else
token_id = blockIdx.x;
access_id_in_token = threadIdx.x;
token_stride = gridDim.x;
#endif
access_id = token_id * params.hidden_dim / VEC_SIZE + access_id_in_token;
access_stride = token_stride * params.hidden_dim / VEC_SIZE;
tot_access = params.size / VEC_SIZE;
}
int token_id;
int access_id_in_token;
int token_stride;
int access_id;
int access_stride;
int tot_access;
};
template <AllReduceFusionPattern Pattern, typename T, int NRanks, bool Fp32Acc,
bool TriggerCompletionAtEnd = true>
__global__ void allreduce_fusion_kernel_oneshot_lamport(AllReduceFusionParams<T> params) {
static constexpr int VEC_SIZE = details::kBytesPerAccess / sizeof(T);
IndexHelper<T> index_helper(params);
int token_id = index_helper.token_id;
int access_id_in_token = index_helper.access_id_in_token;
int token_stride = index_helper.token_stride;
int access_id = index_helper.access_id;
int access_stride = index_helper.access_stride;
int tot_access = index_helper.tot_access;
vec_t<T, VEC_SIZE> clear_vec;
clear_vec.fill(neg_zero_v<T>);
FusedOp<Pattern, T> fused_op(params, access_id, access_id_in_token);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
if constexpr (!TriggerCompletionAtEnd) {
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
LamportComm<NRanks> comm(params.workspace, params.rank);
int clear_access = comm.clear_size / VEC_SIZE;
for (int idx = access_id; idx < tot_access; idx += access_stride) {
vec_t<T, VEC_SIZE> val;
val.load(reinterpret_cast<T*>(params.allreduce_in) + idx * VEC_SIZE);
remove_neg_zero<T, VEC_SIZE>(val);
#pragma unroll
for (int r = 0; r < NRanks; ++r) {
// Push data to other ranks
val.store(reinterpret_cast<T*>(comm.data_bufs[r]) +
(params.rank * tot_access + idx) * VEC_SIZE);
}
}
for (int idx = access_id; idx < clear_access; idx += access_stride) {
// Clear comm buffer that previous kernel used
clear_vec.store(reinterpret_cast<T*>(comm.clear_buf) + idx * VEC_SIZE);
}
for (int idx = access_id, tidx = token_id; idx < tot_access;
idx += access_stride, tidx += token_stride) {
fused_op.update(idx);
vec_t<T, VEC_SIZE> vals[NRanks];
bool done = false;
while (!done) {
done = true;
#pragma unroll
for (int r = 0; r < NRanks; ++r) {
// LDG.128 from local rank
vals[r].load_global_volatile(reinterpret_cast<T*>(comm.data_bufs[params.rank]) +
(r * tot_access + idx) * VEC_SIZE);
done &= !has_neg_zero<T, VEC_SIZE>(vals[r]);
}
}
vec_t<T, VEC_SIZE> sum_val = allreduce_sum<T, VEC_SIZE, NRanks, Fp32Acc>(vals);
fused_op(sum_val, tidx);
}
comm.update(params.size * NRanks);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
if constexpr (TriggerCompletionAtEnd) {
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
}
template <AllReduceFusionPattern Pattern, typename T, int NRanks, bool Fp32Acc>
__global__ void allreduce_fusion_kernel_twoshot_sync(AllReduceFusionParams<T> params,
std::array<int, NRanks> begin_tokens,
std::array<int, NRanks> token_num_per_ranks) {
static constexpr int VEC_SIZE = details::kBytesPerAccess / sizeof(T);
IndexHelper<T> index_helper(params);
int token_id = index_helper.token_id;
int access_id_in_token = index_helper.access_id_in_token;
int token_stride = index_helper.token_stride;
int access_id = index_helper.access_id;
int access_stride = index_helper.access_stride;
int tot_access = index_helper.tot_access;
FusedOp<Pattern, T> fused_op(params, access_id, access_id_in_token);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
SyncComm<NRanks> comm(params.workspace);
#pragma unroll
for (int r = 0; r < NRanks; ++r) {
int comm_access_id = access_id + begin_tokens[r] * params.hidden_dim / VEC_SIZE;
int comm_tot_access = (begin_tokens[r] + token_num_per_ranks[r]) * params.hidden_dim / VEC_SIZE;
for (int idx = comm_access_id; idx < comm_tot_access; idx += access_stride) {
reinterpret_cast<float4*>(comm.comm_bufs[params.rank])[idx] =
reinterpret_cast<float4*>(params.allreduce_in)[idx];
}
}
Barrier<NRanks> barrier(params.rank, comm);
barrier.sync();
int comm_access_id = access_id + begin_tokens[params.rank] * params.hidden_dim / VEC_SIZE;
int comm_tot_access =
(begin_tokens[params.rank] + token_num_per_ranks[params.rank]) * params.hidden_dim / VEC_SIZE;
for (int idx = comm_access_id; idx < comm_tot_access; idx += access_stride) {
vec_t<T, VEC_SIZE> vals[NRanks];
#pragma unroll
for (int r = 0; r < NRanks; ++r) {
vals[r].load(reinterpret_cast<T*>(comm.comm_bufs[r]) + idx * VEC_SIZE);
}
vec_t<T, VEC_SIZE> sum_val = allreduce_sum<T, VEC_SIZE, NRanks, Fp32Acc>(vals);
#pragma unroll
for (int r = 0; r < NRanks; ++r) {
sum_val.store(reinterpret_cast<T*>(comm.comm_bufs[r]) + (tot_access + idx) * VEC_SIZE);
}
}
barrier.sync();
#pragma unroll
for (int r = 0; r < NRanks; ++r) {
int comm_access_id = access_id + begin_tokens[r] * params.hidden_dim / VEC_SIZE;
int comm_token_id = token_id + begin_tokens[r];
int comm_tot_access = (begin_tokens[r] + token_num_per_ranks[r]) * params.hidden_dim / VEC_SIZE;
for (int idx = comm_access_id, tidx = comm_token_id; idx < comm_tot_access;
idx += access_stride, tidx += token_stride) {
fused_op.update(idx);
vec_t<T, VEC_SIZE> sum_val;
sum_val.load(reinterpret_cast<T*>(comm.comm_bufs[params.rank]) +
(tot_access + idx) * VEC_SIZE);
fused_op(sum_val, tidx);
}
}
comm.update(barrier.m_flag_value);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
int get_sm_count() {
static int sm_count = 0;
if (sm_count == 0) {
int device_id;
FLASHINFER_CUDA_CALL(cudaGetDevice(&device_id));
FLASHINFER_CUDA_CALL(
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device_id));
}
return sm_count;
}
template <AllReduceFusionPattern Pattern, typename T, int NRanks, bool Fp32Acc,
bool TriggerCompletionAtEnd = true>
cudaError_t launch_oneshot_lamport(AllReduceFusionParams<T> const& params,
cudaLaunchConfig_t& cfg) {
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(
&cfg,
allreduce_fusion_kernel_oneshot_lamport<Pattern, T, NRanks, Fp32Acc, TriggerCompletionAtEnd>,
params));
return cudaSuccess;
}
template <AllReduceFusionPattern Pattern, typename T, int NRanks, bool Fp32Acc,
bool TriggerCompletionAtEnd = true>
int get_registers_per_thread_oneshot() {
auto kernel =
allreduce_fusion_kernel_oneshot_lamport<Pattern, T, NRanks, Fp32Acc, TriggerCompletionAtEnd>;
cudaFuncAttributes attr;
cudaFuncGetAttributes(&attr, kernel);
return attr.numRegs;
}
template <AllReduceFusionPattern Pattern, typename T, int NRanks, bool Fp32Acc>
cudaError_t launch_twoshot_sync(AllReduceFusionParams<T> const& params, cudaLaunchConfig_t& cfg,
std::array<int, NRanks> begin_tokens,
std::array<int, NRanks> token_num_per_ranks) {
FLASHINFER_CUDA_CALL(
cudaLaunchKernelEx(&cfg, allreduce_fusion_kernel_twoshot_sync<Pattern, T, NRanks, Fp32Acc>,
params, begin_tokens, token_num_per_ranks));
return cudaSuccess;
}
template <AllReduceFusionPattern Pattern, typename T, int NRanks, bool Fp32Acc>
int get_registers_per_thread_twoshot() {
auto kernel = allreduce_fusion_kernel_twoshot_sync<Pattern, T, NRanks, Fp32Acc>;
cudaFuncAttributes attr;
cudaFuncGetAttributes(&attr, kernel);
return attr.numRegs;
}
bool use_oneshot(int token_num) { return token_num <= details::kOneShotMaxToken; }
template <AllReduceFusionPattern Pattern, typename T, int NRanks, bool Fp32Acc>
cudaError_t allreduce_fusion_kernel_launcher(AllReduceFusionParams<T> const& params,
bool launch_with_pdl) {
static constexpr int VEC_SIZE = details::kBytesPerAccess / sizeof(T);
FLASHINFER_CHECK(params.size % params.hidden_dim == 0, "params.size % params.hidden_dim != 0");
FLASHINFER_CHECK(params.hidden_dim % VEC_SIZE == 0, "params.hidden_dim % VEC_SIZE != 0");
static int SM = utils::getSMVersion();
int token_num = params.size / params.hidden_dim;
bool oneshot = params.use_oneshot;
int cluster_num = token_num;
std::array<int, NRanks> begin_tokens, token_num_per_ranks;
if (!oneshot) {
int remaining_token = token_num % NRanks;
int token_num_per_rank = token_num / NRanks;
cluster_num = token_num_per_rank;
if (remaining_token) {
cluster_num++;
}
for (int r = 0; r < NRanks; ++r) {
begin_tokens[r] = r * token_num_per_rank + (remaining_token > r ? r : remaining_token);
token_num_per_ranks[r] = token_num_per_rank + (remaining_token > r ? 1 : 0);
}
}
int threads_per_token = params.hidden_dim / VEC_SIZE;
int cluster_size;
if (SM >= 90) {
cluster_size = 8;
} else {
cluster_size = 1;
}
while (threads_per_token % cluster_size != 0 && cluster_size > 1) {
cluster_size /= 2;
}
int threads_per_block = threads_per_token / cluster_size;
while (threads_per_block < 128 && cluster_size >= 2) {
threads_per_block *= 2;
cluster_size /= 2;
}
int sm_count = get_sm_count();
int registers_per_thread;
if (oneshot) {
if (params.trigger_completion_at_end) {
registers_per_thread = get_registers_per_thread_oneshot<Pattern, T, NRanks, Fp32Acc, true>();
} else {
registers_per_thread = get_registers_per_thread_oneshot<Pattern, T, NRanks, Fp32Acc, false>();
}
} else {
registers_per_thread = get_registers_per_thread_twoshot<Pattern, T, NRanks, Fp32Acc>();
}
static int max_registers = -1;
if (max_registers < 0) {
max_registers = utils::getSMRegisters();
}
int max_threads_per_block = min(max_registers / registers_per_thread, 1024);
while (cluster_num * cluster_size > sm_count && cluster_size > 1 &&
threads_per_block <= max_threads_per_block / 2) {
threads_per_block *= 2;
cluster_size /= 2;
}
FLASHINFER_CHECK(oneshot || threads_per_block >= params.nranks,
"not oneshot, or threads_per_block < nranks");
int block_size = threads_per_block;
FLASHINFER_CHECK(block_size <= 1024 && cluster_size > 0,
"block_size > 1024 or cluster_size <= 0");
int grid_size = (std::min(sm_count, cluster_num * cluster_size) / cluster_size) * cluster_size;
cudaLaunchConfig_t cfg;
cudaLaunchAttribute attribute[2];
cfg.gridDim = grid_size;
cfg.blockDim = block_size;
cfg.dynamicSmemBytes = 0;
cfg.stream = params.stream;
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = launch_with_pdl ? 1 : 0;
attribute[1].id = cudaLaunchAttributeClusterDimension;
attribute[1].val.clusterDim.x = cluster_size;
attribute[1].val.clusterDim.y = 1;
attribute[1].val.clusterDim.z = 1;
cfg.attrs = attribute;
cfg.numAttrs = SM >= 90 ? 2 : 0;
if (oneshot) {
bool trigger_completion_at_end = params.trigger_completion_at_end;
if (trigger_completion_at_end) {
FLASHINFER_CUDA_CALL(
(launch_oneshot_lamport<Pattern, T, NRanks, Fp32Acc, true>(params, cfg)));
} else {
FLASHINFER_CUDA_CALL(
(launch_oneshot_lamport<Pattern, T, NRanks, Fp32Acc, false>(params, cfg)));
}
} else {
FLASHINFER_CUDA_CALL((launch_twoshot_sync<Pattern, T, NRanks, Fp32Acc>(
params, cfg, begin_tokens, token_num_per_ranks)));
}
return cudaSuccess;
}
template <typename T>
cudaError_t allreduce_fusion_op(AllReduceFusionParams<T> const& params, bool launch_with_pdl,
bool fp32_acc) {
#define DISPATCH_ACC_TYPE(T, Pattern, NRanks) \
if constexpr (std::is_same_v<T, float>) { \
return allreduce_fusion_kernel_launcher<Pattern, T, NRanks, false>(params, launch_with_pdl); \
} else { \
if (fp32_acc) { \
return allreduce_fusion_kernel_launcher<Pattern, T, NRanks, true>(params, launch_with_pdl); \
} else { \
return allreduce_fusion_kernel_launcher<Pattern, T, NRanks, false>(params, launch_with_pdl); \
} \
}
#define DISPATCH_PATTERN(T, NRanks) \
switch (params.pattern) { \
case AllReduceFusionPattern::kAllReduce: \
DISPATCH_ACC_TYPE(T, AllReduceFusionPattern::kAllReduce, NRanks); \
break; \
case AllReduceFusionPattern::kARResidualRMSNorm: \
DISPATCH_ACC_TYPE(T, AllReduceFusionPattern::kARResidualRMSNorm, NRanks); \
break; \
case AllReduceFusionPattern::kARResidualRMSNormFP8Quant: \
DISPATCH_ACC_TYPE(T, AllReduceFusionPattern::kARResidualRMSNormFP8Quant, NRanks); \
break; \
case AllReduceFusionPattern::kARResidualRMSNormFP4Quant: \
if constexpr (!std::is_same_v<T, float> && CUDA_VERSION >= 12080) { \
DISPATCH_ACC_TYPE(T, AllReduceFusionPattern::kARResidualRMSNormFP4Quant, NRanks); \
} else { \
FLASHINFER_CHECK(CUDA_VERSION >= 12080, "FP4Quant requires CUDA 12.8 or higher"); \
FLASHINFER_CHECK(false, "FP4Quant pattern cannot work with DType=float"); \
} \
break; \
case AllReduceFusionPattern::kARResidualRMSNormOutFP8Quant: \
DISPATCH_ACC_TYPE(T, AllReduceFusionPattern::kARResidualRMSNormOutFP8Quant, NRanks); \
break; \
case AllReduceFusionPattern::kARResidualRMSNormOutFP4Quant: \
if constexpr (!std::is_same_v<T, float> && CUDA_VERSION >= 12080) { \
DISPATCH_ACC_TYPE(T, AllReduceFusionPattern::kARResidualRMSNormOutFP4Quant, NRanks); \
} else { \
FLASHINFER_CHECK(CUDA_VERSION >= 12080, "OutFP4Quant requires CUDA 12.8 or higher"); \
FLASHINFER_CHECK(false, "OutFP4Quant pattern cannot work with DType=float"); \
} \
break; \
default: \
FLASHINFER_CHECK(false, "Unsupported allreduce fusion pattern"); \
}
switch (params.nranks) {
case 2:
DISPATCH_PATTERN(T, 2);
break;
case 4:
DISPATCH_PATTERN(T, 4);
break;
case 8:
DISPATCH_PATTERN(T, 8);
break;
case 16:
DISPATCH_PATTERN(T, 16);
break;
default:
FLASHINFER_ERROR(
"allreduce_fusion_kernel: unsupported ranks number! Supported ranks: 2, 4, 8, 16.");
}
}
} // namespace trtllm_allreduce_fusion
} // namespace flashinfer