sglang_v0.5.2/flashinfer_0.3.1/include/flashinfer/comm/trtllm_alltoall.cuh

926 lines
38 KiB
Plaintext

/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cooperative_groups.h>
#include <stdio.h>
#include <cub/cub.cuh>
#include "../exception.h"
#include "../logging.h"
#include "../utils.cuh"
#include "../vec_dtypes.cuh"
namespace cg = cooperative_groups;
namespace flashinfer {
namespace trtllm_alltoall {
#ifdef __CUDACC__
#define ALIGN_256 __align__(256)
#else
#define ALIGN_256 alignas(256)
#endif
struct ALIGN_256 MoeCommFifoConnInfo {
volatile uint64_t head; // write position
volatile uint64_t tail; // read position
};
constexpr int WARP_SIZE = 32;
constexpr uint32_t WARP_MASK = 0xffffffff;
constexpr int RECV_FIFO_DEPTH = 8;
constexpr int RECV_FIFO_ENTRY_BYTES = 256 * 1024;
constexpr int RECV_FIFO_ENTRY_U64 = RECV_FIFO_ENTRY_BYTES / sizeof(uint64_t);
constexpr int RECV_FIFO_TOTAL_BYTES = RECV_FIFO_DEPTH * RECV_FIFO_ENTRY_BYTES;
constexpr int RECV_FIFO_TOTAL_U64 = RECV_FIFO_TOTAL_BYTES / sizeof(uint64_t);
static int getMultiProcessorCount() {
int device_id;
int multi_processor_count;
FLASHINFER_CUDA_CALL(cudaGetDevice(&device_id));
FLASHINFER_CUDA_CALL(
cudaDeviceGetAttribute(&multi_processor_count, cudaDevAttrMultiProcessorCount, device_id));
return multi_processor_count;
}
class AllToAllChannelCommunicatorBase {
public:
static constexpr int GROUP_COUNT_PER_BLOCK = 8;
static_assert(GROUP_COUNT_PER_BLOCK <= 8,
"GROUP_COUNT_PER_BLOCK must be less than or equal to 8");
static constexpr int WARP_PER_GROUP = 2;
static constexpr int U64_DATA_REG_PER_THREAD = 8;
// A packet is a warp-sized chunk of data that is sent or received in one go,
// but may be split into multiple 64-bit registers, the number of which is
// U64_DATA_REG_PER_THREAD.
static constexpr int PACKET_SIZE_IN_U64 = WARP_SIZE * U64_DATA_REG_PER_THREAD;
static constexpr int PACKET_SIZE_IN_BYTES = PACKET_SIZE_IN_U64 * sizeof(uint64_t);
static constexpr int DATA_PAYLOAD_SIZE_PER_PACKET_IN_U64 =
(WARP_SIZE - 2) * U64_DATA_REG_PER_THREAD;
static constexpr int DATA_PAYLOAD_SIZE_PER_PACKET =
DATA_PAYLOAD_SIZE_PER_PACKET_IN_U64 * sizeof(uint64_t);
static constexpr int U64_ELT_COUNT_PER_PACKET = PACKET_SIZE_IN_BYTES / sizeof(uint64_t);
static constexpr int PACKET_COUNT_PER_FIFO_ENTRY = RECV_FIFO_ENTRY_BYTES / PACKET_SIZE_IN_BYTES;
static constexpr int GROUP_MAX_INDICE_COUNT =
RECV_FIFO_ENTRY_BYTES / sizeof(uint64_t) / (WARP_SIZE * U64_DATA_REG_PER_THREAD);
struct GroupSharedBuffer {
int groupIndiceBuffer[GROUP_MAX_INDICE_COUNT];
int groupStartIndice;
int groupEndIndice;
};
static void setMaxUsableSmCount(int maxUsableSmCount) {
FLASHINFER_CHECK(AllToAllChannelCommunicatorBase::maxSmCountUsed == false,
"setMaxUsableSmCount can be called only before it is used");
int smCount = getMultiProcessorCount();
if (maxUsableSmCount > smCount) {
FLASHINFER_LOG_WARN(
"setMaxUsableSmCount, maxUsableSmCount=%d, larger than smCount=%d, using smCount instead",
maxUsableSmCount, smCount);
maxUsableSmCount = smCount;
}
AllToAllChannelCommunicatorBase::maxSmCount = maxUsableSmCount;
}
static int getMaxUsableSmCount() {
AllToAllChannelCommunicatorBase::maxSmCountUsed = true;
if (AllToAllChannelCommunicatorBase::maxSmCount == -1) {
int smCount = getMultiProcessorCount();
AllToAllChannelCommunicatorBase::maxSmCount = smCount;
}
return AllToAllChannelCommunicatorBase::maxSmCount;
}
static int computeMoeCommChannelCount(int epSize) {
int smCount = getMaxUsableSmCount();
int blockCountPerChannel = (epSize + GROUP_COUNT_PER_BLOCK - 1) / GROUP_COUNT_PER_BLOCK;
blockCountPerChannel *= 2; // for send and recv
FLASHINFER_CHECK(blockCountPerChannel <= smCount,
"GPU should support at lease one channel, usableSmCount=%d", smCount);
int perferredChannel = smCount / 2 / blockCountPerChannel; // use half SMs for communication
int channelCount = std::max(perferredChannel, 1); // at lease one channel
return channelCount;
}
static int getMoeCommChannelCount(int epSize) {
static std::map<int, int> channelCountMap{};
auto iter = channelCountMap.find(epSize);
if (iter == channelCountMap.end()) {
auto channelCount = AllToAllChannelCommunicatorBase::computeMoeCommChannelCount(epSize);
channelCountMap[epSize] = channelCount;
return channelCount;
}
return iter->second;
}
static dim3 getLaunchBlockDim() {
return dim3(WARP_SIZE * WARP_PER_GROUP, GROUP_COUNT_PER_BLOCK);
}
static dim3 getLaunchGridDim(int epSize) {
int channelCount = AllToAllChannelCommunicatorBase::getMoeCommChannelCount(epSize);
return dim3((epSize + GROUP_COUNT_PER_BLOCK - 1) / GROUP_COUNT_PER_BLOCK, channelCount, 2);
}
protected:
static int maxSmCount;
static bool maxSmCountUsed;
};
inline size_t getMoeCommWorkspaceSize(int epSize) {
int channelCount = AllToAllChannelCommunicatorBase::getMoeCommChannelCount(epSize);
return RECV_FIFO_TOTAL_BYTES * epSize * channelCount +
sizeof(MoeCommFifoConnInfo) * epSize * channelCount;
}
struct MoeEpWorldInfo {
int epSize;
int epRank;
};
struct MoeExpertParallelInfo {
int expertCount = -1;
int topK = 1;
};
struct SendRecvDataInfo {
int vectorSizeInU64;
// pre-computed at host side for GPU kernel
int dataPacketCountPerVector;
int vectorCountPerFifoEntry;
void ComputeDataPacketCountPerVector() {
dataPacketCountPerVector = (vectorSizeInU64 * sizeof(uint64_t) +
AllToAllChannelCommunicatorBase::DATA_PAYLOAD_SIZE_PER_PACKET - 1) /
AllToAllChannelCommunicatorBase::DATA_PAYLOAD_SIZE_PER_PACKET;
}
void ComputeVectorCountPerFifoEntry() {
ComputeDataPacketCountPerVector();
vectorCountPerFifoEntry =
AllToAllChannelCommunicatorBase::PACKET_COUNT_PER_FIFO_ENTRY / dataPacketCountPerVector;
}
void DoPreCompute() {
ComputeDataPacketCountPerVector();
ComputeVectorCountPerFifoEntry();
assert(vectorCountPerFifoEntry <= AllToAllChannelCommunicatorBase::GROUP_MAX_INDICE_COUNT);
}
};
// struct holding Send/Recv data pointer and its displacement information.
struct SendRecvDispls {
uint64_t* dataPtr;
int const* rankCountCumSum; // length = epSize
int const* rankLocalIndices; // length = rankCountCumSum[epRank] - rankCountCumSum[epRank - 1] if
// epRank > 0 else rankCountCumSum[epRank]
int vectorStrideInU64;
#ifdef __CUDACC__
__inline__ __device__ int getCount(int rank) const {
return rank == 0 ? rankCountCumSum[rank] : rankCountCumSum[rank] - rankCountCumSum[rank - 1];
}
__inline__ __device__ int getRankStart(int rank) const {
return rank == 0 ? 0 : rankCountCumSum[rank - 1];
}
__inline__ __device__ int getRealVectorIndice(int globalVectorIndex) const {
return rankLocalIndices[globalVectorIndex];
}
__inline__ __device__ uint64_t* getVectorDataPtr(int realVectorIndex) const {
return dataPtr + realVectorIndex * vectorStrideInU64;
}
#endif
};
struct MoeCommWorkspace {
uint64_t* workspacePtr;
size_t rankStrideInU64;
#ifdef __CUDACC__
__inline__ __device__ uint64_t* getFifoBasePtr(bool isSender, int epRank, int peerRank,
int channel, int channelCount) const {
// fifo itself is in receiver's side.
if (isSender) {
return workspacePtr + peerRank * rankStrideInU64 +
(epRank * channelCount + channel) * RECV_FIFO_TOTAL_U64;
} else {
return workspacePtr + epRank * rankStrideInU64 +
(peerRank * channelCount + channel) * RECV_FIFO_TOTAL_U64;
}
}
__inline__ __device__ MoeCommFifoConnInfo* getFifoConnInfo(bool isSender, int epRank,
int peerRank, int channel, int epSize,
int channelCount) const {
// fifoInfo is in sender's side.
uint64_t* fifoInfoPtrU64 = workspacePtr + RECV_FIFO_TOTAL_U64 * channelCount * epSize;
int strideIndice = isSender ? epRank : peerRank;
int fifoInfoIndice = isSender ? peerRank : epRank;
fifoInfoPtrU64 += strideIndice * rankStrideInU64;
MoeCommFifoConnInfo* fifoInfoPtr = (MoeCommFifoConnInfo*)fifoInfoPtrU64;
return fifoInfoPtr + fifoInfoIndice * channelCount + channel;
}
#endif
};
__device__ inline void barrier_sync(int name, int nThreads) {
asm volatile("barrier.sync.aligned %0, %1;" ::"r"(name), "r"(nThreads) : "memory");
}
inline __device__ void load128(uint64_t const* ptr, uint64_t& v0, uint64_t& v1) {
asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];"
: "=l"(v0), "=l"(v1)
: "l"(ptr)
: "memory");
}
inline __device__ void store128(uint64_t* ptr, uint64_t v0, uint64_t v1) {
asm volatile("st.volatile.global.v2.u64 [%2], {%0,%1};" ::"l"(v0), "l"(v1), "l"(ptr) : "memory");
}
template <bool isSender>
class AllToAllChannelCommunicator : public AllToAllChannelCommunicatorBase {
private:
int const tid; // thread index in primitives group
int const nthreads; // number of threads in primitives group
int const wid; // lane index in warp
int const warp; // warp index in primitives group
const MoeEpWorldInfo worldInfo;
const MoeCommWorkspace workspace;
const SendRecvDataInfo sendRecvDataInfo;
const SendRecvDispls dataDispls;
int peerRank; // peer rank index
bool const flagThread;
int const group; // primitives group index
int const channel; // channel index
int const channelCount; // count of channels
MoeCommFifoConnInfo* fifoConnInfoPtr;
uint64_t* fifoBasePtr; // pointer to fifo base address
uint64_t step;
uint64_t tailStepCache;
uint64_t regs[U64_DATA_REG_PER_THREAD];
GroupSharedBuffer* groupSharedBuffer;
int groupStartIndice;
int groupEndIndice;
int sliceStartIndice;
int sliceEndIndice;
uint64_t* stepFifoEntryPtr;
public:
__inline__ __device__ uint64_t getFlag() { return step + 1; }
__inline__ __device__ AllToAllChannelCommunicator(MoeEpWorldInfo const& worldInfo,
MoeCommWorkspace workspace,
SendRecvDataInfo sendRecvDataInfo,
SendRecvDispls dataDispls,
GroupSharedBuffer* groupSharedBuffer,
int channelCount)
: worldInfo(worldInfo),
nthreads(blockDim.x),
tid(threadIdx.x),
workspace(workspace),
sendRecvDataInfo(sendRecvDataInfo),
dataDispls(dataDispls),
wid(threadIdx.x % WARP_SIZE),
warp(threadIdx.x / WARP_SIZE),
peerRank(blockIdx.x * GROUP_COUNT_PER_BLOCK + threadIdx.y),
group(threadIdx.y),
channel(blockIdx.y),
flagThread(threadIdx.x % 8 == 7),
fifoConnInfoPtr(nullptr),
fifoBasePtr(nullptr),
step(0),
tailStepCache(0),
groupSharedBuffer(groupSharedBuffer),
channelCount(channelCount) {}
__inline__ __device__ void init() {
fifoBasePtr =
workspace.getFifoBasePtr(isSender, worldInfo.epRank, peerRank, channel, channelCount);
fifoConnInfoPtr = workspace.getFifoConnInfo(isSender, worldInfo.epRank, peerRank, channel,
worldInfo.epSize, channelCount);
step = isSender ? fifoConnInfoPtr->head : fifoConnInfoPtr->tail;
tailStepCache = isSender ? fifoConnInfoPtr->tail : 0;
}
__inline__ __device__ void computeGroupTransferRange() {
if (tid == 0) {
int rankCount = dataDispls.getCount(peerRank);
int rankStart = dataDispls.getRankStart(peerRank);
int countPerChannel = (rankCount + channelCount - 1) / channelCount;
int groupEnd = min(rankStart + (channel + 1) * countPerChannel, rankStart + rankCount);
int groupStart = min(rankStart + channel * countPerChannel, rankStart + rankCount);
groupSharedBuffer->groupStartIndice = groupStart;
groupSharedBuffer->groupEndIndice = groupEnd;
}
barrier();
groupStartIndice = groupSharedBuffer->groupStartIndice;
groupEndIndice = groupSharedBuffer->groupEndIndice;
}
__inline__ __device__ void loadTransferIndices() {
sliceStartIndice = groupStartIndice;
sliceEndIndice =
min(groupStartIndice + sendRecvDataInfo.vectorCountPerFifoEntry, groupEndIndice);
for (int i = groupStartIndice + tid; i < sliceEndIndice; i += WARP_SIZE * WARP_PER_GROUP) {
groupSharedBuffer->groupIndiceBuffer[i - groupStartIndice] =
dataDispls.getRealVectorIndice(i);
}
groupStartIndice = sliceEndIndice;
barrier();
}
__inline__ __device__ void computeSlicePtr() {
stepFifoEntryPtr = fifoBasePtr + RECV_FIFO_ENTRY_U64 * (step % RECV_FIFO_DEPTH);
}
__inline__ __device__ void sendSlice() {
waitSend();
int EltPer16B = 2;
int eltN = sendRecvDataInfo.vectorSizeInU64;
for (int vecId = warp + sliceStartIndice; vecId < sliceEndIndice; vecId += WARP_PER_GROUP) {
int idxInSlice = vecId - sliceStartIndice;
int vecRealIdx = groupSharedBuffer->groupIndiceBuffer[idxInSlice];
uint64_t* src = dataDispls.getVectorDataPtr(vecRealIdx);
uint64_t* slicePtr =
stepFifoEntryPtr +
idxInSlice * sendRecvDataInfo.dataPacketCountPerVector * PACKET_SIZE_IN_U64 + 2 * wid;
for (int packetId = 0; packetId < sendRecvDataInfo.dataPacketCountPerVector; packetId++) {
int vecOff = packetId * DATA_PAYLOAD_SIZE_PER_PACKET_IN_U64;
#pragma unroll
for (int g = 0; g < U64_DATA_REG_PER_THREAD / 2; g++) {
int ix = g * WARP_SIZE - 4 * (g / 2) + wid - (g % 2) * (wid / 8);
__syncwarp();
if (!flagThread || g % 2 == 0) {
if (ix * EltPer16B + vecOff < eltN) {
load128((uint64_t*)(src + ix * EltPer16B + vecOff), regs[2 * g + 0], regs[2 * g + 1]);
}
}
__syncwarp();
}
#pragma unroll
for (int g = 1; g < U64_DATA_REG_PER_THREAD / 2; g += 2) {
if (flagThread) regs[2 * g] = regs[2 * g - 1];
}
uint64_t flag = getFlag();
uint64_t* packetPtr = slicePtr + packetId * PACKET_SIZE_IN_U64;
__syncwarp();
#pragma unroll
for (int u = 0; u < U64_DATA_REG_PER_THREAD; u += 2) {
store128(packetPtr + u * WARP_SIZE, regs[u], flagThread ? flag : regs[u + 1]);
}
}
}
updateSend();
}
__inline__ __device__ void recvSlice() {
// receiver don't need to wait since we have flag.
int EltPer16B = 2;
int eltN = sendRecvDataInfo.vectorSizeInU64;
for (int vecId = warp + sliceStartIndice; vecId < sliceEndIndice; vecId += WARP_PER_GROUP) {
int idxInSlice = vecId - sliceStartIndice;
int vecRealIdx = groupSharedBuffer->groupIndiceBuffer[idxInSlice];
uint64_t* dst = dataDispls.getVectorDataPtr(vecRealIdx);
uint64_t* slicePtr =
stepFifoEntryPtr +
idxInSlice * sendRecvDataInfo.dataPacketCountPerVector * PACKET_SIZE_IN_U64 + 2 * wid;
for (int packetId = 0; packetId < sendRecvDataInfo.dataPacketCountPerVector; packetId++) {
uint64_t* packetPtr = slicePtr + packetId * PACKET_SIZE_IN_U64;
int vecOff = packetId * DATA_PAYLOAD_SIZE_PER_PACKET_IN_U64;
bool needReload;
uint64_t flag = getFlag();
__syncwarp();
do {
needReload = false;
#pragma unroll
for (int u = 0; u < U64_DATA_REG_PER_THREAD; u += 2) {
load128(packetPtr + u * WARP_SIZE, regs[u], regs[u + 1]);
needReload |= flagThread && (regs[u + 1] != flag);
}
} while (__any_sync(WARP_MASK, needReload));
#pragma unroll
for (int g = 1; g < U64_DATA_REG_PER_THREAD / 2; g += 2) {
if (flagThread) regs[2 * g - 1] = regs[2 * g];
}
#pragma unroll
for (int g = 0; g < U64_DATA_REG_PER_THREAD / 2; g++) {
int ix = g * WARP_SIZE - 4 * (g / 2) + wid - (g % 2) * (wid / 8);
__syncwarp();
if (!flagThread || g % 2 == 0) {
if (ix * EltPer16B + vecOff < eltN) {
store128((uint64_t*)(dst + ix * EltPer16B + vecOff), regs[2 * g + 0],
regs[2 * g + 1]);
}
}
__syncwarp();
}
}
}
updateRecv();
}
__inline__ __device__ void run() {
if (peerRank >= worldInfo.epSize) {
return;
}
init();
computeGroupTransferRange();
while (groupStartIndice < groupEndIndice) {
loadTransferIndices();
computeSlicePtr();
if (isSender) {
sendSlice();
} else {
recvSlice();
}
}
}
__inline__ __device__ ~AllToAllChannelCommunicator() {}
__inline__ __device__ void barrier() { barrier_sync(15 - group, nthreads); }
__inline__ __device__ void waitSend() {
barrier();
while (tailStepCache + RECV_FIFO_DEPTH < step + 1) {
tailStepCache = fifoConnInfoPtr->tail;
}
barrier();
}
__inline__ __device__ void updateSend() {
barrier();
if (tid == 0) {
atomicAdd_system((unsigned long long*)&fifoConnInfoPtr->head, 1);
}
barrier();
step++;
}
__inline__ __device__ void updateRecv() {
barrier();
if (tid == 0) {
atomicAdd_system((unsigned long long*)&fifoConnInfoPtr->tail, 1);
}
barrier();
step++;
}
};
__global__ void moeAllToAllKernel(MoeEpWorldInfo worldInfo, MoeCommWorkspace workspace,
SendRecvDataInfo sendRecvDataInfo, SendRecvDispls sendDispls,
SendRecvDispls recvDispls) {
__shared__ AllToAllChannelCommunicatorBase::GroupSharedBuffer
allGroupSharedBuffer[AllToAllChannelCommunicatorBase::GROUP_COUNT_PER_BLOCK];
bool isSender = blockIdx.z == 0;
int channelCount = gridDim.y;
int group = threadIdx.y;
SendRecvDispls dataDispls = isSender ? sendDispls : recvDispls;
AllToAllChannelCommunicatorBase::GroupSharedBuffer* groupSharedBuffer =
&allGroupSharedBuffer[group];
if (isSender) {
AllToAllChannelCommunicator<true> comm(worldInfo, workspace, sendRecvDataInfo, dataDispls,
groupSharedBuffer, channelCount);
comm.run();
} else {
AllToAllChannelCommunicator<false> comm(worldInfo, workspace, sendRecvDataInfo, dataDispls,
groupSharedBuffer, channelCount);
comm.run();
}
}
void moeAllToAll(MoeEpWorldInfo worldInfo, SendRecvDataInfo sendRecvDataInfo,
SendRecvDispls sendDispls, SendRecvDispls recvDispls, MoeCommWorkspace workspace,
cudaStream_t stream) {
sendRecvDataInfo.DoPreCompute();
FLASHINFER_CHECK(reinterpret_cast<uintptr_t>(sendDispls.dataPtr) % 16 == 0,
"sendDispls.dataPtr must be 16-byte aligned");
FLASHINFER_CHECK(reinterpret_cast<uintptr_t>(recvDispls.dataPtr) % 16 == 0,
"recvDispls.dataPtr must be 16-byte aligned");
dim3 block = AllToAllChannelCommunicatorBase::getLaunchBlockDim();
dim3 grid = AllToAllChannelCommunicatorBase::getLaunchGridDim(worldInfo.epSize);
moeAllToAllKernel<<<grid, block, 0, stream>>>(worldInfo, workspace, sendRecvDataInfo, sendDispls,
recvDispls);
}
template <bool isSend, int kThreadsGroupSize>
__inline__ __device__ void computeSendRecvRankCountDevice(
MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo, int maxTokenCountPerRank,
int const* realRankTokenCountCumSum, int const* gatheredTargetRankIds,
int* sharedSendRecvRankCount, int* sendRecvRankCount) {
cg::thread_block_tile<kThreadsGroupSize> tile =
cg::tiled_partition<kThreadsGroupSize>(cg::this_thread_block());
int laneInTile = tile.thread_rank();
int tileId = threadIdx.x / kThreadsGroupSize;
int tileCountPerBlock = blockDim.x / kThreadsGroupSize;
int topK = expertParallelInfo.topK;
int epRank = worldInfo.epRank;
int epSize = worldInfo.epSize;
if (threadIdx.x == 0) {
*sharedSendRecvRankCount = 0;
}
__syncthreads();
int readRank = isSend ? epRank : blockIdx.x;
int compareRankId = isSend ? blockIdx.x : epRank;
int const* readRankTargetRankIds = gatheredTargetRankIds + readRank * maxTokenCountPerRank * topK;
int readRankTokenCount = maxTokenCountPerRank;
if (realRankTokenCountCumSum != nullptr) {
int readRankStart = readRank == 0 ? 0 : realRankTokenCountCumSum[readRank - 1];
readRankTargetRankIds = gatheredTargetRankIds + readRankStart * topK;
readRankTokenCount = realRankTokenCountCumSum[readRank] - readRankStart;
}
for (int i = tileId + blockIdx.z * tileCountPerBlock; i < readRankTokenCount;
i += tileCountPerBlock * gridDim.z) {
int targetRankId = laneInTile < topK ? readRankTargetRankIds[i * topK + laneInTile] : epSize;
bool rankMatched = (targetRankId == compareRankId);
bool hasRankMatched = tile.any(rankMatched);
if (hasRankMatched && laneInTile == 0) {
atomicAdd_block(sharedSendRecvRankCount, 1);
}
tile.sync();
}
__syncthreads();
if (threadIdx.x == 0) {
atomicAdd_system(sendRecvRankCount + blockIdx.x, *sharedSendRecvRankCount);
}
}
template <int kThreadsGroupSize>
__global__ void computeSendRecvRankCountKernel(MoeEpWorldInfo worldInfo,
MoeExpertParallelInfo expertParallelInfo,
int maxTokenCountPerRank,
int const* realRankTokenCountCumSum,
int const* gatheredTargetRankIds, int* sendRankCount,
int* recvRankCount) {
static_assert(kThreadsGroupSize == 1 || kThreadsGroupSize == 2 || kThreadsGroupSize == 4 ||
kThreadsGroupSize == 8 || kThreadsGroupSize == 16 || kThreadsGroupSize == 32,
"Only 1, 2, 4, 8, 16, 32 threads group size supported now.");
__shared__ int sharedSendRecvRankCount;
if (blockIdx.y == 0) {
// compute send rank count
computeSendRecvRankCountDevice<true, kThreadsGroupSize>(
worldInfo, expertParallelInfo, maxTokenCountPerRank, realRankTokenCountCumSum,
gatheredTargetRankIds, &sharedSendRecvRankCount, sendRankCount);
} else {
// compute recv rank count
computeSendRecvRankCountDevice<false, kThreadsGroupSize>(
worldInfo, expertParallelInfo, maxTokenCountPerRank, realRankTokenCountCumSum,
gatheredTargetRankIds, &sharedSendRecvRankCount, recvRankCount);
}
}
void computeSendRecvRankCount(const MoeEpWorldInfo& worldInfo,
MoeExpertParallelInfo expertParallelInfo, int maxTokenCountPerRank,
int const* realRankTokenCountCumSum, int const* gatheredTargetRankIds,
int* sendRankCount, int* recvRankCount, cudaStream_t stream) {
FLASHINFER_CHECK(expertParallelInfo.topK <= 32,
"Only topK less than or equal to 32 supported now.");
int threadsPerBlock = 1024;
auto* kernelPtr = computeSendRecvRankCountKernel<32>;
if (expertParallelInfo.topK <= 1) {
kernelPtr = computeSendRecvRankCountKernel<1>;
} else if (expertParallelInfo.topK <= 2) {
kernelPtr = computeSendRecvRankCountKernel<2>;
} else if (expertParallelInfo.topK <= 4) {
kernelPtr = computeSendRecvRankCountKernel<4>;
} else if (expertParallelInfo.topK <= 8) {
kernelPtr = computeSendRecvRankCountKernel<8>;
} else if (expertParallelInfo.topK <= 16) {
kernelPtr = computeSendRecvRankCountKernel<16>;
}
dim3 block(worldInfo.epSize, 2, 1);
kernelPtr<<<block, threadsPerBlock, 0, stream>>>(
worldInfo, expertParallelInfo, maxTokenCountPerRank, realRankTokenCountCumSum,
gatheredTargetRankIds, sendRankCount, recvRankCount);
}
template <int kThreadsPerBlock>
__global__ void inplaceSendRecvRankCumSumKernel(MoeEpWorldInfo worldInfo, int* sendRankCount,
int* recvRankCount) {
int* inputOutputPtr = blockIdx.x == 0 ? sendRankCount : recvRankCount;
typedef cub::BlockScan<int, kThreadsPerBlock> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;
int tid = threadIdx.x;
int threadData = tid < worldInfo.epSize ? inputOutputPtr[tid] : 0;
BlockScan(temp_storage).InclusiveSum(threadData, threadData);
if (tid < worldInfo.epSize) {
inputOutputPtr[tid] = threadData;
}
}
void inplaceSendRecvRankCumSum(const MoeEpWorldInfo& worldInfo, int* sendRankCount,
int* recvRankCount, cudaStream_t stream) {
int epSize = worldInfo.epSize;
int epRank = worldInfo.epRank;
FLASHINFER_CHECK(epSize <= 1024, "Only epSize less than or equal to 1024 supported now.");
auto* kernelPtr = inplaceSendRecvRankCumSumKernel<1024>;
int blockSize = 1024;
if (epSize <= 32) {
kernelPtr = inplaceSendRecvRankCumSumKernel<32>;
blockSize = 32;
} else if (epSize <= 64) {
kernelPtr = inplaceSendRecvRankCumSumKernel<64>;
blockSize = 64;
} else if (epSize <= 128) {
kernelPtr = inplaceSendRecvRankCumSumKernel<128>;
blockSize = 128;
} else if (epSize <= 256) {
kernelPtr = inplaceSendRecvRankCumSumKernel<256>;
blockSize = 256;
} else if (epSize <= 512) {
kernelPtr = inplaceSendRecvRankCumSumKernel<512>;
blockSize = 512;
}
kernelPtr<<<2, blockSize, 0, stream>>>(worldInfo, sendRankCount, recvRankCount);
}
template <bool isSend, int kThreadsGroupSize, int kThreadsPerBlock>
__inline__ __device__ void computeSendRecvIndicesDevice(
MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo, int maxTokenCountPerRank,
int const* realRankTokenCountCumSum, int const* gatheredTargetRankIds,
int const* sendRecvCumSum,
int* sendRecvIndices, // send or receive
int* localGatherIndices, // receive only
int* backwardRecvRankLocalIndices, // send only
int* sharedSendRecvRankStart,
typename cub::BlockScan<int, kThreadsPerBlock>::TempStorage& tempStorage) {
cg::thread_block_tile<kThreadsGroupSize> tile =
cg::tiled_partition<kThreadsGroupSize>(cg::this_thread_block());
int laneInTile = tile.thread_rank();
int tileId = threadIdx.x / kThreadsGroupSize;
int tileCountPerBlock = blockDim.x / kThreadsGroupSize;
int topK = expertParallelInfo.topK;
int epRank = worldInfo.epRank;
int epSize = worldInfo.epSize;
if (threadIdx.x == 0) {
*sharedSendRecvRankStart = blockIdx.x == 0 ? 0 : sendRecvCumSum[blockIdx.x - 1];
}
__syncthreads();
int readRank = isSend ? epRank : blockIdx.x;
int compareRankId = isSend ? blockIdx.x : epRank;
int readRankStart = readRank * maxTokenCountPerRank;
int const* readRankTargetRankIds = gatheredTargetRankIds + readRankStart * topK;
int readRankTokenCount = maxTokenCountPerRank;
if (realRankTokenCountCumSum != nullptr) {
readRankStart = readRank == 0 ? 0 : realRankTokenCountCumSum[readRank - 1];
readRankTargetRankIds = gatheredTargetRankIds + readRankStart * topK;
readRankTokenCount = realRankTokenCountCumSum[readRank] - readRankStart;
}
for (int blockStartId = blockIdx.z * tileCountPerBlock; blockStartId < readRankTokenCount;
blockStartId += tileCountPerBlock * gridDim.z) {
int stepStartIndice = *sharedSendRecvRankStart;
int i = blockStartId + tileId;
int targetRankId = (laneInTile < topK && i < readRankTokenCount)
? readRankTargetRankIds[i * topK + laneInTile]
: epSize;
bool rankMatched = (targetRankId == compareRankId);
bool hasRankMatched = tile.any(rankMatched);
unsigned int laneMask = tile.ballot(rankMatched);
int lowestLane = __ffs(laneMask) - 1;
int isMatchedLane = (hasRankMatched && laneInTile == lowestLane) ? 1 : 0;
int indice;
typedef cub::BlockScan<int, kThreadsPerBlock> BlockScan;
BlockScan(tempStorage).ExclusiveSum(isMatchedLane, indice);
indice += stepStartIndice;
__syncthreads();
if (isMatchedLane == 1) {
atomicAdd_block(sharedSendRecvRankStart, 1);
if (isSend) {
sendRecvIndices[indice] = i;
backwardRecvRankLocalIndices[indice] = i * topK + lowestLane;
} else {
sendRecvIndices[indice] = indice;
localGatherIndices[indice] = readRankStart + i;
}
}
__syncthreads();
}
}
template <int kThreadsGroupSize, int kThreadsPerBlock>
__global__ void computeSendRecvIndicesKernel(
MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo, int maxTokenCountPerRank,
int const* realRankTokenCountCumSum, int const* gatheredTargetRankIds,
int const* sendRankCountCumSum, int const* recvRankCountCumSum, int* localGatherIndices,
int* sendRankLocalIndices, int* recvRankLocalIndices, int* backwardRecvRankLocalIndices) {
static_assert(kThreadsGroupSize == 1 || kThreadsGroupSize == 2 || kThreadsGroupSize == 4 ||
kThreadsGroupSize == 8 || kThreadsGroupSize == 16 || kThreadsGroupSize == 32,
"Only 1, 2, 4, 8, 16, 32 threads group size supported now.");
__shared__ int sharedSendRecvRankStart;
__shared__ typename cub::BlockScan<int, kThreadsPerBlock>::TempStorage tempStorage;
if (blockIdx.y == 0) {
// compute send rank count
computeSendRecvIndicesDevice<true, kThreadsGroupSize, kThreadsPerBlock>(
worldInfo, expertParallelInfo, maxTokenCountPerRank, realRankTokenCountCumSum,
gatheredTargetRankIds, sendRankCountCumSum, sendRankLocalIndices, localGatherIndices,
backwardRecvRankLocalIndices, &sharedSendRecvRankStart, tempStorage);
} else {
// compute recv rank count
computeSendRecvIndicesDevice<false, kThreadsGroupSize, kThreadsPerBlock>(
worldInfo, expertParallelInfo, maxTokenCountPerRank, realRankTokenCountCumSum,
gatheredTargetRankIds, recvRankCountCumSum, recvRankLocalIndices, localGatherIndices,
backwardRecvRankLocalIndices, &sharedSendRecvRankStart, tempStorage);
}
}
void computeSendRecvIndices(const MoeEpWorldInfo& worldInfo,
MoeExpertParallelInfo expertParallelInfo, int maxTokenCountPerRank,
int const* realRankTokenCountCumSum, int const* gatheredTargetRankIds,
int const* sendRankCountCumSum, int const* recvRankCountCumSum,
int* localGatherIndices, int* sendRankLocalIndices,
int* recvRankLocalIndices, int* backwardRecvRankLocalIndices,
cudaStream_t stream) {
FLASHINFER_CHECK(expertParallelInfo.topK <= 32,
"Only topK less than or equal to 32 supported now.");
int threadsPerBlock = 1024;
auto* kernelPtr = computeSendRecvIndicesKernel<32, 1024>;
if (expertParallelInfo.topK <= 1) {
kernelPtr = computeSendRecvIndicesKernel<1, 1024>;
} else if (expertParallelInfo.topK <= 2) {
kernelPtr = computeSendRecvIndicesKernel<2, 1024>;
} else if (expertParallelInfo.topK <= 4) {
kernelPtr = computeSendRecvIndicesKernel<4, 1024>;
} else if (expertParallelInfo.topK <= 8) {
kernelPtr = computeSendRecvIndicesKernel<8, 1024>;
} else if (expertParallelInfo.topK <= 16) {
kernelPtr = computeSendRecvIndicesKernel<16, 1024>;
} else if (expertParallelInfo.topK <= 32) {
kernelPtr = computeSendRecvIndicesKernel<32, 1024>;
}
dim3 block(worldInfo.epSize, 2, 1);
kernelPtr<<<block, threadsPerBlock, 0, stream>>>(
worldInfo, expertParallelInfo, maxTokenCountPerRank, realRankTokenCountCumSum,
gatheredTargetRankIds, sendRankCountCumSum, recvRankCountCumSum, localGatherIndices,
sendRankLocalIndices, recvRankLocalIndices, backwardRecvRankLocalIndices);
}
cudaError_t moeAllToAllPrepareIndices(
MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo, int maxTokenCountPerRank,
int const* gatheredTargetRankIds, int const* realRankTokenCountCumSum,
// indices of gatheredTargetRankIds that has the local rank in topK
int* localGatherIndices, // max length = maxTokenCountPerRank * worldInfo.epSize when all ranks
// send to current rank
int* sendRankCountCumSum, // max length = worldInfo.epSize
int* sendRankLocalIndices, // max length = maxTokenCountPerRank *
// expertParallelInfo.expertCount when current rank has
// maxTokenCountPerRank tokens to send and all has expertCount dest
int* recvRankCountCumSum, // max length = worldInfo.epSize
int* recvRankLocalIndices, // max length = maxTokenCountPerRank * worldInfo.epSize when all
// ranks send to current rank
// the rankCountCumSum of combineRecv should be the same as sendRankCountCumSum
int* backwardRecvRankLocalIndices, // max length = maxTokenCountPerRank *
// expertParallelInfo.expertCount when current rank has
// maxTokenCountPerRank tokens to send and all has
// expertCount dest
cudaStream_t stream) {
FLASHINFER_CHECK(worldInfo.epSize <= 1024,
"Only worldInfo.epSize less than or equal to 1024 supported now.");
FLASHINFER_CUDA_CALL(
cudaMemsetAsync(sendRankCountCumSum, 0, sizeof(int) * worldInfo.epSize, stream));
FLASHINFER_CUDA_CALL(
cudaMemsetAsync(recvRankCountCumSum, 0, sizeof(int) * worldInfo.epSize, stream));
int maxSendRanksPerToken = std::max(worldInfo.epSize, expertParallelInfo.topK);
FLASHINFER_CUDA_CALL(cudaMemsetAsync(
localGatherIndices, -1, maxTokenCountPerRank * worldInfo.epSize * sizeof(int), stream));
FLASHINFER_CUDA_CALL(cudaMemsetAsync(
sendRankLocalIndices, -1, maxTokenCountPerRank * maxSendRanksPerToken * sizeof(int), stream));
FLASHINFER_CUDA_CALL(cudaMemsetAsync(
recvRankLocalIndices, -1, maxTokenCountPerRank * worldInfo.epSize * sizeof(int), stream));
FLASHINFER_CUDA_CALL(cudaMemsetAsync(backwardRecvRankLocalIndices, -1,
maxTokenCountPerRank * maxSendRanksPerToken * sizeof(int),
stream));
computeSendRecvRankCount(worldInfo, expertParallelInfo, maxTokenCountPerRank,
realRankTokenCountCumSum, gatheredTargetRankIds, sendRankCountCumSum,
recvRankCountCumSum, stream);
inplaceSendRecvRankCumSum(worldInfo, sendRankCountCumSum, recvRankCountCumSum, stream);
computeSendRecvIndices(worldInfo, expertParallelInfo, maxTokenCountPerRank,
realRankTokenCountCumSum, gatheredTargetRankIds, sendRankCountCumSum,
recvRankCountCumSum, localGatherIndices, sendRankLocalIndices,
recvRankLocalIndices, backwardRecvRankLocalIndices, stream);
return cudaSuccess;
}
template <int kThreadsGroupSize>
__global__ void moeLocalGatherDevice(MoeEpWorldInfo worldInfo,
MoeExpertParallelInfo expertParallelInfo,
int maxTokenCountPerRank, int localMaxTokenCount,
int const* recvRankCountCumSum, int const* localGatherIndices,
int const* gatheredExpertIds, float const* gatheredScales,
int* localExpertIds, float* localScales) {
cg::thread_block_tile<kThreadsGroupSize> tile =
cg::tiled_partition<kThreadsGroupSize>(cg::this_thread_block());
int laneInTile = tile.thread_rank();
int tileId = threadIdx.x / kThreadsGroupSize;
int tileCountPerBlock = blockDim.x / kThreadsGroupSize;
int epSize = worldInfo.epSize;
int rankTokenCount = recvRankCountCumSum[epSize - 1];
bool needLoad = laneInTile < expertParallelInfo.topK;
for (int index = tileId + blockIdx.x * tileCountPerBlock; index < localMaxTokenCount;
index += tileCountPerBlock * gridDim.x) {
int localTokenIndice = localGatherIndices[index];
int expertId = needLoad && (index < rankTokenCount)
? gatheredExpertIds[localTokenIndice * expertParallelInfo.topK + laneInTile]
: expertParallelInfo.expertCount;
float scale = needLoad && (index < rankTokenCount)
? gatheredScales[localTokenIndice * expertParallelInfo.topK + laneInTile]
: 0.0f;
if (needLoad) {
localExpertIds[index * expertParallelInfo.topK + laneInTile] = expertId;
localScales[index * expertParallelInfo.topK + laneInTile] = scale;
}
}
}
void moeLocalGather(MoeEpWorldInfo worldInfo, MoeExpertParallelInfo expertParallelInfo,
int maxTokenCountPerRank, int localMaxTokenCount,
int const* recvRankCountCumSum, int const* localGatherIndices,
int const* gatheredExpertIds, float const* gatheredScales, int* localExpertIds,
float* localScales, cudaStream_t stream) {
FLASHINFER_CHECK(expertParallelInfo.topK <= 32,
"Only topK less than or equal to 32 supported now.");
auto* kernelPtr = moeLocalGatherDevice<32>;
int paddedTopK = 32;
if (expertParallelInfo.topK <= 1) {
paddedTopK = 1;
kernelPtr = moeLocalGatherDevice<1>;
} else if (expertParallelInfo.topK <= 2) {
paddedTopK = 2;
kernelPtr = moeLocalGatherDevice<2>;
} else if (expertParallelInfo.topK <= 4) {
paddedTopK = 4;
kernelPtr = moeLocalGatherDevice<4>;
} else if (expertParallelInfo.topK <= 8) {
paddedTopK = 8;
kernelPtr = moeLocalGatherDevice<8>;
} else if (expertParallelInfo.topK <= 16) {
paddedTopK = 16;
kernelPtr = moeLocalGatherDevice<16>;
}
int threadsPerBlock = 512;
int tokenPerBlock = threadsPerBlock / paddedTopK;
int blockCount = (localMaxTokenCount + tokenPerBlock - 1) / tokenPerBlock * 2;
kernelPtr<<<blockCount, threadsPerBlock, 0, stream>>>(
worldInfo, expertParallelInfo, maxTokenCountPerRank, localMaxTokenCount, recvRankCountCumSum,
localGatherIndices, gatheredExpertIds, gatheredScales, localExpertIds, localScales);
}
int AllToAllChannelCommunicatorBase::maxSmCount = -1;
bool AllToAllChannelCommunicatorBase::maxSmCountUsed = false;
void setMaxUsableSmCount(int smCount) {
AllToAllChannelCommunicatorBase::setMaxUsableSmCount(smCount);
}
} // namespace trtllm_alltoall
} // namespace flashinfer