#include #include #include #include #if CUDA_VERSION >= 12080 #include #endif #include #include #include #include "../exception.h" #include "../fp4_layout.cuh" #include "../logging.h" #include "../utils.cuh" #include "../vec_dtypes.cuh" namespace flashinfer { namespace trtllm_allreduce_fusion { using flashinfer::QuantizationSFLayout; namespace details { static constexpr int CVT_FP4_ELTS_PER_THREAD = 8; static constexpr int CVT_FP4_SF_VEC_SIZE = 16; static constexpr int kBytesPerAccess = 16; static constexpr int kOneShotMaxToken = 128; static constexpr int kBarrierFlagCount = 256; } // namespace details namespace maths { // // ============================== Cast ============================== template __device__ inline T_OUT cuda_cast(T_IN val) { return val; } template <> __device__ inline float2 cuda_cast(int2 val) { return make_float2(val.x, val.y); } template <> __device__ inline float2 cuda_cast(float val) { return make_float2(val, val); } template <> __device__ inline float2 cuda_cast(half2 val) { return __half22float2(val); } template <> __device__ inline half2 cuda_cast(float2 val) { return __float22half2_rn(val); } template <> __device__ inline half2 cuda_cast(float val) { return __float2half2_rn(val); } template <> __device__ inline half2 cuda_cast(half val) { return __half2half2(val); } template <> __device__ inline int8_t cuda_cast(half val) { union { int8_t int8[2]; int16_t int16; }; union { half fp16; int16_t int16_in; }; fp16 = val; asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in)); return int8[0]; } template <> __device__ inline int16_t cuda_cast(half2 val) { union { int8_t int8[2]; int16_t int16; }; int8[0] = cuda_cast(val.x); int8[1] = cuda_cast(val.y); return int16; } template <> __device__ inline int8_t cuda_cast(float val) { union { int8_t int8[2]; int16_t int16; }; asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val)); return int8[0]; } template <> __device__ inline int16_t cuda_cast(float2 val) { union { int8_t int8[2]; int16_t int16; }; int8[0] = cuda_cast(val.x); int8[1] = cuda_cast(val.y); return int16; } template <> __device__ inline half2 cuda_cast(int16_t val) { union { int8_t int8[2]; int16_t int16; }; int16 = val; return make_half2(int8[0], int8[1]); } template <> __device__ inline float2 cuda_cast(int16_t val) { union { int8_t int8[2]; int16_t int16; }; int16 = val; return make_float2(int8[0], int8[1]); } template <> __device__ inline __nv_bfloat16 cuda_cast(int32_t val) { return static_cast(val); } template <> __device__ inline __nv_bfloat16 cuda_cast(int8_t val) { return static_cast(val); } template <> __device__ inline int8_t cuda_cast(__nv_bfloat16 val) { return static_cast(val); } template <> __device__ inline float cuda_cast(__nv_bfloat16 val) { return __bfloat162float(val); } inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 float2 f_val; f_val.x = __low2float(val); f_val.y = __high2float(val); return f_val; #else return __bfloat1622float2(val); #endif } template <> __device__ inline float2 cuda_cast(__nv_bfloat162 val) { return bf1622float2(val); } template <> __device__ inline half cuda_cast(__nv_bfloat16 val) { return __float2half(__bfloat162float(val)); } inline __device__ int16_t bf1622int16(__nv_bfloat162 val) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 float2 f_val; f_val.x = max(min(__low2float(val), 127.f), -128.f); f_val.y = max(min(__high2float(val), 127.f), -128.f); union { int8_t int8[2]; int16_t int16; }; int8[0] = static_cast(static_cast(f_val.x)); int8[1] = static_cast(static_cast(f_val.y)); return int16; #else val = __hmin2(val, make_bfloat162(127., 127.)); val = __hmax2(val, make_bfloat162(-128., -128.)); union { int8_t int8[2]; int16_t int16; }; int8[0] = static_cast(static_cast(val.x)); int8[1] = static_cast(static_cast(val.y)); return int16; #endif } template <> __device__ inline int16_t cuda_cast(__nv_bfloat162 val) { return bf1622int16(val); } template <> __device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) { return __float2bfloat16(val); } template <> __device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val) { return __float2bfloat16(__half2float(val)); } inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 __nv_bfloat162 val2; val2.x = val; val2.y = val; return val2; #else return __bfloat162bfloat162(val); #endif } template <> __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val) { return bf162bf162(val); } template <> __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val) { return __float2bfloat162_rn(val); } inline __device__ __nv_bfloat162 float22bf162(const float2 val) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 return __floats2bfloat162_rn(val.x, val.y); #else return __float22bfloat162_rn(val); #endif } template <> __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val) { return float22bf162(val); } template <> __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val) { union { int8_t int8[2]; int16_t int16; }; int16 = val; __nv_bfloat162 res; res.x = cuda_cast<__nv_bfloat16>(int8[0]); res.y = cuda_cast<__nv_bfloat16>(int8[1]); return res; } template <> __device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val) { return float22bf162(__half22float2(val)); } // // ============================== Abs ============================== template __device__ inline T cuda_abs(T val) { assert(false); return {}; } template <> __device__ inline float cuda_abs(float val) { return fabs(val); } template <> __device__ inline float2 cuda_abs(float2 val) { return make_float2(fabs(val.x), fabs(val.y)); } template <> __device__ inline half cuda_abs(half val) { return __habs(val); } template <> __device__ inline half2 cuda_abs(half2 val) { return __habs2(val); } #if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) template <> __device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val) { return __habs(val); } template <> __device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val) { return __habs2(val); } #endif // // ============================== Max ============================== template __device__ inline To cuda_max(Ti val) { return cuda_cast(val); }; template <> __device__ inline float cuda_max(float2 val) { return fmaxf(val.x, val.y); } template <> __device__ inline half cuda_max(half2 val) { return __hmax(val.x, val.y); } template <> __device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) return __hmax(val.x, val.y); #else assert(0); asm volatile("brkpt;\n" ::); return __nv_bfloat16(0); #endif } // Binary maximum: compute the max of two values. template __device__ inline T cuda_max(T val1, T val2) { return (val1 > val2) ? val1 : val2; } template <> __device__ inline float2 cuda_max(float2 val1, float2 val2) { float2 out; out.x = fmaxf(val1.x, val2.x); out.y = fmaxf(val1.y, val2.y); return out; } template <> __device__ inline half2 cuda_max(half2 val1, half2 val2) { return __hmax2(val1, val2); } template <> __device__ inline __nv_bfloat162 cuda_max(__nv_bfloat162 val1, __nv_bfloat162 val2) { return __hmax2(val1, val2); } // // ============================== Reciprocal ============================== // Fast reciprocal. inline __device__ float reciprocal_approximate_ftz(float a) { float b; asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); return b; } } // namespace maths namespace utils { #define FINAL_MASK 0xffffffff template __inline__ __device__ T warpReduceSumV2(T* val) { #pragma unroll for (int i = 0; i < NUM; i++) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); } return (T)(0.0f); } template __inline__ __device__ T blockReduceSumV2(T* val) { static __shared__ T shared[NUM][33]; int lane = threadIdx.x & 0x1f; int wid = threadIdx.x >> 5; warpReduceSumV2(val); if (lane == 0) { #pragma unroll for (int i = 0; i < NUM; i++) { shared[i][wid] = val[i]; } } __syncthreads(); bool is_mask = threadIdx.x < (blockDim.x / 32.f); #pragma unroll for (int i = 0; i < NUM; i++) { val[i] = is_mask ? shared[i][lane] : (T)(0.0f); } warpReduceSumV2(val); return (T)0.0f; } inline int getSMVersion() { int device{-1}; FLASHINFER_CUDA_CALL(cudaGetDevice(&device)); int sm_major = 0; int sm_minor = 0; FLASHINFER_CUDA_CALL( cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device)); FLASHINFER_CUDA_CALL( cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); return sm_major * 10 + sm_minor; } inline int getSMRegisters() { int device{-1}; FLASHINFER_CUDA_CALL(cudaGetDevice(&device)); int regs_per_block; FLASHINFER_CUDA_CALL( cudaDeviceGetAttribute(®s_per_block, cudaDevAttrMaxRegistersPerBlock, device)); return regs_per_block; } inline __device__ int64_t get_sf_out_offset_128x4(std::optional batchIdx, int mIdx, int kIdx, std::optional numRows, int numCols) { // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] // batched tensor // SF layout [numBTiles, numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] // --> index [bTileIdx, mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] int32_t innerKIdx = (kIdx % 4); int64_t innerKStride = 1; int32_t innerMIdx = (mIdx % (32 * 4)) / 32; int64_t innerMStride = 4 * innerKStride; // 4 // M tile layout [32, 4] is column-major. int32_t outerMIdx = (mIdx % 32); int64_t outerMStride = 4 * innerMStride; // 16 int32_t kTileIdx = (kIdx / 4); int64_t kTileStride = 32 * outerMStride; // 512 // SF vector size 16. We round the "numCols" up to a multiple of 64. int factor = details::CVT_FP4_SF_VEC_SIZE * 4; int32_t numKTiles = (numCols + factor - 1) / factor; int32_t mTileIdx = mIdx / (32 * 4); int64_t mTileStride = numKTiles * kTileStride; // Each SF block has 128 rows so pad rows to the multiple of 128. int32_t numMTiles = (numRows.value_or(0) + 128 - 1) / 128; int64_t bTileStride = numMTiles * mTileStride; // Compute the global offset. int64_t SFOffset = batchIdx.value_or(0) * bTileStride + mTileIdx * mTileStride + kTileIdx * kTileStride + outerMIdx * outerMStride + innerMIdx * innerMStride + innerKIdx * innerKStride; return SFOffset; } template __device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(std::optional batchIdx, int rowIdx, int colIdx, std::optional numRows, int numCols, SFType* SFout, QuantizationSFLayout layout) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || CVT_FP4_NUM_THREADS_PER_SF == 2); // One pair of threads write one SF to global memory. // TODO: stage through smem for packed STG.32 // is it better than STG.8 from 4 threads ? if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { if (layout == QuantizationSFLayout::SWIZZLED_128x4) { // SF vector index (16 elements share one SF in the K dimension). // numRows and numCols are unpadded. int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; int32_t mIdx = rowIdx; auto SFOffset = get_sf_out_offset_128x4(batchIdx, mIdx, kIdx, numRows, numCols); return reinterpret_cast(SFout) + SFOffset; } else if (layout == QuantizationSFLayout::LINEAR) { // Linear row-major layout, no padding required. int32_t KTileIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; int32_t numKTiles = numCols / details::CVT_FP4_SF_VEC_SIZE; int64_t mTileStride = numKTiles; int64_t BTileStride = numRows.value_or(0) * mTileStride; int64_t SFOffset = batchIdx.value_or(0) * BTileStride + rowIdx * mTileStride + KTileIdx; return reinterpret_cast(SFout) + SFOffset; } else { return nullptr; } } #endif return nullptr; } __forceinline__ __device__ uint32_t pack_bytes(uint8_t c0, uint8_t c1, uint8_t c2, uint8_t c3) { uint32_t val0 = c0; uint32_t val1 = c1; uint32_t val2 = c2; uint32_t val3 = c3; return (val3 << 24) | (val2 << 16) | (val1 << 8) | val0; } #if CUDA_VERSION >= 12080 // Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). // NOTE: bypass sm_100 requirement by __nv_cvt_float2_to_fp4x2 inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) uint32_t val; asm volatile( "{\n" ".reg .b8 byte0;\n" ".reg .b8 byte1;\n" ".reg .b8 byte2;\n" ".reg .b8 byte3;\n" "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" "}" : "=r"(val) : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); return val; #else uint32_t val; __nv_fp4x2_storage_t vals[4]; #pragma unroll for (int i = 0; i < 4; i++) { vals[i] = __nv_cvt_float2_to_fp4x2(*(((float2*)array) + i), __NV_E2M1, cudaRoundNearest); } val = pack_bytes(vals[0], vals[1], vals[2], vals[3]); return val; #endif } // Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) uint32_t val; asm volatile( "{\n" ".reg .b8 byte0;\n" ".reg .b8 byte1;\n" ".reg .b8 byte2;\n" ".reg .b8 byte3;\n" "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" "}" : "=r"(val) : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); return val; #else uint32_t val; __nv_fp4x2_storage_t vals[4]; #pragma unroll for (int i = 0; i < 4; i++) { vals[i] = __nv_cvt_float2_to_fp4x2(array[i], __NV_E2M1, cudaRoundNearest); } val = pack_bytes(vals[0], vals[1], vals[2], vals[3]); return val; #endif } // Quantizes the provided PackedVec into the uint32_t output template __device__ uint32_t cvt_warp_fp16_to_fp4(vec_t& vec, float SFScaleVal, uint8_t* SFout) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) // Get absolute maximum values among the local 8 values. auto localMax = maths::cuda_abs(get_vec2_element(vec, 0)); #pragma unroll for (int i = 1; i < details::CVT_FP4_ELTS_PER_THREAD / 2; i++) { localMax = maths::cuda_max(localMax, maths::cuda_abs(get_vec2_element(vec, i))); } // Get the absolute maximum among all 16 values (two threads). localMax = maths::cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); // Get the final absolute maximum values. float vecMax = float(maths::cuda_max(localMax.x, localMax.y)); // Get the SF (max value of the vector / max value of e2m1). // maximum value of e2m1 = 6.0. // TODO: use half as compute data type. float SFValue = SFScaleVal * (vecMax * maths::reciprocal_approximate_ftz(6.0f)); // 8 bits representation of the SF. uint8_t fp8SFVal; // Write the SF to global memory (STG.8). if constexpr (UE8M0_SF) { #if (__CUDACC_VER_MAJOR__ * 1000 + __CUDACC_VER_MINOR__ * 10 >= 12080) __nv_fp8_e8m0 tmp; tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf); SFValue = static_cast(tmp); fp8SFVal = tmp.__x; #else #error "FP8 E8M0 support requires CUDA 12.8 or newer." #endif } else { // Here SFValue is always positive, so E4M3 is the same as UE4M3. __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); fp8SFVal = tmp.__x; SFValue = static_cast(tmp); } // Get the output scale. // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * reciprocal(SFScaleVal)) float outputScale = SFValue != 0 ? maths::reciprocal_approximate_ftz( SFValue * maths::reciprocal_approximate_ftz(SFScaleVal)) : 0.0f; if (SFout) { // Write the SF to global memory (STG.8). *SFout = fp8SFVal; } // Convert the input to float. float2 fp2Vals[details::CVT_FP4_ELTS_PER_THREAD / 2]; #pragma unroll for (int i = 0; i < details::CVT_FP4_ELTS_PER_THREAD / 2; i++) { if constexpr (std::is_same_v) { fp2Vals[i] = __half22float2(get_vec2_element(vec, i)); } else { fp2Vals[i] = __bfloat1622float2(get_vec2_element(vec, i)); } fp2Vals[i].x *= outputScale; fp2Vals[i].y *= outputScale; } // Convert to e2m1 values. uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); // Write the e2m1 values to global memory. return e2m1Vec; #else return 0; #endif } #endif } // namespace utils template __device__ __forceinline__ vec_t vec_add(const vec_t& a, const vec_t& b) { vec_t ret; #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { ret[i] = static_cast(a[i]) + static_cast(b[i]); } return ret; } enum class AllReduceFusionPattern : int { kAllReduce = 0, kARResidualRMSNorm = 1, kARResidualRMSNormFP8Quant = 2, kARResidualRMSNormFP4Quant = 3, // The difference between these two and the standard version is that the NormOut version outputs // the result of the norm. kARResidualRMSNormOutFP8Quant = 4, kARResidualRMSNormOutFP4Quant = 5 }; enum class QuantType : int { kNone = 0, kFP8 = 1, kFP4 = 2, }; template struct FusionPatternTraits; #define DEFINE_FUSION_PATTERN_TRAITS(pattern, hasAllReduceOut, hasResidual, hasResidualOut, \ hasRMSNorm, hasNormOut, quantType) \ template <> \ struct FusionPatternTraits { \ static constexpr bool kHasAllReduceOut = hasAllReduceOut; \ static constexpr bool kHasResidual = hasResidual; \ static constexpr bool kHasResidualOut = hasResidualOut; \ static constexpr bool kHasRMSNorm = hasRMSNorm; \ static constexpr bool kHasNormOut = hasNormOut; \ static constexpr QuantType kQuantType = quantType; \ }; DEFINE_FUSION_PATTERN_TRAITS(AllReduceFusionPattern::kAllReduce, true, false, false, false, false, QuantType::kNone); DEFINE_FUSION_PATTERN_TRAITS(AllReduceFusionPattern::kARResidualRMSNorm, false, true, true, true, true, QuantType::kNone); DEFINE_FUSION_PATTERN_TRAITS(AllReduceFusionPattern::kARResidualRMSNormFP8Quant, false, true, true, true, false, QuantType::kFP8); DEFINE_FUSION_PATTERN_TRAITS(AllReduceFusionPattern::kARResidualRMSNormFP4Quant, false, true, true, true, false, QuantType::kFP4); DEFINE_FUSION_PATTERN_TRAITS(AllReduceFusionPattern::kARResidualRMSNormOutFP8Quant, false, true, true, true, true, QuantType::kFP8); DEFINE_FUSION_PATTERN_TRAITS(AllReduceFusionPattern::kARResidualRMSNormOutFP4Quant, false, true, true, true, true, QuantType::kFP4); #undef DEFINE_FUSION_PATTERN_TRAITS template constexpr bool HasResidual = FusionPatternTraits::kHasResidual; template constexpr bool HasRMSNorm = FusionPatternTraits::kHasRMSNorm; template constexpr bool HasAllReduceOut = FusionPatternTraits::kHasAllReduceOut; template constexpr bool HasResidualOut = FusionPatternTraits::kHasResidualOut; template constexpr bool HasNormOut = FusionPatternTraits::kHasNormOut; template constexpr QuantType GetQuantType = FusionPatternTraits::kQuantType; template struct AllReduceFusionParams { int nranks; int rank; int size; int hidden_dim; void** workspace; void* allreduce_in; void* allreduce_out; void* residual_in; void* residual_out; void* norm_out; void* quant_out; void* scale_out; void* rms_gamma; float rms_eps; float* scale_factor; bool use_oneshot; QuantizationSFLayout layout = QuantizationSFLayout::SWIZZLED_128x4; cudaStream_t stream; AllReduceFusionPattern pattern; bool trigger_completion_at_end = true; }; template struct SyncComm { __device__ __forceinline__ SyncComm(void** workspace) { counter_ptr = &reinterpret_cast(workspace[NRanks * 3])[0]; flag_ptr = &reinterpret_cast(workspace[NRanks * 3])[1]; flag_value = *flag_ptr; for (int r = 0; r < NRanks; ++r) { comm_bufs[r] = workspace[r]; barrier_flags[r] = workspace[NRanks + r]; } __syncthreads(); if (threadIdx.x == 0) { atomicAdd(counter_ptr, 1); } } __device__ __forceinline__ void update(int new_flag_value) { if (blockIdx.x == 0 && threadIdx.x == 0) { while (*reinterpret_cast(counter_ptr) != gridDim.x) { } *flag_ptr = new_flag_value; *counter_ptr = 0; } } int* counter_ptr; int* flag_ptr; void* comm_bufs[NRanks]; void* barrier_flags[NRanks]; int flag_value; }; template struct LamportComm { __device__ __forceinline__ LamportComm(void** workspace, int rank) { counter_ptr = &reinterpret_cast(workspace[NRanks * 3])[0]; flag_ptr = &reinterpret_cast(workspace[NRanks * 3])[2]; clear_ptr = &reinterpret_cast(workspace[NRanks * 3])[4]; flag_value = *flag_ptr; int comm_size = reinterpret_cast(workspace[NRanks * 3])[3]; clear_size = *clear_ptr; int data_offset = flag_value % 3; int clear_offset = (flag_value + 2) % 3; for (int r = 0; r < NRanks; ++r) { data_bufs[r] = reinterpret_cast(workspace[2 * NRanks + r]) + static_cast(data_offset) * comm_size; } clear_buf = reinterpret_cast(workspace[2 * NRanks + rank]) + clear_offset * comm_size; __syncthreads(); if (threadIdx.x == 0) { atomicAdd(counter_ptr, 1); } } __device__ __forceinline__ void update(int new_clear_size) { if (blockIdx.x == 0 && threadIdx.x == 0) { while (*reinterpret_cast(counter_ptr) != gridDim.x) { } *flag_ptr = (flag_value + 1) % 3; *clear_ptr = new_clear_size; *counter_ptr = 0; } } int* counter_ptr; int* flag_ptr; int* clear_ptr; uint8_t* data_bufs[NRanks]; uint8_t* clear_buf; int clear_size; int flag_value; }; template class Barrier { public: __device__ __forceinline__ Barrier(int rank, SyncComm const& comm) { if (threadIdx.x < NRanks) { m_flag_value = comm.flag_value; int current_rank = rank; int target_rank = threadIdx.x; m_target_flag = reinterpret_cast(comm.barrier_flags[target_rank]) + current_rank; m_current_flag = reinterpret_cast(comm.barrier_flags[current_rank]) + blockIdx.x * NRanks + target_rank; } } __device__ __forceinline__ void sync() { __syncthreads(); if (threadIdx.x < NRanks) { m_flag_value = next_flag(m_flag_value); // To avoid the ABA problem, we need to synchronize the correct flag value to all // barrier_flags, even if the corresponding CTA has not been launched. for (int flag_idx = blockIdx.x; flag_idx < details::kBarrierFlagCount; flag_idx += gridDim.x) { st_flag(m_target_flag + flag_idx * NRanks, m_flag_value); } while (ld_flag(m_current_flag) == prev_flag(m_flag_value)) { } } __syncthreads(); } protected: __device__ __forceinline__ void st_flag(int* addr, int flag) { asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(addr)); } __device__ __forceinline__ int ld_flag(int* addr) { int flag; asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(addr)); return flag; } __device__ __forceinline__ int next_flag(int flag) { return flag == 2 ? 0 : flag + 1; } __device__ __forceinline__ int prev_flag(int flag) { return flag == 0 ? 2 : flag - 1; } public: int m_flag_value; private: int* m_target_flag; int* m_current_flag; }; template class FusedOp { static constexpr int VEC_SIZE = details::kBytesPerAccess / sizeof(T); public: __device__ __forceinline__ FusedOp(AllReduceFusionParams const& params, int access_id, int access_id_in_token) : m_params(params), m_access_id(access_id), m_access_id_in_token(access_id_in_token) { if constexpr (HasRMSNorm) { m_gamma_val.load(reinterpret_cast(params.rms_gamma) + m_access_id_in_token * VEC_SIZE); } if constexpr (HasResidual) { m_residual_val.load(reinterpret_cast(params.residual_in) + m_access_id * VEC_SIZE); } if constexpr (GetQuantType == QuantType::kFP8) { m_scale_factor = 1.f / *(params.scale_factor); } else if constexpr (GetQuantType == QuantType::kFP4) { m_scale_factor = *(params.scale_factor); } } // template __device__ __forceinline__ void update(int access_id) { if (m_access_id != access_id) { m_access_id = access_id; if constexpr (HasResidual) { m_residual_val.load(reinterpret_cast(m_params.residual_in) + m_access_id * VEC_SIZE); } } } // template __device__ __forceinline__ void operator()(vec_t val, int token_id) { if constexpr (HasAllReduceOut) { val.store(reinterpret_cast(m_params.allreduce_out) + m_access_id * VEC_SIZE); } if constexpr (HasResidual) { val = vec_add(val, m_residual_val); if constexpr (HasResidualOut) { val.store(reinterpret_cast(m_params.residual_out) + m_access_id * VEC_SIZE); } } if constexpr (HasRMSNorm) { val = rms_norm(val, m_gamma_val); if constexpr (HasNormOut) { val.store(reinterpret_cast(m_params.norm_out) + m_access_id * VEC_SIZE); } } #if CUDA_VERSION >= 12080 if constexpr (GetQuantType == QuantType::kFP4) { // NOTE(Yingyi): might update later auto sf_out = utils::cvt_quant_to_fp4_get_sf_out_offset( std::nullopt /* batchIdx */, token_id, m_access_id_in_token, std::nullopt /* numRows */, m_params.hidden_dim, reinterpret_cast(m_params.scale_out), m_params.layout); reinterpret_cast(m_params.quant_out)[m_access_id] = utils::cvt_warp_fp16_to_fp4(val, m_scale_factor, sf_out); } else #endif if constexpr (GetQuantType == QuantType::kFP8) { using PackedQuantizedType = std::conditional_t, float, float2>; PackedQuantizedType ret; #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { reinterpret_cast<__nv_fp8_e4m3*>(&ret)[i] = static_cast<__nv_fp8_e4m3>( static_cast(reinterpret_cast(&val)[i]) * m_scale_factor); } reinterpret_cast(m_params.quant_out)[m_access_id] = ret; } else { static_assert(GetQuantType == QuantType::kNone, "Invalid quant type"); } } protected: __device__ __forceinline__ vec_t rms_norm(vec_t const& residual, vec_t const& gamma) { __shared__ float s_val; vec_t norm_out; float acc = 0.f; #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { float v = static_cast(reinterpret_cast(&residual)[i]); acc += v * v; } utils::blockReduceSumV2(&acc); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) namespace cg = cooperative_groups; cg::cluster_group cluster = cg::this_cluster(); if (cluster.num_blocks() > 1) { if (threadIdx.x == 0) { s_val = acc; acc = 0.f; } cluster.sync(); if (threadIdx.x == 0) { for (int i = 0; i < cluster.num_blocks(); ++i) { acc += *cluster.map_shared_rank(&s_val, i); } } cluster.sync(); } #endif if (threadIdx.x == 0) { s_val = rsqrtf(acc / m_params.hidden_dim + m_params.rms_eps); } __syncthreads(); #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { reinterpret_cast(&norm_out)[i] = static_cast(static_cast(reinterpret_cast(&residual)[i]) * s_val * static_cast(reinterpret_cast(&gamma)[i])); } return norm_out; } private: AllReduceFusionParams const& m_params; int m_access_id; int m_access_id_in_token; float m_scale_factor; vec_t m_residual_val; vec_t m_gamma_val; }; template struct neg_zero { static constexpr T value = -T(0); }; template <> struct neg_zero { static constexpr unsigned short neg_zero_bits = 0x8000U; static constexpr __half value = __half_raw{neg_zero_bits}; }; template <> struct neg_zero { static constexpr unsigned short neg_zero_bits = 0x8000U; static constexpr __nv_bfloat16 value = __nv_bfloat16_raw{neg_zero_bits}; }; template <> struct neg_zero { static constexpr unsigned int neg_zero_bits = 0x80000000U; static constexpr float value = -0.0f; }; template __device__ static constexpr T neg_zero_v = neg_zero::value; template __device__ bool is_negative_zero(T) { return false; } // float specialization template <> __device__ bool is_negative_zero(float x) { return (__float_as_int(x) == 0x80000000); } // double specialization template <> __device__ bool is_negative_zero(double x) { return (__double_as_longlong(x) == 0x8000000000000000ULL); } // __half specialization template <> __device__ bool is_negative_zero<__half>(__half x) { return (__half_as_ushort(x) == 0x8000); } // __nv_bfloat16 specialization template <> __device__ bool is_negative_zero<__nv_bfloat16>(__nv_bfloat16 x) { return (__bfloat16_as_ushort(x) == 0x8000); } template __device__ __forceinline__ bool has_neg_zero(const vec_t& vec) { #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { if (is_negative_zero(vec[i])) { return true; } } return false; } template __device__ __forceinline__ void remove_neg_zero(vec_t& vec) { #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { vec[i] = (is_negative_zero(vec[i])) ? static_cast(0.f) : vec[i]; } } template __device__ __forceinline__ void set_neg_zero(T* addr) { vec_t val; val.fill(neg_zero_v); val.store_global_volatile(addr); } template __device__ __forceinline__ vec_t allreduce_sum(vec_t* vals) { if constexpr (Fp32Acc) { static_assert(!std::is_same_v); float acc_f32[VEC_SIZE]; #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { acc_f32[i] = static_cast(reinterpret_cast(&vals[0])[i]); } #pragma unroll for (int r = 1; r < NRanks; ++r) { #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { acc_f32[i] += static_cast(reinterpret_cast(&vals[r])[i]); } } vec_t acc; #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { acc[i] = static_cast(acc_f32[i]); } return acc; } else { vec_t acc = vals[0]; #pragma unroll for (int r = 1; r < NRanks; ++r) { acc = vec_add(acc, vals[r]); } return acc; } } template class IndexHelper { public: __device__ __forceinline__ IndexHelper(AllReduceFusionParams const& params) { static constexpr int VEC_SIZE = details::kBytesPerAccess / sizeof(T); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) namespace cg = cooperative_groups; cg::cluster_group cluster = cg::this_cluster(); cg::grid_group grid = cg::this_grid(); token_id = grid.cluster_rank(); access_id_in_token = cluster.thread_rank(); token_stride = grid.num_clusters(); #else token_id = blockIdx.x; access_id_in_token = threadIdx.x; token_stride = gridDim.x; #endif access_id = token_id * params.hidden_dim / VEC_SIZE + access_id_in_token; access_stride = token_stride * params.hidden_dim / VEC_SIZE; tot_access = params.size / VEC_SIZE; } int token_id; int access_id_in_token; int token_stride; int access_id; int access_stride; int tot_access; }; template __global__ void allreduce_fusion_kernel_oneshot_lamport(AllReduceFusionParams params) { static constexpr int VEC_SIZE = details::kBytesPerAccess / sizeof(T); IndexHelper index_helper(params); int token_id = index_helper.token_id; int access_id_in_token = index_helper.access_id_in_token; int token_stride = index_helper.token_stride; int access_id = index_helper.access_id; int access_stride = index_helper.access_stride; int tot_access = index_helper.tot_access; vec_t clear_vec; clear_vec.fill(neg_zero_v); FusedOp fused_op(params, access_id, access_id_in_token); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); if constexpr (!TriggerCompletionAtEnd) { cudaTriggerProgrammaticLaunchCompletion(); } #endif LamportComm comm(params.workspace, params.rank); int clear_access = comm.clear_size / VEC_SIZE; for (int idx = access_id; idx < tot_access; idx += access_stride) { vec_t val; val.load(reinterpret_cast(params.allreduce_in) + idx * VEC_SIZE); remove_neg_zero(val); #pragma unroll for (int r = 0; r < NRanks; ++r) { // Push data to other ranks val.store(reinterpret_cast(comm.data_bufs[r]) + (params.rank * tot_access + idx) * VEC_SIZE); } } for (int idx = access_id; idx < clear_access; idx += access_stride) { // Clear comm buffer that previous kernel used clear_vec.store(reinterpret_cast(comm.clear_buf) + idx * VEC_SIZE); } for (int idx = access_id, tidx = token_id; idx < tot_access; idx += access_stride, tidx += token_stride) { fused_op.update(idx); vec_t vals[NRanks]; bool done = false; while (!done) { done = true; #pragma unroll for (int r = 0; r < NRanks; ++r) { // LDG.128 from local rank vals[r].load_global_volatile(reinterpret_cast(comm.data_bufs[params.rank]) + (r * tot_access + idx) * VEC_SIZE); done &= !has_neg_zero(vals[r]); } } vec_t sum_val = allreduce_sum(vals); fused_op(sum_val, tidx); } comm.update(params.size * NRanks); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) if constexpr (TriggerCompletionAtEnd) { cudaTriggerProgrammaticLaunchCompletion(); } #endif } template __global__ void allreduce_fusion_kernel_twoshot_sync(AllReduceFusionParams params, std::array begin_tokens, std::array token_num_per_ranks) { static constexpr int VEC_SIZE = details::kBytesPerAccess / sizeof(T); IndexHelper index_helper(params); int token_id = index_helper.token_id; int access_id_in_token = index_helper.access_id_in_token; int token_stride = index_helper.token_stride; int access_id = index_helper.access_id; int access_stride = index_helper.access_stride; int tot_access = index_helper.tot_access; FusedOp fused_op(params, access_id, access_id_in_token); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); #endif SyncComm comm(params.workspace); #pragma unroll for (int r = 0; r < NRanks; ++r) { int comm_access_id = access_id + begin_tokens[r] * params.hidden_dim / VEC_SIZE; int comm_tot_access = (begin_tokens[r] + token_num_per_ranks[r]) * params.hidden_dim / VEC_SIZE; for (int idx = comm_access_id; idx < comm_tot_access; idx += access_stride) { reinterpret_cast(comm.comm_bufs[params.rank])[idx] = reinterpret_cast(params.allreduce_in)[idx]; } } Barrier barrier(params.rank, comm); barrier.sync(); int comm_access_id = access_id + begin_tokens[params.rank] * params.hidden_dim / VEC_SIZE; int comm_tot_access = (begin_tokens[params.rank] + token_num_per_ranks[params.rank]) * params.hidden_dim / VEC_SIZE; for (int idx = comm_access_id; idx < comm_tot_access; idx += access_stride) { vec_t vals[NRanks]; #pragma unroll for (int r = 0; r < NRanks; ++r) { vals[r].load(reinterpret_cast(comm.comm_bufs[r]) + idx * VEC_SIZE); } vec_t sum_val = allreduce_sum(vals); #pragma unroll for (int r = 0; r < NRanks; ++r) { sum_val.store(reinterpret_cast(comm.comm_bufs[r]) + (tot_access + idx) * VEC_SIZE); } } barrier.sync(); #pragma unroll for (int r = 0; r < NRanks; ++r) { int comm_access_id = access_id + begin_tokens[r] * params.hidden_dim / VEC_SIZE; int comm_token_id = token_id + begin_tokens[r]; int comm_tot_access = (begin_tokens[r] + token_num_per_ranks[r]) * params.hidden_dim / VEC_SIZE; for (int idx = comm_access_id, tidx = comm_token_id; idx < comm_tot_access; idx += access_stride, tidx += token_stride) { fused_op.update(idx); vec_t sum_val; sum_val.load(reinterpret_cast(comm.comm_bufs[params.rank]) + (tot_access + idx) * VEC_SIZE); fused_op(sum_val, tidx); } } comm.update(barrier.m_flag_value); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaTriggerProgrammaticLaunchCompletion(); #endif } int get_sm_count() { static int sm_count = 0; if (sm_count == 0) { int device_id; FLASHINFER_CUDA_CALL(cudaGetDevice(&device_id)); FLASHINFER_CUDA_CALL( cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device_id)); } return sm_count; } template cudaError_t launch_oneshot_lamport(AllReduceFusionParams const& params, cudaLaunchConfig_t& cfg) { FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( &cfg, allreduce_fusion_kernel_oneshot_lamport, params)); return cudaSuccess; } template int get_registers_per_thread_oneshot() { auto kernel = allreduce_fusion_kernel_oneshot_lamport; cudaFuncAttributes attr; cudaFuncGetAttributes(&attr, kernel); return attr.numRegs; } template cudaError_t launch_twoshot_sync(AllReduceFusionParams const& params, cudaLaunchConfig_t& cfg, std::array begin_tokens, std::array token_num_per_ranks) { FLASHINFER_CUDA_CALL( cudaLaunchKernelEx(&cfg, allreduce_fusion_kernel_twoshot_sync, params, begin_tokens, token_num_per_ranks)); return cudaSuccess; } template int get_registers_per_thread_twoshot() { auto kernel = allreduce_fusion_kernel_twoshot_sync; cudaFuncAttributes attr; cudaFuncGetAttributes(&attr, kernel); return attr.numRegs; } bool use_oneshot(int token_num) { return token_num <= details::kOneShotMaxToken; } template cudaError_t allreduce_fusion_kernel_launcher(AllReduceFusionParams const& params, bool launch_with_pdl) { static constexpr int VEC_SIZE = details::kBytesPerAccess / sizeof(T); FLASHINFER_CHECK(params.size % params.hidden_dim == 0, "params.size % params.hidden_dim != 0"); FLASHINFER_CHECK(params.hidden_dim % VEC_SIZE == 0, "params.hidden_dim % VEC_SIZE != 0"); static int SM = utils::getSMVersion(); int token_num = params.size / params.hidden_dim; bool oneshot = params.use_oneshot; int cluster_num = token_num; std::array begin_tokens, token_num_per_ranks; if (!oneshot) { int remaining_token = token_num % NRanks; int token_num_per_rank = token_num / NRanks; cluster_num = token_num_per_rank; if (remaining_token) { cluster_num++; } for (int r = 0; r < NRanks; ++r) { begin_tokens[r] = r * token_num_per_rank + (remaining_token > r ? r : remaining_token); token_num_per_ranks[r] = token_num_per_rank + (remaining_token > r ? 1 : 0); } } int threads_per_token = params.hidden_dim / VEC_SIZE; int cluster_size; if (SM >= 90) { cluster_size = 8; } else { cluster_size = 1; } while (threads_per_token % cluster_size != 0 && cluster_size > 1) { cluster_size /= 2; } int threads_per_block = threads_per_token / cluster_size; while (threads_per_block < 128 && cluster_size >= 2) { threads_per_block *= 2; cluster_size /= 2; } int sm_count = get_sm_count(); int registers_per_thread; if (oneshot) { if (params.trigger_completion_at_end) { registers_per_thread = get_registers_per_thread_oneshot(); } else { registers_per_thread = get_registers_per_thread_oneshot(); } } else { registers_per_thread = get_registers_per_thread_twoshot(); } static int max_registers = -1; if (max_registers < 0) { max_registers = utils::getSMRegisters(); } int max_threads_per_block = min(max_registers / registers_per_thread, 1024); while (cluster_num * cluster_size > sm_count && cluster_size > 1 && threads_per_block <= max_threads_per_block / 2) { threads_per_block *= 2; cluster_size /= 2; } FLASHINFER_CHECK(oneshot || threads_per_block >= params.nranks, "not oneshot, or threads_per_block < nranks"); int block_size = threads_per_block; FLASHINFER_CHECK(block_size <= 1024 && cluster_size > 0, "block_size > 1024 or cluster_size <= 0"); int grid_size = (std::min(sm_count, cluster_num * cluster_size) / cluster_size) * cluster_size; cudaLaunchConfig_t cfg; cudaLaunchAttribute attribute[2]; cfg.gridDim = grid_size; cfg.blockDim = block_size; cfg.dynamicSmemBytes = 0; cfg.stream = params.stream; attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; attribute[0].val.programmaticStreamSerializationAllowed = launch_with_pdl ? 1 : 0; attribute[1].id = cudaLaunchAttributeClusterDimension; attribute[1].val.clusterDim.x = cluster_size; attribute[1].val.clusterDim.y = 1; attribute[1].val.clusterDim.z = 1; cfg.attrs = attribute; cfg.numAttrs = SM >= 90 ? 2 : 0; if (oneshot) { bool trigger_completion_at_end = params.trigger_completion_at_end; if (trigger_completion_at_end) { FLASHINFER_CUDA_CALL( (launch_oneshot_lamport(params, cfg))); } else { FLASHINFER_CUDA_CALL( (launch_oneshot_lamport(params, cfg))); } } else { FLASHINFER_CUDA_CALL((launch_twoshot_sync( params, cfg, begin_tokens, token_num_per_ranks))); } return cudaSuccess; } template cudaError_t allreduce_fusion_op(AllReduceFusionParams const& params, bool launch_with_pdl, bool fp32_acc) { #define DISPATCH_ACC_TYPE(T, Pattern, NRanks) \ if constexpr (std::is_same_v) { \ return allreduce_fusion_kernel_launcher(params, launch_with_pdl); \ } else { \ if (fp32_acc) { \ return allreduce_fusion_kernel_launcher(params, launch_with_pdl); \ } else { \ return allreduce_fusion_kernel_launcher(params, launch_with_pdl); \ } \ } #define DISPATCH_PATTERN(T, NRanks) \ switch (params.pattern) { \ case AllReduceFusionPattern::kAllReduce: \ DISPATCH_ACC_TYPE(T, AllReduceFusionPattern::kAllReduce, NRanks); \ break; \ case AllReduceFusionPattern::kARResidualRMSNorm: \ DISPATCH_ACC_TYPE(T, AllReduceFusionPattern::kARResidualRMSNorm, NRanks); \ break; \ case AllReduceFusionPattern::kARResidualRMSNormFP8Quant: \ DISPATCH_ACC_TYPE(T, AllReduceFusionPattern::kARResidualRMSNormFP8Quant, NRanks); \ break; \ case AllReduceFusionPattern::kARResidualRMSNormFP4Quant: \ if constexpr (!std::is_same_v && CUDA_VERSION >= 12080) { \ DISPATCH_ACC_TYPE(T, AllReduceFusionPattern::kARResidualRMSNormFP4Quant, NRanks); \ } else { \ FLASHINFER_CHECK(CUDA_VERSION >= 12080, "FP4Quant requires CUDA 12.8 or higher"); \ FLASHINFER_CHECK(false, "FP4Quant pattern cannot work with DType=float"); \ } \ break; \ case AllReduceFusionPattern::kARResidualRMSNormOutFP8Quant: \ DISPATCH_ACC_TYPE(T, AllReduceFusionPattern::kARResidualRMSNormOutFP8Quant, NRanks); \ break; \ case AllReduceFusionPattern::kARResidualRMSNormOutFP4Quant: \ if constexpr (!std::is_same_v && CUDA_VERSION >= 12080) { \ DISPATCH_ACC_TYPE(T, AllReduceFusionPattern::kARResidualRMSNormOutFP4Quant, NRanks); \ } else { \ FLASHINFER_CHECK(CUDA_VERSION >= 12080, "OutFP4Quant requires CUDA 12.8 or higher"); \ FLASHINFER_CHECK(false, "OutFP4Quant pattern cannot work with DType=float"); \ } \ break; \ default: \ FLASHINFER_CHECK(false, "Unsupported allreduce fusion pattern"); \ } switch (params.nranks) { case 2: DISPATCH_PATTERN(T, 2); break; case 4: DISPATCH_PATTERN(T, 4); break; case 8: DISPATCH_PATTERN(T, 8); break; case 16: DISPATCH_PATTERN(T, 16); break; default: FLASHINFER_ERROR( "allreduce_fusion_kernel: unsupported ranks number! Supported ranks: 2, 4, 8, 16."); } } } // namespace trtllm_allreduce_fusion } // namespace flashinfer