sglang_v0.5.2/flashinfer_0.3.1/include/flashinfer/comm/trtllm_alltoall_prepare.cuh

129 lines
5.0 KiB
Plaintext

/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstdint>
#include <map>
#define DEBUG_PIPELINE 0
namespace flashinfer::trtllm_alltoall {
namespace moe_prepare {
#define STEP_DEPTH 2
#define THREADS_PER_UNIT 1
#define UNIT_PER_PIPELINE 128
#define PIPELINE_PER_CTA 4
#define EXPERT_BYTES_PER_UNIT 32
#define SCALE_BYTES_PER_UNIT 32
#define UNIT_COUNT_PER_PACKET 1024
#define BYTES_COUNTER 8
#define CUMSUM_THREADS_PER_BLOCK 128
#define UNIT_PER_ITER 256
#define STATIC_COPY_PER_ITER 128
static constexpr int THREADS_PER_PIPELINE = THREADS_PER_UNIT * UNIT_PER_PIPELINE;
static constexpr int THREADS_PER_CTA = THREADS_PER_PIPELINE * PIPELINE_PER_CTA;
template <int UNIT_SIZE_INPUT, int PACKET_PER_STEP_INPUT>
struct PipelineConfig {
static constexpr int UNIT_SIZE = UNIT_SIZE_INPUT;
static constexpr int PACKET_PER_STEP = PACKET_PER_STEP_INPUT;
static constexpr int UNIT_BYTES_SIZE = UNIT_SIZE * UNIT_PER_ITER * (sizeof(int) + sizeof(float));
static constexpr int SCALE_OFFSET = UNIT_SIZE * UNIT_PER_ITER * sizeof(int);
static constexpr int STATIC_COPY_OFFSET =
UNIT_SIZE * UNIT_PER_ITER * (sizeof(int) + sizeof(float));
static constexpr int PACKET_SIZE = UNIT_BYTES_SIZE + STATIC_COPY_PER_ITER * 4 * sizeof(int);
static constexpr int PACKET_SIZE_IN_U64 = (PACKET_SIZE / 8);
};
// 1MB FIFO size
static constexpr int FIFO_SIZE_IN_U64 = 1024 * 1024 / 8;
#ifdef __CUDACC__
#define ALIGN_256 __align__(256)
#else
#define ALIGN_256 alignas(256)
#endif
struct ALIGN_256 MoeCommFifoConnInfo {
volatile uint64_t head; // write position
volatile uint64_t tail; // read position
volatile uint64_t count; // for counter
};
struct MoeCommWorkspace {
uint64_t* workspacePtr;
size_t rankStrideInU64;
#ifdef __CUDACC__
__inline__ __device__ uint64_t* getFifoBasePtr(bool isSender, int epRank, int peerRank,
int channel, int channelCount) const {
// fifo itself is in receiver's side.
if (isSender) {
return workspacePtr + peerRank * rankStrideInU64 +
(epRank * channelCount + channel) * FIFO_SIZE_IN_U64;
} else {
return workspacePtr + epRank * rankStrideInU64 +
(peerRank * channelCount + channel) * FIFO_SIZE_IN_U64;
}
}
__inline__ __device__ MoeCommFifoConnInfo* getFifoConnInfo(bool isSender, int epRank,
int peerRank, int channel, int epSize,
int channelCount) const {
// fifoInfo is in sender's side.
uint64_t* fifoInfoPtrU64 = workspacePtr + FIFO_SIZE_IN_U64 * channelCount * epSize;
int strideIndice = isSender ? epRank : peerRank;
int fifoInfoIndice = isSender ? peerRank : epRank;
fifoInfoPtrU64 += strideIndice * rankStrideInU64;
MoeCommFifoConnInfo* fifoInfoPtr = (MoeCommFifoConnInfo*)fifoInfoPtrU64;
MoeCommFifoConnInfo* result = fifoInfoPtr + fifoInfoIndice * channelCount + channel;
return result;
}
#endif
};
void computeCountAndIndice(int* experts, int* sendCounts, int* recvCounts, int* sendIndiceWorkspace,
int* backwardIndiceWorkspace, int* recvIndiceWorkspace,
MoeCommWorkspace workspace, int tokenCount, int maxTokenCountPerRank,
int topK, int expert_count, int rankId, int rankCount,
cudaStream_t stream);
void computeCumsum(int* sendCountsCumsum, int* recvCountsCumsum, int rankId, int rankCount,
cudaStream_t stream);
void moveIndice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice,
int* gatherSendIndice, int* backwardIndice, int* gatherBackwardIndice,
int* recvIndice, int* gatherRecvIndice, int rankId, int rankCount,
int maxTokenCountPerRank, cudaStream_t stream);
void allToAllMetadata(int* sendExperts, int* recvExperts, float* sendScales, float* recvScales,
int* localExpertStatics, int* gatheredExpertStatics,
MoeCommWorkspace workspace, int* sendCountsCumsum, int* localSendIndice,
int* recvCountsCumsum, int* localRecvIndice, int tokenCount,
int maxTokenCountPerRank, int topK, int expertCount, int slotCount,
int rankId, int rankCount, cudaStream_t stream);
size_t getMoePrepareWorkspaceSize(int epSize);
} // namespace moe_prepare
} // namespace flashinfer::trtllm_alltoall