sglang_v0.5.2/flashinfer_0.3.1/csrc/trtllm_alltoall_prepare.cu

619 lines
24 KiB
Plaintext

/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <stdio.h>
#include <cub/cub.cuh>
#include "flashinfer/comm/trtllm_alltoall_prepare.cuh"
#include "flashinfer/exception.h"
#include "flashinfer/utils.cuh"
// Local definition to avoid multiple definition issues from trtllm_alltoall.cuh
static int getMultiProcessorCount() {
int device_id;
int multi_processor_count;
FLASHINFER_CUDA_CALL(cudaGetDevice(&device_id));
FLASHINFER_CUDA_CALL(
cudaDeviceGetAttribute(&multi_processor_count, cudaDevAttrMultiProcessorCount, device_id));
return multi_processor_count;
}
namespace cg = cooperative_groups;
namespace flashinfer::trtllm_alltoall {
namespace moe_prepare {
__device__ __forceinline__ void st_release_sys_global(uint64_t volatile* ptr, uint64_t val) {
asm volatile("st.release.sys.global.u64 [%0], %1;" ::"l"(ptr), "l"(val) : "memory");
}
__device__ __forceinline__ uint64_t ld_acquire_sys_global(uint64_t volatile* ptr) {
uint64_t ret;
asm volatile("ld.acquire.sys.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ int ld_acquire_sys_global_int(int volatile* ptr) {
int ret;
asm volatile("ld.acquire.sys.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
return ret;
}
class StepCommunicatorBase {
public:
static constexpr int META_SIZE = sizeof(MoeCommFifoConnInfo);
__device__ __inline__ StepCommunicatorBase(MoeCommFifoConnInfo* fifoConnInfo)
: fifoConnInfo(fifoConnInfo), localCachedHead(0), localCachedTail(0) {}
__forceinline__ __device__ void reset() {
fifoConnInfo->head = 0;
fifoConnInfo->tail = 0;
}
__forceinline__ __device__ void releaseSendStep() {
localCachedHead += 1;
st_release_sys_global(&(fifoConnInfo->head), uint64_t(localCachedHead));
}
__forceinline__ __device__ void releaseRecvStep() {
localCachedTail += 1;
st_release_sys_global(&(fifoConnInfo->tail), uint64_t(localCachedTail));
}
__forceinline__ __device__ uint64_t acquireTail() {
uint64_t tail = ld_acquire_sys_global(&(fifoConnInfo->tail));
localCachedTail = tail;
return tail;
}
__forceinline__ __device__ uint64_t acquireHead() {
uint64_t head = ld_acquire_sys_global(&(fifoConnInfo->head));
localCachedHead = head;
return head;
}
__forceinline__ __device__ int acquireNewSendStep() {
int64_t tail;
do {
tail = acquireTail();
} while (localCachedHead >= tail + STEP_DEPTH);
// depth = 2, head = 1, tail = 0 , ok
// depth = 2, head = 2, tail = 0, should wait
return localCachedHead % STEP_DEPTH;
}
__forceinline__ __device__ int acquireNewRecvStep() {
int64_t head = 0;
do {
head = acquireHead();
} while (localCachedTail >= head);
return localCachedTail % STEP_DEPTH;
}
public:
MoeCommFifoConnInfo* fifoConnInfo;
uint64_t localCachedHead;
uint64_t localCachedTail;
int rank;
int targetRank;
};
// Use MoeCommFifoConnInfo as media to transfer a counter number.
// Use the "head" field as flag.
// Use the "tail" field to transfer the counter number.
class CounterCommunicator {
public:
__device__ __inline__ CounterCommunicator(MoeCommFifoConnInfo* fifoConnInfo)
: fifoConnInfo(fifoConnInfo) {}
__forceinline__ __device__ void releaseValue(uint64_t value) {
// Avoid block on 0
st_release_sys_global(&(fifoConnInfo->count), value + 1);
}
__forceinline__ __device__ uint64_t acquireValue() {
uint64_t localCount = 0;
do {
localCount = ld_acquire_sys_global(&(fifoConnInfo->count));
} while (localCount == 0);
fifoConnInfo->count = 0; // reset the count
return localCount - 1;
}
protected:
MoeCommFifoConnInfo* fifoConnInfo;
};
template <int kThreadsGroupSize>
__device__ __forceinline__ void computeCountAndSend(
int* experts, int tokenCount, int* sharedSendRecvRankCount, int* sendCounts,
int* sendIndiceWorkspace, int* backwardIndiceWorkspace, MoeCommWorkspace workspace,
int maxTokenCountPerRank, int expertCount, int topK, int epRank, int epSize) {
cg::thread_block_tile<kThreadsGroupSize> tile =
cg::tiled_partition<kThreadsGroupSize>(cg::this_thread_block());
int laneInTile = tile.thread_rank();
int tileId = threadIdx.x / kThreadsGroupSize;
int tileCountPerBlock = blockDim.x / kThreadsGroupSize;
int expertCountPerRank = expertCount / epSize;
if (threadIdx.x == 0) {
*sharedSendRecvRankCount = 0;
}
__syncthreads();
int targetRankId = blockIdx.x;
int readRankTokenCount = tokenCount;
if (targetRankId >= epSize) {
return;
}
int* localSendIndice = sendIndiceWorkspace + targetRankId * maxTokenCountPerRank;
int* localBackwardIndice = backwardIndiceWorkspace + targetRankId * maxTokenCountPerRank;
for (int i = tileId; i < readRankTokenCount; i += tileCountPerBlock) {
int expertRankId =
laneInTile < topK ? experts[i * topK + laneInTile] / expertCountPerRank : epSize;
bool rankMatched = (expertRankId == targetRankId);
bool hasRankMatched = tile.any(rankMatched);
int mask = tile.ballot(rankMatched);
int firstMatchLane = __ffs(mask) - 1; // only valid if hasRankMatched is true
if (hasRankMatched && laneInTile == 0) {
int index = atomicAdd_block(sharedSendRecvRankCount, 1);
localSendIndice[index] = i;
localBackwardIndice[index] = i * topK + firstMatchLane;
}
tile.sync();
}
__syncthreads();
if (threadIdx.x == 0) {
CounterCommunicator counter(
workspace.getFifoConnInfo(true, epRank, targetRankId, 0, epSize, 1));
int count = *(sharedSendRecvRankCount);
// printf("sendRecvCount: %d, rankId: %d, targetRankId: %d\n", count, rankId, targetRankId);
counter.releaseValue(uint64_t(count));
*(sendCounts + targetRankId) = count;
}
}
__device__ __forceinline__ void recvCount(int* recvIndiceWorkspace, int* recvCounts,
int* sharedCountsBase, MoeCommWorkspace workspace,
int maxTokenCountPerRank, int rankId, int rankCount) {
int rankOffset = threadIdx.x / THREADS_PER_PIPELINE;
if (rankOffset >= PIPELINE_PER_CTA) {
return;
}
int* sharedCountsThisRank = sharedCountsBase + rankOffset;
int targetRankId = (blockIdx.x - rankCount) * PIPELINE_PER_CTA + rankOffset;
if (targetRankId >= rankCount) {
return;
}
int unitId = threadIdx.x % UNIT_PER_PIPELINE;
cg::thread_block_tile<THREADS_PER_PIPELINE> rankTile =
cg::tiled_partition<THREADS_PER_PIPELINE>(cg::this_thread_block());
int* localRecvIndice = recvIndiceWorkspace + targetRankId * maxTokenCountPerRank;
int rankRecvCount;
if (rankTile.thread_rank() == 0) {
CounterCommunicator counter(
workspace.getFifoConnInfo(false, rankId, targetRankId, 0, rankCount, 1));
rankRecvCount = int(counter.acquireValue());
// printf("rankRecvCount: %d, rankId: %d, targetRankId: %d\n", rankRecvCount, rankId,
// targetRankId);
*(recvCounts + targetRankId) = rankRecvCount;
*(sharedCountsThisRank) = rankRecvCount;
}
rankTile.sync();
rankRecvCount = *(sharedCountsThisRank);
for (int tokenId = unitId; tokenId < rankRecvCount; tokenId += UNIT_PER_PIPELINE) {
*(localRecvIndice + tokenId) = tokenId;
}
}
template <int kThreadsGroupSize>
__global__ void computeCountAndIndiceDevice(int* experts, int* sendCounts, int* recvCounts,
int* sendIndiceWorkspace, int* backwardIndiceWorkspace,
int* recvIndiceWorkspace, MoeCommWorkspace workspace,
int tokenCount, int maxTokenCountPerRank, int topK,
int expertCount, int rankId, int rankCount) {
__shared__ int sharedCounts[PIPELINE_PER_CTA];
bool isSender = blockIdx.x < rankCount;
if (isSender) {
computeCountAndSend<kThreadsGroupSize>(experts, tokenCount, &sharedCounts[0], sendCounts,
sendIndiceWorkspace, backwardIndiceWorkspace, workspace,
maxTokenCountPerRank, expertCount, topK, rankId,
rankCount);
} else {
recvCount(recvIndiceWorkspace, recvCounts, &sharedCounts[0], workspace, maxTokenCountPerRank,
rankId, rankCount);
}
}
__global__ void moveIndiceDevice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice,
int* gatherSendIndice, int* backwardIndice,
int* gatherBackwardIndice, int* recvIndice, int* gatherRecvIndice,
int maxTokenCountPerRank) {
int targetRankId = blockIdx.x;
if (blockIdx.y == 0) {
// sendIndice and backwardIndice CTA
int startIndex = targetRankId == 0 ? 0 : sendCountsCumsum[targetRankId - 1];
int endIndex = sendCountsCumsum[targetRankId];
int count = endIndex - startIndex;
int* localSendIndice = sendIndice + targetRankId * maxTokenCountPerRank;
int* localBackwardIndice = backwardIndice + targetRankId * maxTokenCountPerRank;
for (int localIdx = threadIdx.x; localIdx < count; localIdx += blockDim.x) {
gatherSendIndice[startIndex + localIdx] = localSendIndice[localIdx];
gatherBackwardIndice[startIndex + localIdx] = localBackwardIndice[localIdx];
}
} else {
// recvIndice CTA
int startIndex = targetRankId == 0 ? 0 : recvCountsCumsum[targetRankId - 1];
int endIndex = recvCountsCumsum[targetRankId];
int count = endIndex - startIndex;
for (int localIdx = threadIdx.x; localIdx < count; localIdx += blockDim.x) {
gatherRecvIndice[startIndex + localIdx] = startIndex + localIdx;
}
}
}
__global__ void computeCumsumDevice(int* sendCountsCumsum, int* recvCountsCumsum, int rankId,
int rankCount) {
int* inputOutputPtr = blockIdx.x == 0 ? sendCountsCumsum : recvCountsCumsum;
// Use 2 block to comuteCumsum
typedef cub::BlockScan<int, CUMSUM_THREADS_PER_BLOCK> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;
int tid = threadIdx.x;
int threadData = tid < rankCount ? inputOutputPtr[tid] : 0;
int count = threadData;
__syncthreads();
BlockScan(temp_storage).InclusiveSum(threadData, threadData);
if (tid < rankCount) {
inputOutputPtr[tid] = threadData;
// printf("cumsum, send? : %d, rankId:%d, tid:%d, threadData:%d, count:%d\n", blockIdx.x == 0,
// rankId, tid, threadData, count);
}
}
template <typename PipelineConfig>
class PacketPipeline {
public:
__device__ __inline__ PacketPipeline(void* bufferBase, StepCommunicatorBase* stepCommunicator,
int* sharedNewStepPtr, bool isSender)
: bufferBase(bufferBase),
stepCommunicator(stepCommunicator),
shared_new_step(sharedNewStepPtr) {
step = 0;
needRelease = false;
packetId = isSender ? 0 : PipelineConfig::PACKET_PER_STEP - 1;
}
__device__ __forceinline__ void* getFirstSendPacket() { return bufferBase; }
__device__ __inline__ void* finishSendPacket(bool acquireNewStep) {
packetId++;
if (packetId < PipelineConfig::PACKET_PER_STEP) {
return acquireNewStep
? bufferBase +
step * PipelineConfig::PACKET_PER_STEP * PipelineConfig::PACKET_SIZE +
packetId * PipelineConfig::PACKET_SIZE
: nullptr;
}
__syncthreads();
if (threadIdx.x == 0) {
stepCommunicator->releaseSendStep();
if (acquireNewStep) {
step = stepCommunicator->acquireNewSendStep();
*(shared_new_step) = step;
}
}
__syncthreads();
if (acquireNewStep) {
step = *(shared_new_step);
packetId = 0;
return bufferBase + step * PipelineConfig::PACKET_SIZE * PipelineConfig::PACKET_PER_STEP;
}
return nullptr;
}
__device__ __forceinline__ void* sendFinalize() {
if (packetId > 0 && threadIdx.x == 0) {
stepCommunicator->releaseSendStep();
}
}
__device__ __inline__ void* getNewRecvPacket() {
packetId++;
if (packetId < PipelineConfig::PACKET_PER_STEP) {
return bufferBase + step * PipelineConfig::PACKET_PER_STEP * PipelineConfig::PACKET_SIZE +
packetId * PipelineConfig::PACKET_SIZE;
}
__syncthreads();
if (threadIdx.x == 0) {
if (needRelease) {
stepCommunicator->releaseRecvStep();
}
step = stepCommunicator->acquireNewRecvStep();
needRelease = true;
*(shared_new_step) = step;
}
__syncthreads();
packetId = 0;
step = *(shared_new_step);
void* packetPtr =
bufferBase + step * PipelineConfig::PACKET_SIZE * PipelineConfig::PACKET_PER_STEP;
return packetPtr;
}
__device__ __forceinline__ void reset() {
if (threadIdx.x == 0) {
stepCommunicator->reset();
}
}
void* bufferBase;
StepCommunicatorBase* stepCommunicator;
int step;
int packetId;
bool needRelease;
int* shared_new_step;
};
template <typename PipelineConfig, typename ExpertType, typename ScaleType>
__global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float* sendScales,
float* recvScales, int* localExpertStatics,
int* gatheredExpertStatics, MoeCommWorkspace workspace,
int* sendCountsCumsum, int* localSendIndice,
int* recvCountsCumsum, int* localRecvIndice, int tokenCount,
int maxTokenCountPerRank, int topK, int expertCount,
int slotCount, int rankId, int rankCount) {
bool isSender = (blockIdx.y == 0);
int targetRankId = blockIdx.x;
int slotCountPerRank = slotCount / rankCount;
int groupSize = topK / PipelineConfig::UNIT_SIZE;
__shared__ int sharedNewStep;
__align__(16) int experts[PipelineConfig::UNIT_SIZE];
__align__(16) float scales[PipelineConfig::UNIT_SIZE];
uint8_t* bufferBase = (uint8_t*)(workspace.getFifoBasePtr(isSender, rankId, targetRankId, 0, 1));
StepCommunicatorBase stepCommunicator(
workspace.getFifoConnInfo(isSender, rankId, targetRankId, 0, rankCount, 1));
PacketPipeline<PipelineConfig> pipeline(bufferBase, &stepCommunicator, &sharedNewStep, isSender);
if (isSender) {
int baseCumsum = targetRankId == 0 ? 0 : *(sendCountsCumsum + targetRankId - 1);
int sendTokenCount = *(sendCountsCumsum + targetRankId) - baseCumsum;
int unitCount = sendTokenCount * topK / PipelineConfig::UNIT_SIZE;
void* packPtr = pipeline.getFirstSendPacket();
int indexBase = 0;
int staticCopyBase = 0;
bool acquireNewStep = unitCount > 0 || (localExpertStatics != nullptr && expertCount > 0);
while (acquireNewStep) {
if (threadIdx.x < UNIT_PER_ITER) {
int index = indexBase + threadIdx.x;
int groupId = index % groupSize;
if (index < unitCount) {
int tokenId =
*(localSendIndice + maxTokenCountPerRank * targetRankId + (index / groupSize));
*((ExpertType*)(experts)) =
*(ExpertType*)(sendExperts + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
#pragma unroll
for (int j = 0; j < PipelineConfig::UNIT_SIZE; j++) {
int expertId = experts[j];
if (expertId / slotCountPerRank != targetRankId) {
experts[j] = slotCount;
}
}
int* expertsPtr = (int*)(packPtr) + threadIdx.x * PipelineConfig::UNIT_SIZE;
*((ExpertType*)(expertsPtr)) = *((ExpertType*)(experts));
if (sendScales != nullptr) {
*((ScaleType*)(scales)) =
*(ScaleType*)(sendScales + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
float* scaleBasePtr = (float*)(packPtr + PipelineConfig::SCALE_OFFSET);
float* scalesPtr = (float*)(scaleBasePtr) + threadIdx.x * PipelineConfig::UNIT_SIZE;
*((ScaleType*)(scalesPtr)) = *((ScaleType*)(scales));
}
}
} else if (localExpertStatics != nullptr) {
int staticCopyIdx = threadIdx.x - UNIT_PER_ITER;
if (staticCopyBase + staticCopyIdx * 4 < expertCount) {
int4* staticBasePtr = (int4*)(packPtr + PipelineConfig::STATIC_COPY_OFFSET);
int4 staticData = *(int4*)(localExpertStatics + staticCopyBase + staticCopyIdx * 4);
*(staticBasePtr + staticCopyIdx) = staticData;
}
}
indexBase += UNIT_PER_ITER;
staticCopyBase += STATIC_COPY_PER_ITER * 4;
acquireNewStep = indexBase < unitCount || staticCopyBase < expertCount;
packPtr = pipeline.finishSendPacket(acquireNewStep);
}
pipeline.sendFinalize();
} else {
int baseCumsum = targetRankId == 0 ? 0 : *(recvCountsCumsum + targetRankId - 1);
int recvTokenCount = *(recvCountsCumsum + targetRankId) - baseCumsum;
int recvUnitCount = recvTokenCount * groupSize;
int unitIdBase = 0;
int staticCopyBase = 0;
while (unitIdBase < recvUnitCount ||
(localExpertStatics != nullptr && staticCopyBase < expertCount)) {
void* packetPtr = pipeline.getNewRecvPacket();
int packetUnitCount =
unitIdBase + UNIT_PER_ITER < recvUnitCount ? UNIT_PER_ITER : recvUnitCount - unitIdBase;
packetUnitCount = max(packetUnitCount, 0);
if (threadIdx.x < UNIT_PER_ITER) {
if (threadIdx.x < packetUnitCount) {
int tokenId = baseCumsum + (unitIdBase + threadIdx.x) / groupSize;
int groupId = (unitIdBase + threadIdx.x) % groupSize;
int* expertsPtr = (int*)(packetPtr) + threadIdx.x * PipelineConfig::UNIT_SIZE;
*((ExpertType*)(experts)) = *((ExpertType*)(expertsPtr));
ExpertType* dstExpertsPtr =
(ExpertType*)(recvExperts + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
*dstExpertsPtr = *((ExpertType*)(experts));
if (recvScales != nullptr) {
float* scaleBasePtr = (float*)(packetPtr + PipelineConfig::SCALE_OFFSET);
float* scalesPtr = scaleBasePtr + threadIdx.x * PipelineConfig::UNIT_SIZE;
*((ScaleType*)(scales)) = *((ScaleType*)(scalesPtr));
ScaleType* dstScalesPtr =
(ScaleType*)(recvScales + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
*dstScalesPtr = *((ScaleType*)(scales));
}
}
} else if (localExpertStatics != nullptr) {
int staticCopyIdx = threadIdx.x - UNIT_PER_ITER;
if (staticCopyBase + staticCopyIdx * 4 < expertCount) {
int4* staticBasePtr = (int4*)(packetPtr + PipelineConfig::STATIC_COPY_OFFSET);
int4 staticData = *(staticBasePtr + staticCopyIdx);
*(int4*)(gatheredExpertStatics + targetRankId * expertCount + staticCopyBase +
staticCopyIdx * 4) = staticData;
}
}
unitIdBase += packetUnitCount;
staticCopyBase += STATIC_COPY_PER_ITER * 4;
}
pipeline.reset();
}
}
__global__ void memsetExpertIdsDevice(int* expertIds, int* recvCountsCumsum,
int maxTokenCountPerRank, int topK, int slotCount,
int rankCount) {
int maxTokenCount = maxTokenCountPerRank * rankCount;
int totalRecvTokenCount = *(recvCountsCumsum + rankCount - 1);
for (int i = blockIdx.x * blockDim.x + threadIdx.x;
i + totalRecvTokenCount * topK < maxTokenCount * topK; i += gridDim.x * blockDim.x) {
*(expertIds + i + totalRecvTokenCount * topK) = slotCount;
}
}
void computeCountAndIndice(int* experts, int* sendCounts, int* recvCounts, int* sendIndiceWorkspace,
int* backwardIndiceWorkspace, int* recvIndiceWorkspace,
MoeCommWorkspace workspace, int tokenCount, int maxTokenCountPerRank,
int topK, int expert_count, int rankId, int rankCount,
cudaStream_t stream) {
// first rankCount CTAs for count and send, then rankCount / PIPELINE_PER_CTA CTAs only for
// receive
int grid_x = rankCount + (rankCount + PIPELINE_PER_CTA - 1) / PIPELINE_PER_CTA;
int block_size = 1024;
dim3 block(block_size);
dim3 grid(grid_x);
FLASHINFER_CHECK(topK >= 1 && topK <= 32, "Only 1 <= topK <= 32 is supported now.");
auto* kernelFn = computeCountAndIndiceDevice<1>;
if (topK > 16) {
kernelFn = computeCountAndIndiceDevice<32>;
} else if (topK > 8) {
kernelFn = computeCountAndIndiceDevice<16>;
} else if (topK > 4) {
kernelFn = computeCountAndIndiceDevice<8>;
} else if (topK > 2) {
kernelFn = computeCountAndIndiceDevice<4>;
} else if (topK > 1) {
kernelFn = computeCountAndIndiceDevice<2>;
}
kernelFn<<<grid, block, 0, stream>>>(experts, sendCounts, recvCounts, sendIndiceWorkspace,
backwardIndiceWorkspace, recvIndiceWorkspace, workspace,
tokenCount, maxTokenCountPerRank, topK, expert_count, rankId,
rankCount);
}
void computeCumsum(int* sendCountsCumsum, int* recvCountsCumsum, int rankId, int rankCount,
cudaStream_t stream) {
int block_size = CUMSUM_THREADS_PER_BLOCK;
dim3 block(block_size);
dim3 grid(2);
computeCumsumDevice<<<grid, block, 0, stream>>>(sendCountsCumsum, recvCountsCumsum, rankId,
rankCount);
}
void moveIndice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice,
int* gatherSendIndice, int* backwardIndice, int* gatherBackwardIndice,
int* recvIndice, int* gatherRecvIndice, int rankId, int rankCount,
int maxTokenCountPerRank, cudaStream_t stream) {
dim3 block(512);
dim3 grid(rankCount, 2);
moveIndiceDevice<<<grid, block, 0, stream>>>(
sendCountsCumsum, recvCountsCumsum, sendIndice, gatherSendIndice, backwardIndice,
gatherBackwardIndice, recvIndice, gatherRecvIndice, maxTokenCountPerRank);
}
void allToAllMetadata(int* sendExperts, int* recvExperts, float* sendScales, float* recvScales,
int* localExpertStatics, int* gatheredExpertStatics,
MoeCommWorkspace workspace, int* sendCountsCumsum, int* localSendIndice,
int* recvCountsCumsum, int* localRecvIndice, int tokenCount,
int maxTokenCountPerRank, int topK, int expertCount, int slotCount,
int rankId, int rankCount, cudaStream_t stream) {
int block_size =
localExpertStatics == nullptr ? UNIT_PER_ITER : UNIT_PER_ITER + STATIC_COPY_PER_ITER;
dim3 block(block_size);
dim3 grid(rankCount, 2);
if (topK % 4 == 0) {
using PipelineConfig = PipelineConfig<4, 16>;
static_assert(
PipelineConfig::PACKET_SIZE_IN_U64 * PipelineConfig::PACKET_PER_STEP * STEP_DEPTH <=
FIFO_SIZE_IN_U64,
"FIFO size is too small");
allToAllMetadataDevice<PipelineConfig, int4, float4><<<grid, block, 0, stream>>>(
sendExperts, recvExperts, sendScales, recvScales, localExpertStatics, gatheredExpertStatics,
workspace, sendCountsCumsum, localSendIndice, recvCountsCumsum, localRecvIndice, tokenCount,
maxTokenCountPerRank, topK, expertCount, slotCount, rankId, rankCount);
} else {
using PipelineConfig = PipelineConfig<1, 64>;
static_assert(
PipelineConfig::PACKET_SIZE_IN_U64 * PipelineConfig::PACKET_PER_STEP * STEP_DEPTH <=
FIFO_SIZE_IN_U64,
"FIFO size is too small");
allToAllMetadataDevice<PipelineConfig, int, float><<<grid, block, 0, stream>>>(
sendExperts, recvExperts, sendScales, recvScales, localExpertStatics, gatheredExpertStatics,
workspace, sendCountsCumsum, localSendIndice, recvCountsCumsum, localRecvIndice, tokenCount,
maxTokenCountPerRank, topK, expertCount, slotCount, rankId, rankCount);
}
int smCount = getMultiProcessorCount();
memsetExpertIdsDevice<<<smCount, 256, 0, stream>>>(
recvExperts, recvCountsCumsum, maxTokenCountPerRank, topK, slotCount, rankCount);
}
size_t getMoePrepareWorkspaceSize(int epSize) {
return (FIFO_SIZE_IN_U64 * 8 + StepCommunicatorBase::META_SIZE) * epSize;
}
} // namespace moe_prepare
} // namespace flashinfer::trtllm_alltoall