commit cc76bab27ed083524a4bbb010a6a0242a0603c25 Author: hailin Date: Mon Sep 15 10:32:17 2025 +0800 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..9725fabd9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,240 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ + +# Tokenizer cache for tests +.tokenizer_cache/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + +# MacOS +.DS_Store + +# Vim +*.swp + +# Documentation +docs/_build + +# SGL +benchmark/mmlu/data +benchmark/mmlu/data.tar +benchmark/llava_bench/images +benchmark/llava_bench/mme_pack +*.jsonl +tmp*.txt + +# Plots +*.png +*.pdf + +# personnal +work_dirs/ +*.csv + +!logo.png + +# Prerequisites +*.d + +# Compiled Object files +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Fortran module files +*.mod +*.smod + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +compile_commands.json + +*.iml + +# VSCode +.vscode + +1 + +# Autoenv +.env.leave + +# Rust lib +Cargo.lock + +lmms-eval diff --git a/DeepEP/.gitignore b/DeepEP/.gitignore new file mode 100644 index 000000000..fd2a10383 --- /dev/null +++ b/DeepEP/.gitignore @@ -0,0 +1,8 @@ +compile_commands.json +.idea +.DS_Store +*.pyc +build/ +.cache/ +.vscode/ +*/cmake-build-*/ diff --git a/DeepEP/LICENSE b/DeepEP/LICENSE new file mode 100644 index 000000000..5c48bdc9f --- /dev/null +++ b/DeepEP/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 DeepSeek + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/DeepEP/README.md b/DeepEP/README.md new file mode 100644 index 000000000..2cb8657cb --- /dev/null +++ b/DeepEP/README.md @@ -0,0 +1,344 @@ +# DeepEP + +DeepEP is a communication library tailored for Mixture-of-Experts (MoE) and expert parallelism (EP). It provides high-throughput and low-latency all-to-all GPU kernels, which are also known as MoE dispatch and combine. The library also supports low-precision operations, including FP8. + +To align with the group-limited gating algorithm proposed in the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) paper, DeepEP offers a set of kernels optimized for asymmetric-domain bandwidth forwarding, such as forwarding data from NVLink domain to RDMA domain. These kernels deliver high throughput, making them suitable for both training and inference prefilling tasks. Additionally, they support SM (Streaming Multiprocessors) number control. + +For latency-sensitive inference decoding, DeepEP includes a set of low-latency kernels with pure RDMA to minimize delays. The library also introduces a hook-based communication-computation overlapping method that does not occupy any SM resource. + +Notice: the implementation in this library may have some slight differences from the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) paper. + +## Performance + +### Normal kernels with NVLink and RDMA forwarding + +We test normal kernels on H800 (~160 GB/s NVLink maximum bandwidth), with each connected to a CX7 InfiniBand 400 Gb/s RDMA network card (~50 GB/s maximum bandwidth). And we follow the DeepSeek-V3/R1 pretraining setting (4096 tokens per batch, 7168 hidden, top-4 groups, top-8 experts, FP8 dispatching and BF16 combining). + +| Type | Dispatch #EP | Bottleneck bandwidth | Combine #EP | Bottleneck bandwidth | +|:---------:|:------------:|:--------------------:|:-----------:|:--------------------:| +| Intranode | 8 | 153 GB/s (NVLink) | 8 | 158 GB/s (NVLink) | +| Internode | 16 | 43 GB/s (RDMA) | 16 | 43 GB/s (RDMA) | +| Internode | 32 | 58 GB/s (RDMA) | 32 | 57 GB/s (RDMA) | +| Internode | 64 | 51 GB/s (RDMA) | 64 | 50 GB/s (RDMA) | + +**News (2025.04.22)**: with optimizations from Tencent Network Platform Department, performance was enhanced by up to 30%, see [#130](https://github.com/deepseek-ai/DeepEP/pull/130) for more details. Thanks for the contribution! + +### Low-latency kernels with pure RDMA + +We test low-latency kernels on H800 with each connected to a CX7 InfiniBand 400 Gb/s RDMA network card (~50 GB/s maximum bandwidth). And we follow a typical DeepSeek-V3/R1 production setting (128 tokens per batch, 7168 hidden, top-8 experts, FP8 dispatching and BF16 combining). + +| Dispatch #EP | Latency | RDMA bandwidth | Combine #EP | Latency | RDMA bandwidth | +|:------------:|:-------:|:--------------:|:-----------:|:-------:|:--------------:| +| 8 | 77 us | 98 GB/s | 8 | 114 us | 127 GB/s | +| 16 | 118 us | 63 GB/s | 16 | 195 us | 74 GB/s | +| 32 | 155 us | 48 GB/s | 32 | 273 us | 53 GB/s | +| 64 | 173 us | 43 GB/s | 64 | 314 us | 46 GB/s | +| 128 | 192 us | 39 GB/s | 128 | 369 us | 39 GB/s | +| 256 | 194 us | 39 GB/s | 256 | 360 us | 40 GB/s | + +**News (2025.06.05)**: low-latency kernels now leverage NVLink as much as possible, see [#173](https://github.com/deepseek-ai/DeepEP/pull/173) for more details. Thanks for the contribution! + +## Quick start + +### Requirements + +- Ampere (SM80), Hopper (SM90) GPUs, or other architectures with SM90 PTX ISA support +- Python 3.8 and above +- CUDA version + - CUDA 11.0 and above for SM80 GPUs + - CUDA 12.3 and above for SM90 GPUs +- PyTorch 2.1 and above +- NVLink for intranode communication +- RDMA network for internode communication + +### Download and install NVSHMEM dependency + +DeepEP also depends on our modified NVSHMEM. Please refer to our [NVSHMEM Installation Guide](third-party/README.md) for instructions. + +### Development + +```bash +# Build and make symbolic links for SO files +NVSHMEM_DIR=/path/to/installed/nvshmem python setup.py build +# You may modify the specific SO names according to your own platform +ln -s build/lib.linux-x86_64-cpython-38/deep_ep_cpp.cpython-38-x86_64-linux-gnu.so + +# Run test cases +# NOTES: you may modify the `init_dist` function in `tests/utils.py` +# according to your own cluster settings, and launch into multiple nodes +python tests/test_intranode.py +python tests/test_internode.py +python tests/test_low_latency.py +``` + +### Installation + +```bash +NVSHMEM_DIR=/path/to/installed/nvshmem python setup.py install +``` + +#### Installation environment variables + +- `NVSHMEM_DIR`: the path to the NVSHMEM directory, disable all internode and low-latency features if not specified +- `DISABLE_SM90_FEATURES`: 0 or 1, whether to disable SM90 features, it is required for SM90 devices or CUDA 11 +- `TORCH_CUDA_ARCH_LIST`: the list of target architectures, e.g. `TORCH_CUDA_ARCH_LIST="9.0"` +- `DISABLE_AGGRESSIVE_PTX_INSTRS`: 0 or 1, whether to disable aggressive load/store instructions, see [Undefined-behavior PTX usage](#undefined-behavior-ptx-usage) for more details + +Then, import `deep_ep` in your Python project, and enjoy! + +## Network configurations + +DeepEP is fully tested with InfiniBand networks. However, it is theoretically compatible with RDMA over Converged Ethernet (RoCE) as well. + +### Traffic isolation + +Traffic isolation is supported by InfiniBand through Virtual Lanes (VL). + +To prevent interference between different types of traffic, we recommend segregating workloads across different virtual lanes as follows: + +- workloads using normal kernels +- workloads using low-latency kernels +- other workloads + +For DeepEP, you can control the virtual lane assignment by setting the `NVSHMEM_IB_SL` environment variable. + +### Adaptive routing + +Adaptive routing is an advanced routing feature provided by InfiniBand switches that can evenly distribute traffic across multiple paths. Enabling adaptive routing can completely eliminate network congestion caused by routing conflicts, but it also introduces additional latency. We recommend the following configuration for optimal performance: + +- enable adaptive routing in environments with heavy network loads +- use static routing in environments with light network loads + +### Congestion control + +Congestion control is disabled as we have not observed significant congestion in our production environment. + +## Interfaces and examples + +### Example use in model training or inference prefilling + +The normal kernels can be used in model training or the inference prefilling phase (without the backward part) as the below example code shows. + +```python +import torch +import torch.distributed as dist +from typing import List, Tuple, Optional, Union + +from deep_ep import Buffer, EventOverlap + +# Communication buffer (will allocate at runtime) +_buffer: Optional[Buffer] = None + +# Set the number of SMs to use +# NOTES: this is a static variable +Buffer.set_num_sms(24) + + +# You may call this function at the framework initialization +def get_buffer(group: dist.ProcessGroup, hidden_bytes: int) -> Buffer: + global _buffer + + # NOTES: you may also replace `get_*_config` with your auto-tuned results via all the tests + num_nvl_bytes, num_rdma_bytes = 0, 0 + for config in (Buffer.get_dispatch_config(group.size()), Buffer.get_combine_config(group.size())): + num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes) + num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes) + + # Allocate a buffer if not existed or not enough buffer size + if _buffer is None or _buffer.group != group or _buffer.num_nvl_bytes < num_nvl_bytes or _buffer.num_rdma_bytes < num_rdma_bytes: + _buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes) + return _buffer + + +def get_hidden_bytes(x: torch.Tensor) -> int: + t = x[0] if isinstance(x, tuple) else x + return t.size(1) * max(t.element_size(), 2) + + +def dispatch_forward(x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + topk_idx: torch.Tensor, topk_weights: torch.Tensor, + num_experts: int, previous_event: Optional[EventOverlap] = None) -> \ + Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple, EventOverlap]: + # NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency + # of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please + # refer to the docs of `Buffer.dispatch` + global _buffer + + # Calculate layout before actual dispatch + num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, previous_event = \ + _buffer.get_dispatch_layout(topk_idx, num_experts, + previous_event=previous_event, async_finish=True, + allocate_on_comm_stream=previous_event is not None) + # Do MoE dispatch + # NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible with CUDA graph + # Unless you specify `num_worst_tokens`, but this flag is for intranode only + # For more advanced usages, please refer to the docs of the `dispatch` function + recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \ + _buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights, + num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert, + previous_event=previous_event, async_finish=True, + allocate_on_comm_stream=True) + # For event management, please refer to the docs of the `EventOverlap` class + return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event + + +def dispatch_backward(grad_recv_x: torch.Tensor, grad_recv_topk_weights: torch.Tensor, handle: Tuple) -> \ + Tuple[torch.Tensor, torch.Tensor, EventOverlap]: + global _buffer + + # The backward process of MoE dispatch is actually a combine + # For more advanced usages, please refer to the docs of the `combine` function + combined_grad_x, combined_grad_recv_topk_weights, event = \ + _buffer.combine(grad_recv_x, handle, topk_weights=grad_recv_topk_weights, async_finish=True) + + # For event management, please refer to the docs of the `EventOverlap` class + return combined_grad_x, combined_grad_recv_topk_weights, event + + +def combine_forward(x: torch.Tensor, handle: Tuple, previous_event: Optional[EventOverlap] = None) -> \ + Tuple[torch.Tensor, EventOverlap]: + global _buffer + + # Do MoE combine + # For more advanced usages, please refer to the docs of the `combine` function + combined_x, _, event = _buffer.combine(x, handle, async_finish=True, previous_event=previous_event, + allocate_on_comm_stream=previous_event is not None) + + # For event management, please refer to the docs of the `EventOverlap` class + return combined_x, event + + +def combine_backward(grad_combined_x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + handle: Tuple, previous_event: Optional[EventOverlap] = None) -> \ + Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], EventOverlap]: + global _buffer + + # The backward process of MoE combine is actually a dispatch + # For more advanced usages, please refer to the docs of the `dispatch` function + grad_x, _, _, _, _, event = _buffer.dispatch(grad_combined_x, handle=handle, async_finish=True, + previous_event=previous_event, + allocate_on_comm_stream=previous_event is not None) + + # For event management, please refer to the docs of the `EventOverlap` class + return grad_x, event +``` + +Moreover, inside the dispatch function, we may not know how many tokens to receive for the current rank. So an implicit CPU wait for GPU received count signal will be involved, as the following figure shows. + +![normal](figures/normal.png) + +### Example use in inference decoding + +The low latency kernels can be used in the inference decoding phase as the below example code shows. + +```python +import torch +import torch.distributed as dist +from typing import Tuple, Optional + +from deep_ep import Buffer + +# Communication buffer (will allocate at runtime) +# NOTES: there is no SM control API for the low-latency kernels +_buffer: Optional[Buffer] = None + + +# You may call this function at the framework initialization +def get_buffer(group: dist.ProcessGroup, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> Buffer: + # NOTES: the low-latency mode will consume much more space than the normal mode + # So we recommend that `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256 + global _buffer + num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts) + + # Allocate a buffer if not existed or not enough buffer size + if _buffer is None or _buffer.group != group or not _buffer.low_latency_mode or _buffer.num_rdma_bytes < num_rdma_bytes: + # NOTES: for the best performance, the QP number **must** be equal to the number of the local experts + assert num_experts % group.size() == 0 + _buffer = Buffer(group, 0, num_rdma_bytes, low_latency_mode=True, num_qps_per_rank=num_experts // group.size()) + return _buffer + + +def low_latency_dispatch(hidden_states: torch.Tensor, topk_idx: torch.Tensor, num_max_dispatch_tokens_per_rank: int, num_experts: int): + global _buffer + + # Do MoE dispatch, compatible with CUDA graph (but you may restore some buffer status once you replay) + recv_hidden_states, recv_expert_count, handle, event, hook = \ + _buffer.low_latency_dispatch(hidden_states, topk_idx, num_max_dispatch_tokens_per_rank, num_experts, + async_finish=False, return_recv_hook=True) + + # NOTES: the actual tensor will not be received only if you call `hook()`, + # it is useful for double-batch overlapping, but **without any SM occupation** + # If you don't want to overlap, please set `return_recv_hook=False` + # Later, you can use our GEMM library to do the computation with this specific format + return recv_hidden_states, recv_expert_count, handle, event, hook + + +def low_latency_combine(hidden_states: torch.Tensor, + topk_idx: torch.Tensor, topk_weights: torch.Tensor, handle: Tuple): + global _buffer + + # Do MoE combine, compatible with CUDA graph (but you may restore some buffer status once you replay) + combined_hidden_states, event_overlap, hook = \ + _buffer.low_latency_combine(hidden_states, topk_idx, topk_weights, handle, + async_finish=False, return_recv_hook=True) + + # NOTES: the same behavior as described in the dispatch kernel + return combined_hidden_states, event_overlap, hook +``` + +For two-micro-batch overlapping, you can refer to the following figure. With our receiving hook interface, the RDMA network traffic is happening in the background, without costing any GPU SMs from the computation part. But notice, the overlapped parts can be adjusted, i.e., the 4 parts of attention/dispatch/MoE/combine may not have the exact same execution time. You may adjust the stage settings according to your workload. + +![low-latency](figures/low-latency.png) + +## Roadmap + +- [x] AR support +- [x] Refactor low-latency mode AR code +- [x] A100 support (intranode only) +- [x] Support BF16 for the low-latency dispatch kernel +- [x] Support NVLink protocol for intranode low-latency kernels +- [ ] TMA copy instead of LD/ST + - [x] Intranode kernels + - [ ] Internode kernels + - [ ] Low-latency kernels +- [ ] SM-free kernels and refactors +- [ ] Fully remove undefined-behavior PTX instructions + +## Notices + +#### Easier potential overall design + +The current DeepEP implementation uses queues for communication buffers which save memory but introduce complexity and potential deadlocks. If you're implementing your own version based on DeepEP, consider using fixed-size buffers allocated to maximum capacity for simplicity and better performance. For a detailed discussion of this alternative approach, see https://github.com/deepseek-ai/DeepEP/issues/39. + +#### Undefined-behavior PTX usage + +- For extreme performance, we discover and use an undefined-behavior PTX usage: using read-only PTX `ld.global.nc.L1::no_allocate.L2::256B` to **read volatile data**. The PTX modifier `.nc` indicates that a non-coherent cache is used. But the correctness is tested to be guaranteed with `.L1::no_allocate` on Hopper architectures, and performance will be much better. The reason we guess may be: the non-coherent cache is unified with L1, and the L1 modifier is not just a hint but a strong option, so that the correctness can be guaranteed by no dirty data in L1. +- Initially, because NVCC could not automatically unroll volatile read PTX, we tried using `__ldg` (i.e., `ld.nc`). Even compared to manually unrolled volatile reads, it was significantly faster (likely due to additional compiler optimizations). However, the results could be incorrect or dirty. After consulting the PTX documentation, we discovered that L1 and non-coherent cache are unified on Hopper architectures. We speculated that `.L1::no_allocate` might resolve the issue, leading to this discovery. +- If you find kernels not working on some other platforms, you may add `DISABLE_AGGRESSIVE_PTX_INSTRS=1` to `setup.py` and disable this, or file an issue. + +#### Auto-tuning on your cluster + +For better performance on your cluster, we recommend to run all the tests and use the best auto-tuned configuration. The default configurations are optimized on the DeepSeek's internal cluster. + +## License + +This code repository is released under [the MIT License](LICENSE), except for codes that reference NVSHMEM (including `csrc/kernels/ibgda_device.cuh` and `third-party/nvshmem.patch`), which are subject to [NVSHMEM SLA](https://docs.nvidia.com/nvshmem/api/sla.html). + +## Community Forks + +- [Infrawaves/DeepEP_ibrc_dual-ports_multiQP](https://github.com/Infrawaves/DeepEP_ibrc_dual-ports_multiQP) - Adds multi-QP solution and dual-port NIC support in IBRC transport + +## Citation + +If you use this codebase or otherwise find our work valuable, please cite: + +```bibtex +@misc{deepep2025, + title={DeepEP: an efficient expert-parallel communication library}, + author={Chenggang Zhao and Shangyan Zhou and Liyue Zhang and Chengqi Deng and Zhean Xu and Yuxuan Liu and Kuai Yu and Jiashi Li and Liang Zhao}, + year={2025}, + publisher = {GitHub}, + howpublished = {\url{https://github.com/deepseek-ai/DeepEP}}, +} +``` diff --git a/DeepEP/csrc/CMakeLists.txt b/DeepEP/csrc/CMakeLists.txt new file mode 100644 index 000000000..3f51c2713 --- /dev/null +++ b/DeepEP/csrc/CMakeLists.txt @@ -0,0 +1,36 @@ +# NOTES: this CMake is only for debugging; for setup, please use Torch extension +cmake_minimum_required(VERSION 3.10) +project(deep_ep LANGUAGES CUDA CXX) +set(CMAKE_VERBOSE_MAKEFILE ON) + +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC") +set(CUDA_SEPARABLE_COMPILATION ON) +list(APPEND CUDA_NVCC_FLAGS "-DENABLE_FAST_DEBUG") +list(APPEND CUDA_NVCC_FLAGS "-O3") +list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage") + +set(USE_SYSTEM_NVTX on) +set(CUDA_ARCH_LIST "9.0" CACHE STRING "List of CUDA architectures to compile") +set(TORCH_CUDA_ARCH_LIST "${CUDA_ARCH_LIST}") + +find_package(CUDAToolkit REQUIRED) +find_package(pybind11 REQUIRED) +find_package(Torch REQUIRED) +find_package(NVSHMEM REQUIRED HINTS ${NVSHMEM_ROOT_DIR}/lib/cmake/nvshmem) + +add_library(nvshmem ALIAS nvshmem::nvshmem) +add_library(nvshmem_host ALIAS nvshmem::nvshmem_host) +add_library(nvshmem_device ALIAS nvshmem::nvshmem_device) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CUDA_STANDARD 17) + +include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS} ${NVSHMEM_INCLUDE_DIR}) +link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib ${NVSHMEM_LIB_DIR}) + +add_subdirectory(kernels) + +# Link CPP and CUDA together +pybind11_add_module(deep_ep_cpp deep_ep.cpp) +target_link_libraries(deep_ep_cpp PRIVATE ${EP_CUDA_LIBRARIES} ${TORCH_LIBRARIES} torch_python) diff --git a/DeepEP/csrc/config.hpp b/DeepEP/csrc/config.hpp new file mode 100644 index 000000000..8a1a8ba4c --- /dev/null +++ b/DeepEP/csrc/config.hpp @@ -0,0 +1,188 @@ +#pragma once + +#include "kernels/api.cuh" +#include "kernels/exception.cuh" + +namespace deep_ep { + +template +dtype_t ceil_div(dtype_t a, dtype_t b) { + return (a + b - 1) / b; +} + +template +dtype_t align(dtype_t a, dtype_t b) { + return ceil_div(a, b) * b; +} + +struct Config { + int num_sms; + int num_max_nvl_chunked_send_tokens; + int num_max_nvl_chunked_recv_tokens; + int num_max_rdma_chunked_send_tokens; + int num_max_rdma_chunked_recv_tokens; + + Config(int num_sms, + int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, + int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens) : + num_sms(num_sms), + num_max_nvl_chunked_send_tokens(num_max_nvl_chunked_send_tokens), + num_max_nvl_chunked_recv_tokens(num_max_nvl_chunked_recv_tokens), + num_max_rdma_chunked_send_tokens(num_max_rdma_chunked_send_tokens), + num_max_rdma_chunked_recv_tokens(num_max_rdma_chunked_recv_tokens) { + EP_HOST_ASSERT(num_sms >= 0); + EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens > 0 and num_max_nvl_chunked_recv_tokens > 0); + EP_HOST_ASSERT(num_max_nvl_chunked_send_tokens < num_max_nvl_chunked_recv_tokens); + EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens > 0 and num_max_rdma_chunked_recv_tokens > 0); + + // Ceil up RDMA buffer size + this->num_max_rdma_chunked_recv_tokens = align(num_max_rdma_chunked_recv_tokens, num_max_rdma_chunked_send_tokens); + EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens < num_max_rdma_chunked_recv_tokens); + // NOTES: this assertion is related to RDMA lazy head update, we must ensure senders always have space to push + EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <= num_max_rdma_chunked_recv_tokens / 2); + } + + size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const { + // Below are some assumptions + // TODO: add assertions + constexpr int kNumMaxTopK = 128; + constexpr int kNumMaxScales = 128; + EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0); + EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0); + const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1); + const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); + const int num_channels = num_sms / 2; + + size_t num_bytes = 0; + num_bytes += num_channels * num_nvl_ranks * (2 * num_rdma_ranks + 3) * sizeof(int); + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * hidden_bytes; +#ifndef DISABLE_NVSHMEM + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * internode::get_source_meta_bytes(); +#endif + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t); + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxTopK * sizeof(float); + num_bytes += num_channels * num_nvl_ranks * num_max_nvl_chunked_recv_tokens * kNumMaxScales * sizeof(float); + num_bytes = ((num_bytes + 127) / 128) * 128; + return num_bytes; + } + + size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const { +#ifndef DISABLE_NVSHMEM + // Legacy mode + if (num_ranks <= NUM_MAX_NVL_PEERS) + return 0; + + // Below are some assumptions + // TODO: add assertions + constexpr int kNumMaxTopK = 128; + constexpr int kNumMaxScales = 128; + EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0); + EP_HOST_ASSERT(num_sms % 2 == 0); + const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + const int num_channels = num_sms / 2; + + size_t num_bytes = 0; + num_bytes += num_channels * num_rdma_ranks * (NUM_MAX_NVL_PEERS * 2 + 2) * 2 * sizeof(int); + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * hidden_bytes * 2; + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * internode::get_source_meta_bytes() * 2; + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(int64_t) * 2; + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxTopK * sizeof(float) * 2; + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * kNumMaxScales * sizeof(float) * 2; + num_bytes += num_channels * num_rdma_ranks * num_max_rdma_chunked_recv_tokens * sizeof(int4) * 2; + num_bytes = ((num_bytes + 127) / 128) * 128; + return num_bytes; +#else + EP_HOST_ASSERT(false and "NVSHMEM is disable during compilation"); +#endif + } +}; + +struct LowLatencyBuffer { + int num_clean_int = 0; + + void* dispatch_rdma_send_buffer = nullptr; + void* dispatch_rdma_recv_data_buffer = nullptr; + int* dispatch_rdma_recv_count_buffer = nullptr; + + void* combine_rdma_send_buffer = nullptr; + void* combine_rdma_recv_data_buffer = nullptr; + int* combine_rdma_recv_flag_buffer = nullptr; + + void* combine_rdma_send_buffer_data_start = nullptr; + size_t num_bytes_per_combine_msg = 0; + + std::pair clean_meta() { + EP_HOST_ASSERT(dispatch_rdma_recv_count_buffer == combine_rdma_recv_flag_buffer); + return {dispatch_rdma_recv_count_buffer, num_clean_int}; + } +}; + +struct LowLatencyLayout { + size_t total_bytes = 0; + LowLatencyBuffer buffers[2]; + + template + out_ptr_t advance(const in_ptr_t& ptr, size_t count) { + return reinterpret_cast(reinterpret_cast(ptr) + count); + } + + LowLatencyLayout(void* rdma_buffer, int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { + const int num_scales = hidden / 128; + + // Dispatch and combine layout: + // - 2 symmetric odd/even send buffer + // - 2 symmetric odd/even receive buffers + // - 2 symmetric odd/even signaling buffers + + // Message sizes + // NOTES: you should add a control `int4` for combine messages if you want to do data transformation + EP_HOST_ASSERT(num_scales * sizeof(float) <= hidden); + size_t num_bytes_per_dispatch_msg = sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16), hidden + num_scales * sizeof(float)); + size_t num_bytes_per_combine_msg = hidden * sizeof(nv_bfloat16); + + // Send buffer + size_t dispatch_send_buffer_bytes = num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; + size_t combine_send_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; + size_t send_buffer_bytes = std::max(dispatch_send_buffer_bytes, combine_send_buffer_bytes); + EP_HOST_ASSERT(send_buffer_bytes % sizeof(int4) == 0); + total_bytes += send_buffer_bytes * 2; + + // Symmetric receive buffers + // TODO: optimize memory usages + size_t dispatch_recv_data_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_dispatch_msg; + size_t combine_recv_buffer_bytes = num_experts * num_max_dispatch_tokens_per_rank * num_bytes_per_combine_msg; + size_t recv_buffer_bytes = std::max(dispatch_recv_data_buffer_bytes, combine_recv_buffer_bytes); + EP_HOST_ASSERT(recv_buffer_bytes % sizeof(int4) == 0); + total_bytes += recv_buffer_bytes * 2; + + // Symmetric signaling buffers + size_t dispatch_recv_count_buffer_bytes = num_experts * sizeof(int); + size_t combine_recv_flag_buffer_bytes = dispatch_recv_count_buffer_bytes; + size_t signaling_buffer_bytes = std::max(dispatch_recv_count_buffer_bytes, combine_recv_flag_buffer_bytes); + total_bytes += signaling_buffer_bytes * 2; + + // Assign pointers + // NOTES: we still leave some space for distinguishing dispatch/combine buffer, + // so you may see some parameters are duplicated + for (int i = 0; i < 2; ++ i) { + buffers[i] = { + static_cast(signaling_buffer_bytes / sizeof(int)), + advance(rdma_buffer, send_buffer_bytes * i), + advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i), + advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i), + advance(rdma_buffer, send_buffer_bytes * i), + advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * i), + advance(rdma_buffer, send_buffer_bytes * 2 + recv_buffer_bytes * 2 + signaling_buffer_bytes * i), + advance(rdma_buffer, send_buffer_bytes * i), + num_bytes_per_combine_msg + }; + } + } +}; + +size_t get_low_latency_rdma_size_hint(int num_max_dispatch_tokens_per_rank, int hidden, int num_ranks, int num_experts) { + auto num_bytes = LowLatencyLayout(nullptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts).total_bytes; + return ((num_bytes + NUM_BUFFER_ALIGNMENT_BYTES) / NUM_BUFFER_ALIGNMENT_BYTES) * NUM_BUFFER_ALIGNMENT_BYTES; +} + +} // namespace deep_ep diff --git a/DeepEP/csrc/deep_ep.cpp b/DeepEP/csrc/deep_ep.cpp new file mode 100644 index 000000000..e918adc1b --- /dev/null +++ b/DeepEP/csrc/deep_ep.cpp @@ -0,0 +1,1347 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "deep_ep.hpp" +#include "kernels/api.cuh" +#include "kernels/configs.cuh" + +namespace deep_ep { + +Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode): + rank(rank), num_ranks(num_ranks), + num_nvl_bytes(num_nvl_bytes), num_rdma_bytes(num_rdma_bytes), + low_latency_mode(low_latency_mode), + comm_stream(at::cuda::getStreamFromPool(true)) { + // Metadata memory + int64_t barrier_signal_bytes = NUM_MAX_NVL_PEERS * sizeof(int); + int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void*); + int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int*); + + // Common checks + EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (num_nvl_bytes <= std::numeric_limits::max() or num_rdma_bytes == 0)); + EP_HOST_ASSERT(num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (low_latency_mode or num_rdma_bytes <= std::numeric_limits::max())); + EP_HOST_ASSERT(0 <= rank and rank < num_ranks and (num_ranks <= NUM_MAX_NVL_PEERS * NUM_MAX_RDMA_PEERS or low_latency_mode)); + EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0); + if (num_rdma_bytes > 0) + EP_HOST_ASSERT(num_ranks > NUM_MAX_NVL_PEERS or low_latency_mode); + + // Get ranks + CUDA_CHECK(cudaGetDevice(&device_id)); + rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; + num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS), num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); +#ifdef DISABLE_NVSHMEM + EP_HOST_ASSERT(num_rdma_ranks == 1 and not low_latency_mode and "NVSHMEM is disabled during compilation"); +#endif + + // Get device info + cudaDeviceProp device_prop = {}; + CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id)); + num_device_sms = device_prop.multiProcessorCount; + + if (num_nvl_bytes > 0) { + // Local IPC: alloc local memory and set local IPC handles + CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes)); + CUDA_CHECK(cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank])); + buffer_ptrs_gpu = reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes); + + // Set barrier signals + barrier_signal_ptrs[nvl_rank] = reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes); + barrier_signal_ptrs_gpu = reinterpret_cast(static_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes); + + // No need to synchronize, will do a full device sync during `sync` + CUDA_CHECK(cudaMemsetAsync(barrier_signal_ptrs[nvl_rank], 0, barrier_signal_bytes, comm_stream)); + } + + // Create 32 MiB workspace + CUDA_CHECK(cudaMalloc(&workspace, NUM_WORKSPACE_BYTES)); + CUDA_CHECK(cudaMemsetAsync(workspace, 0, NUM_WORKSPACE_BYTES, comm_stream)); + + // MoE counter + CUDA_CHECK(cudaMallocHost(&moe_recv_counter, sizeof(int64_t), cudaHostAllocMapped)); + CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_counter_mapped, const_cast(moe_recv_counter), 0)); + *moe_recv_counter = -1; + + // MoE expert-level counter + CUDA_CHECK(cudaMallocHost(&moe_recv_expert_counter, sizeof(int) * NUM_MAX_LOCAL_EXPERTS, cudaHostAllocMapped)); + CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_expert_counter_mapped, const_cast(moe_recv_expert_counter), 0)); + for (int i = 0; i < NUM_MAX_LOCAL_EXPERTS; ++ i) + moe_recv_expert_counter[i] = -1; + + // MoE RDMA-level counter + if (num_rdma_ranks > 0) { + CUDA_CHECK(cudaMallocHost(&moe_recv_rdma_counter, sizeof(int), cudaHostAllocMapped)); + CUDA_CHECK(cudaHostGetDevicePointer(&moe_recv_rdma_counter_mapped, const_cast(moe_recv_rdma_counter), 0)); + *moe_recv_rdma_counter = -1; + } +} + +Buffer::~Buffer() noexcept(false) { + // Synchronize + CUDA_CHECK(cudaDeviceSynchronize()); + + if (num_nvl_bytes > 0) { + // Barrier + intranode::barrier(barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream); + CUDA_CHECK(cudaDeviceSynchronize()); + + // Close remote IPC + if (is_available()) { + for (int i = 0; i < num_nvl_ranks; ++ i) if (i != nvl_rank) + CUDA_CHECK(cudaIpcCloseMemHandle(buffer_ptrs[i])); + } + + // Free local buffer and error flag + CUDA_CHECK(cudaFree(buffer_ptrs[nvl_rank])); + } + + // Free NVSHMEM +#ifndef DISABLE_NVSHMEM + if (num_rdma_bytes > 0) { + CUDA_CHECK(cudaDeviceSynchronize()); + internode::barrier(); + internode::free(rdma_buffer_ptr); + internode::finalize(); + } +#endif + + // Free workspace and MoE counter + CUDA_CHECK(cudaFree(workspace)); + CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_counter))); + + // Free chunked mode staffs + CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_expert_counter))); +} + +bool Buffer::is_available() const { + return available; +} + +bool Buffer::is_internode_available() const { + return is_available() and num_ranks > NUM_MAX_NVL_PEERS; +} + +int Buffer::get_num_rdma_ranks() const { + return num_rdma_ranks; +} + +int Buffer::get_rdma_rank() const { + return rdma_rank; +} + +int Buffer::get_root_rdma_rank(bool global) const { + return global ? nvl_rank : 0; +} + +int Buffer::get_local_device_id() const { + return device_id; +} + +pybind11::bytearray Buffer::get_local_ipc_handle() const { + return {ipc_handles[nvl_rank].reserved, CUDA_IPC_HANDLE_SIZE}; +} + +pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const { +#ifndef DISABLE_NVSHMEM + EP_HOST_ASSERT(rdma_rank == 0 and "Only RDMA rank 0 can get NVSHMEM unique ID"); + auto unique_id = internode::get_unique_id(); + return {reinterpret_cast(unique_id.data()), unique_id.size()}; +#else + EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); +#endif +} + +torch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const { + torch::ScalarType casted_dtype = torch::python::detail::py_object_to_dtype(dtype); + auto element_bytes = static_cast(elementSize(casted_dtype)); + auto base_ptr = static_cast(use_rdma_buffer ? rdma_buffer_ptr : buffer_ptrs[nvl_rank]) + offset; + auto num_bytes = use_rdma_buffer ? num_rdma_bytes : num_nvl_bytes; + return torch::from_blob(base_ptr, num_bytes / element_bytes, torch::TensorOptions().dtype(casted_dtype).device(at::kCUDA)); +} + +torch::Stream Buffer::get_comm_stream() const { + return comm_stream; +} + +void Buffer::sync(const std::vector &device_ids, + const std::vector> &all_gathered_handles, + const std::optional& root_unique_id_opt) { + EP_HOST_ASSERT(not is_available()); + + // Sync IPC handles + if (num_nvl_bytes > 0) { + EP_HOST_ASSERT(num_ranks == device_ids.size()); + EP_HOST_ASSERT(device_ids.size() == all_gathered_handles.size()); + for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; ++ i) { + EP_HOST_ASSERT(all_gathered_handles[offset + i].has_value()); + auto handle_str = std::string(all_gathered_handles[offset + i].value()); + EP_HOST_ASSERT(handle_str.size() == CUDA_IPC_HANDLE_SIZE); + if (offset + i != rank) { + std::memcpy(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE); + CUDA_CHECK(cudaIpcOpenMemHandle(&buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess)); + barrier_signal_ptrs[i] = reinterpret_cast(static_cast(buffer_ptrs[i]) + num_nvl_bytes); + } else { + EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE) == 0); + } + } + + // Copy all buffer and barrier signal pointers to GPU + CUDA_CHECK(cudaMemcpy(buffer_ptrs_gpu, buffer_ptrs, sizeof(void*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(barrier_signal_ptrs_gpu, barrier_signal_ptrs, sizeof(int*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaDeviceSynchronize()); + } + + // Sync NVSHMEM handles and allocate memory +#ifndef DISABLE_NVSHMEM + if (num_rdma_bytes > 0) { + // Initialize NVSHMEM + EP_HOST_ASSERT(root_unique_id_opt.has_value()); + std::vector root_unique_id(root_unique_id_opt->size()); + auto root_unique_id_str = root_unique_id_opt->cast(); + std::memcpy(root_unique_id.data(), root_unique_id_str.c_str(), root_unique_id_opt->size()); + auto nvshmem_rank = low_latency_mode ? rank : rdma_rank; + auto num_nvshmem_ranks = low_latency_mode ? num_ranks : num_rdma_ranks; + EP_HOST_ASSERT(nvshmem_rank == internode::init(root_unique_id, nvshmem_rank, num_nvshmem_ranks, low_latency_mode)); + internode::barrier(); + + // Allocate + rdma_buffer_ptr = internode::alloc(num_rdma_bytes, NUM_BUFFER_ALIGNMENT_BYTES); + + // Clean buffer (mainly for low-latency mode) + CUDA_CHECK(cudaMemset(rdma_buffer_ptr, 0, num_rdma_bytes)); + + // Barrier + internode::barrier(); + CUDA_CHECK(cudaDeviceSynchronize()); + } +#endif + + // Ready to use + available = true; +} + +std::tuple, torch::Tensor, torch::Tensor, std::optional> +Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, + std::optional& previous_event, bool async, bool allocate_on_comm_stream) { + EP_HOST_ASSERT(topk_idx.dim() == 2); + EP_HOST_ASSERT(topk_idx.is_contiguous()); + EP_HOST_ASSERT(num_experts > 0); + + // Allocate all tensors on comm stream if set + // NOTES: do not allocate tensors upfront! + auto compute_stream = at::cuda::getCurrentCUDAStream(); + if (allocate_on_comm_stream) { + EP_HOST_ASSERT(previous_event.has_value() and async); + at::cuda::setCurrentCUDAStream(comm_stream); + } + + // Wait previous tasks to be finished + if (previous_event.has_value()) { + stream_wait(comm_stream, previous_event.value()); + } else { + stream_wait(comm_stream, compute_stream); + } + + auto num_tokens = static_cast(topk_idx.size(0)), num_topk = static_cast(topk_idx.size(1)); + auto num_tokens_per_rank = torch::empty({num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); + auto num_tokens_per_rdma_rank = std::optional(); + auto num_tokens_per_expert = torch::empty({num_experts}, dtype(torch::kInt32).device(torch::kCUDA)); + auto is_token_in_rank = torch::empty({num_tokens, num_ranks}, dtype(torch::kBool).device(torch::kCUDA)); + if (is_internode_available()) + num_tokens_per_rdma_rank = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); + + layout::get_dispatch_layout(topk_idx.data_ptr(), + num_tokens_per_rank.data_ptr(), + num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data_ptr() : nullptr, + num_tokens_per_expert.data_ptr(), + is_token_in_rank.data_ptr(), + num_tokens, num_topk, num_ranks, num_experts, + comm_stream); + + // Wait streams + std::optional event; + if (async) { + event = EventHandle(comm_stream); + for (auto& t: {topk_idx, num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank}) { + t.record_stream(comm_stream); + if (allocate_on_comm_stream) + t.record_stream(compute_stream); + } + for (auto& to: {num_tokens_per_rdma_rank}) { + to.has_value() ? to->record_stream(comm_stream) : void(); + if (allocate_on_comm_stream) + to.has_value() ? to->record_stream(compute_stream) : void(); + } + } else { + stream_wait(compute_stream, comm_stream); + } + + // Switch back compute stream + if (allocate_on_comm_stream) + at::cuda::setCurrentCUDAStream(compute_stream); + + return {num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event}; +} + +std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional> +Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional& x_scales, + const std::optional& topk_idx, const std::optional& topk_weights, + const std::optional& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, + int cached_num_recv_tokens, const std::optional& cached_rank_prefix_matrix, const std::optional& cached_channel_prefix_matrix, + int expert_alignment, int num_worst_tokens, const Config& config, + std::optional& previous_event, bool async, bool allocate_on_comm_stream) { + bool cached_mode = cached_rank_prefix_matrix.has_value(); + + // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving. + EP_HOST_ASSERT(config.num_sms % 2 == 0); + int num_channels = config.num_sms / 2; + if (cached_mode) { + EP_HOST_ASSERT(cached_rank_prefix_matrix.has_value()); + EP_HOST_ASSERT(cached_channel_prefix_matrix.has_value()); + } else { + EP_HOST_ASSERT(num_tokens_per_rank.has_value()); + EP_HOST_ASSERT(num_tokens_per_expert.has_value()); + } + + // Type checks + EP_HOST_ASSERT(is_token_in_rank.scalar_type() == torch::kBool); + if (cached_mode) { + EP_HOST_ASSERT(cached_rank_prefix_matrix->scalar_type() == torch::kInt32); + EP_HOST_ASSERT(cached_channel_prefix_matrix->scalar_type() == torch::kInt32); + } else { + EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == torch::kInt32); + EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == torch::kInt32); + } + + // Shape and contiguous checks + EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); + EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0); + EP_HOST_ASSERT(is_token_in_rank.dim() == 2 and is_token_in_rank.is_contiguous()); + EP_HOST_ASSERT(is_token_in_rank.size(0) == x.size(0) and is_token_in_rank.size(1) == num_ranks); + if (cached_mode) { + EP_HOST_ASSERT(cached_rank_prefix_matrix->dim() == 2 and cached_rank_prefix_matrix->is_contiguous()); + EP_HOST_ASSERT(cached_rank_prefix_matrix->size(0) == num_ranks and cached_rank_prefix_matrix->size(1) == num_ranks); + EP_HOST_ASSERT(cached_channel_prefix_matrix->dim() == 2 and cached_channel_prefix_matrix->is_contiguous()); + EP_HOST_ASSERT(cached_channel_prefix_matrix->size(0) == num_ranks and cached_channel_prefix_matrix->size(1) == num_channels); + } else { + EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 and num_tokens_per_expert->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0); + EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS); + EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 and num_tokens_per_rank->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks); + } + + auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); + auto num_experts = cached_mode ? 0 : static_cast(num_tokens_per_expert->size(0)), num_local_experts = num_experts / num_ranks; + + // Top-k checks + int num_topk = 0; + int64_t* topk_idx_ptr = nullptr; + float* topk_weights_ptr = nullptr; + EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value()); + if (topk_idx.has_value()) { + num_topk = static_cast(topk_idx->size(1)); + EP_HOST_ASSERT(num_experts > 0); + EP_HOST_ASSERT(topk_idx->dim() == 2 and topk_idx->is_contiguous()); + EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); + EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0)); + EP_HOST_ASSERT(num_topk == topk_weights->size(1)); + EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); + topk_idx_ptr = topk_idx->data_ptr(); + topk_weights_ptr = topk_weights->data_ptr(); + } + + // FP8 scales checks + float* x_scales_ptr = nullptr; + int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0; + if (x_scales.has_value()) { + EP_HOST_ASSERT(x.element_size() == 1); + EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32 or x_scales->scalar_type() == torch::kInt); + EP_HOST_ASSERT(x_scales->dim() == 2); + EP_HOST_ASSERT(x_scales->size(0) == num_tokens); + num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); + x_scales_ptr = static_cast(x_scales->data_ptr()); + scale_token_stride = static_cast(x_scales->stride(0)); + scale_hidden_stride = static_cast(x_scales->stride(1)); + } + + // Allocate all tensors on comm stream if set + // NOTES: do not allocate tensors upfront! + auto compute_stream = at::cuda::getCurrentCUDAStream(); + if (allocate_on_comm_stream) { + EP_HOST_ASSERT(previous_event.has_value() and async); + at::cuda::setCurrentCUDAStream(comm_stream); + } + + // Wait previous tasks to be finished + if (previous_event.has_value()) { + stream_wait(comm_stream, previous_event.value()); + } else { + stream_wait(comm_stream, compute_stream); + } + + // Create handles (only return for non-cached mode) + int num_recv_tokens = -1; + auto rank_prefix_matrix = torch::Tensor(); + auto channel_prefix_matrix = torch::Tensor(); + std::vector num_recv_tokens_per_expert_list; + + // Barrier or send sizes + // To clean: channel start/end offset, head and tail + int num_memset_int = num_channels * num_ranks * 4; + if (cached_mode) { + num_recv_tokens = cached_num_recv_tokens; + rank_prefix_matrix = cached_rank_prefix_matrix.value(); + channel_prefix_matrix = cached_channel_prefix_matrix.value(); + + // Copy rank prefix matrix and clean flags + intranode::cached_notify_dispatch(rank_prefix_matrix.data_ptr(), num_memset_int, + buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank, num_ranks, + comm_stream); + } else { + rank_prefix_matrix = torch::empty({num_ranks, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); + channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); + + // Send sizes + // Meta information: + // - Size prefix by ranks, shaped as `[num_ranks, num_ranks]` + // - Size prefix by experts (not used later), shaped as `[num_ranks, num_local_experts]` + // NOTES: no more token dropping in this version + *moe_recv_counter = -1; + for (int i = 0; i < num_local_experts; ++ i) + moe_recv_expert_counter[i] = -1; + EP_HOST_ASSERT(num_ranks * (num_ranks + num_local_experts) * sizeof(int) <= num_nvl_bytes); + intranode::notify_dispatch(num_tokens_per_rank->data_ptr(), moe_recv_counter_mapped, num_ranks, + num_tokens_per_expert->data_ptr(), moe_recv_expert_counter_mapped, num_experts, + num_tokens, is_token_in_rank.data_ptr(), channel_prefix_matrix.data_ptr(), + rank_prefix_matrix.data_ptr(), + num_memset_int, expert_alignment, + buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank, + comm_stream, num_channels); + + if (num_worst_tokens > 0) { + // No CPU sync, just allocate the worst case + num_recv_tokens = num_worst_tokens; + + // Must be forward with top-k stuffs + EP_HOST_ASSERT(topk_idx.has_value()); + EP_HOST_ASSERT(topk_weights.has_value()); + } else { + // Synchronize total received tokens and tokens per expert + auto start_time = std::chrono::high_resolution_clock::now(); + while (true) { + // Read total count + num_recv_tokens = static_cast(*moe_recv_counter); + + // Read per-expert count + bool ready = (num_recv_tokens >= 0); + for (int i = 0; i < num_local_experts and ready; ++i) + ready &= moe_recv_expert_counter[i] >= 0; + + if (ready) + break; + + // Timeout check + if (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS) + throw std::runtime_error("DeepEP error: CPU recv timeout"); + } + num_recv_tokens_per_expert_list = std::vector(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); + } + } + + // Allocate new tensors + auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); + auto recv_src_idx = torch::empty({num_recv_tokens}, dtype(torch::kInt32).device(torch::kCUDA)); + auto recv_topk_idx = std::optional(), recv_topk_weights = std::optional(), recv_x_scales = std::optional(); + auto recv_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); + auto send_head = torch::empty({num_tokens, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); + + // Assign pointers + int64_t* recv_topk_idx_ptr = nullptr; + float* recv_topk_weights_ptr = nullptr; + float* recv_x_scales_ptr = nullptr; + if (topk_idx.has_value()) { + recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options()); + recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); + recv_topk_idx_ptr = recv_topk_idx->data_ptr(); + recv_topk_weights_ptr = recv_topk_weights->data_ptr(); + } + if (x_scales.has_value()) { + recv_x_scales = x_scales->dim() == 1 ? + torch::empty({num_recv_tokens}, x_scales->options()) : + torch::empty({num_recv_tokens, num_scales}, x_scales->options()); + recv_x_scales_ptr = static_cast(recv_x_scales->data_ptr()); + } + + // Dispatch + EP_HOST_ASSERT(num_ranks * num_ranks * sizeof(int) + // Size prefix matrix + num_channels * num_ranks * sizeof(int) + // Channel start offset + num_channels * num_ranks * sizeof(int) + // Channel end offset + num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * recv_x.element_size() + // Data buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(int64_t) + // Top-k index buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) + // Top-k weight buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(float) * num_scales // FP8 scale buffer + <= num_nvl_bytes); + intranode::dispatch(recv_x.data_ptr(), recv_x_scales_ptr, recv_src_idx.data_ptr(), recv_topk_idx_ptr, recv_topk_weights_ptr, recv_channel_prefix_matrix.data_ptr(), + send_head.data_ptr(), + x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr, + is_token_in_rank.data_ptr(), channel_prefix_matrix.data_ptr(), + num_tokens, num_worst_tokens, static_cast(hidden * recv_x.element_size() / sizeof(int4)), + num_topk, num_experts, num_scales, + scale_token_stride, scale_hidden_stride, + buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms, + config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens); + + // Wait streams + std::optional event; + if (async) { + event = EventHandle(comm_stream); + for (auto& t: {x, is_token_in_rank, rank_prefix_matrix, channel_prefix_matrix, recv_x, recv_src_idx, recv_channel_prefix_matrix, send_head}) { + t.record_stream(comm_stream); + if (allocate_on_comm_stream) + t.record_stream(compute_stream); + } + for (auto& to: {x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_expert, cached_channel_prefix_matrix, cached_rank_prefix_matrix, recv_topk_idx, recv_topk_weights, recv_x_scales}) { + to.has_value() ? to->record_stream(comm_stream) : void(); + if (allocate_on_comm_stream) + to.has_value() ? to->record_stream(compute_stream) : void(); + } + } else { + stream_wait(compute_stream, comm_stream); + } + + // Switch back compute stream + if (allocate_on_comm_stream) + at::cuda::setCurrentCUDAStream(compute_stream); + + // Return values + return {recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event}; +} + +std::tuple, std::optional> +Buffer::intranode_combine(const torch::Tensor& x, const std::optional& topk_weights, + const std::optional& bias_0, const std::optional& bias_1, + const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix, + const torch::Tensor& send_head, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { + EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); + EP_HOST_ASSERT(src_idx.dim() == 1 and src_idx.is_contiguous() and src_idx.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(send_head.dim() == 2 and send_head.is_contiguous() and send_head.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(rank_prefix_matrix.dim() == 2 and rank_prefix_matrix.is_contiguous() and rank_prefix_matrix.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(channel_prefix_matrix.dim() == 2 and channel_prefix_matrix.is_contiguous() and channel_prefix_matrix.scalar_type() == torch::kInt32); + + // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving. + EP_HOST_ASSERT(config.num_sms % 2 == 0); + int num_channels = config.num_sms / 2; + + auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); + auto num_recv_tokens = static_cast(send_head.size(0)); + EP_HOST_ASSERT(src_idx.size(0) == num_tokens); + EP_HOST_ASSERT(send_head.size(1) == num_ranks); + EP_HOST_ASSERT(rank_prefix_matrix.size(0) == num_ranks and rank_prefix_matrix.size(1) == num_ranks); + EP_HOST_ASSERT(channel_prefix_matrix.size(0) == num_ranks and channel_prefix_matrix.size(1) == num_channels); + EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0); + + // Allocate all tensors on comm stream if set + // NOTES: do not allocate tensors upfront! + auto compute_stream = at::cuda::getCurrentCUDAStream(); + if (allocate_on_comm_stream) { + EP_HOST_ASSERT(previous_event.has_value() and async); + at::cuda::setCurrentCUDAStream(comm_stream); + } + + // Wait previous tasks to be finished + if (previous_event.has_value()) { + stream_wait(comm_stream, previous_event.value()); + } else { + stream_wait(comm_stream, compute_stream); + } + + int num_topk = 0; + auto recv_topk_weights = std::optional(); + float* topk_weights_ptr = nullptr; + float* recv_topk_weights_ptr = nullptr; + if (topk_weights.has_value()) { + EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); + EP_HOST_ASSERT(topk_weights->size(0) == num_tokens); + EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); + num_topk = static_cast(topk_weights->size(1)); + topk_weights_ptr = topk_weights->data_ptr(); + recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); + recv_topk_weights_ptr = recv_topk_weights->data_ptr(); + } + + // Launch barrier and reset queue head and tail + EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 <= num_nvl_bytes); + intranode::cached_notify_combine(buffer_ptrs_gpu, send_head.data_ptr(), + num_channels, num_recv_tokens, num_channels * num_ranks * 2, + barrier_signal_ptrs_gpu, rank, num_ranks, + comm_stream); + + // Assign bias pointers + auto bias_opts = std::vector>({bias_0, bias_1}); + void* bias_ptrs[2] = {nullptr, nullptr}; + for (int i = 0; i < 2; ++ i) if (bias_opts[i].has_value()) { + auto bias = bias_opts[i].value(); + EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous()); + EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type()); + EP_HOST_ASSERT(bias.size(0) == num_recv_tokens and bias.size(1) == hidden); + bias_ptrs[i] = bias.data_ptr(); + } + + // Combine data + auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); + EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * x.element_size() + // Data buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) // Top-k weight buffer + <= num_nvl_bytes); + intranode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()), + recv_x.data_ptr(), recv_topk_weights_ptr, + x.data_ptr(), topk_weights_ptr, bias_ptrs[0], bias_ptrs[1], + src_idx.data_ptr(), rank_prefix_matrix.data_ptr(), channel_prefix_matrix.data_ptr(), + send_head.data_ptr(), num_tokens, num_recv_tokens, hidden, num_topk, + buffer_ptrs_gpu, rank, num_ranks, + comm_stream, config.num_sms, + config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens); + + // Wait streams + std::optional event; + if (async) { + event = EventHandle(comm_stream); + for (auto& t: {x, src_idx, send_head, rank_prefix_matrix, channel_prefix_matrix, recv_x}) { + t.record_stream(comm_stream); + if (allocate_on_comm_stream) + t.record_stream(compute_stream); + } + for (auto& to: {topk_weights, recv_topk_weights, bias_0, bias_1}) { + to.has_value() ? to->record_stream(comm_stream) : void(); + if (allocate_on_comm_stream) + to.has_value() ? to->record_stream(compute_stream) : void(); + } + } else { + stream_wait(compute_stream, comm_stream); + } + + // Switch back compute stream + if (allocate_on_comm_stream) + at::cuda::setCurrentCUDAStream(compute_stream); + + return {recv_x, recv_topk_weights, event}; +} + +std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, std::optional, torch::Tensor, std::optional, torch::Tensor, std::optional, std::optional, std::optional, std::optional> +Buffer::internode_dispatch(const torch::Tensor& x, const std::optional& x_scales, + const std::optional& topk_idx, const std::optional& topk_weights, + const std::optional& num_tokens_per_rank, const std::optional& num_tokens_per_rdma_rank, + const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, + int cached_num_recv_tokens, int cached_num_rdma_recv_tokens, + const std::optional& cached_rdma_channel_prefix_matrix, const std::optional& cached_recv_rdma_rank_prefix_sum, + const std::optional& cached_gbl_channel_prefix_matrix, const std::optional& cached_recv_gbl_rank_prefix_sum, + int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { +#ifndef DISABLE_NVSHMEM + // In dispatch, CPU will busy-wait until GPU receive tensor size metadata from other ranks, which can be quite long. + // If users of DeepEP need to execute other Python code on other threads, such as KV transfer, their code will get stuck due to GIL + // unless we release GIL here. + pybind11::gil_scoped_release release; + + const int num_channels = config.num_sms / 2; + EP_HOST_ASSERT(config.num_sms % 2 == 0); + EP_HOST_ASSERT(0 < get_num_rdma_ranks() and get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS); + + bool cached_mode = cached_rdma_channel_prefix_matrix.has_value(); + if (cached_mode) { + EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix.has_value()); + EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum.has_value()); + EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix.has_value()); + EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum.has_value()); + } else { + EP_HOST_ASSERT(num_tokens_per_rank.has_value()); + EP_HOST_ASSERT(num_tokens_per_rdma_rank.has_value()); + EP_HOST_ASSERT(num_tokens_per_expert.has_value()); + } + + // Type checks + if (cached_mode) { + EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->scalar_type() == torch::kInt32); + EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->scalar_type() == torch::kInt32); + EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->scalar_type() == torch::kInt32); + EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->scalar_type() == torch::kInt32); + } else { + EP_HOST_ASSERT(num_tokens_per_rank->scalar_type() == torch::kInt32); + EP_HOST_ASSERT(num_tokens_per_rdma_rank->scalar_type() == torch::kInt32); + EP_HOST_ASSERT(num_tokens_per_expert->scalar_type() == torch::kInt32); + } + + // Shape and contiguous checks + EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); + EP_HOST_ASSERT((x.size(1) * x.element_size()) % sizeof(int4) == 0); + if (cached_mode) { + EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->dim() == 2 and cached_rdma_channel_prefix_matrix->is_contiguous()); + EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->size(0) == num_rdma_ranks and cached_rdma_channel_prefix_matrix->size(1) == num_channels); + EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->dim() == 1 and cached_recv_rdma_rank_prefix_sum->is_contiguous()); + EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->size(0) == num_rdma_ranks); + EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->dim() == 2 and cached_gbl_channel_prefix_matrix->is_contiguous()); + EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->size(0) == num_ranks and cached_gbl_channel_prefix_matrix->size(1) == num_channels); + EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->dim() == 1 and cached_recv_gbl_rank_prefix_sum->is_contiguous()); + EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->size(0) == num_ranks); + } else { + EP_HOST_ASSERT(num_tokens_per_rank->dim() == 1 and num_tokens_per_rank->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_rdma_rank->dim() == 1 and num_tokens_per_rdma_rank->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_expert->dim() == 1 and num_tokens_per_expert->is_contiguous()); + EP_HOST_ASSERT(num_tokens_per_rank->size(0) == num_ranks); + EP_HOST_ASSERT(num_tokens_per_rdma_rank->size(0) == num_rdma_ranks); + EP_HOST_ASSERT(num_tokens_per_expert->size(0) % num_ranks == 0); + EP_HOST_ASSERT(num_tokens_per_expert->size(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS); + } + + auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)), hidden_int4 = static_cast(x.size(1) * x.element_size() / sizeof(int4)); + auto num_experts = cached_mode ? 0 : static_cast(num_tokens_per_expert->size(0)), num_local_experts = num_experts / num_ranks; + + // Top-k checks + int num_topk = 0; + int64_t* topk_idx_ptr = nullptr; + float* topk_weights_ptr = nullptr; + EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value()); + if (topk_idx.has_value()) { + num_topk = static_cast(topk_idx->size(1)); + EP_HOST_ASSERT(num_experts > 0); + EP_HOST_ASSERT(topk_idx->dim() == 2 and topk_idx->is_contiguous()); + EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); + EP_HOST_ASSERT(num_tokens == topk_idx->size(0) and num_tokens == topk_weights->size(0)); + EP_HOST_ASSERT(num_topk == topk_weights->size(1)); + EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); + topk_idx_ptr = topk_idx->data_ptr(); + topk_weights_ptr = topk_weights->data_ptr(); + } + + // FP8 scales checks + float* x_scales_ptr = nullptr; + int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0; + if (x_scales.has_value()) { + EP_HOST_ASSERT(x.element_size() == 1); + EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32 or x_scales->scalar_type() == torch::kInt); + EP_HOST_ASSERT(x_scales->dim() == 2); + EP_HOST_ASSERT(x_scales->size(0) == num_tokens); + num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); + x_scales_ptr = static_cast(x_scales->data_ptr()); + scale_token_stride = static_cast(x_scales->stride(0)); + scale_hidden_stride = static_cast(x_scales->stride(1)); + } + + // Allocate all tensors on comm stream if set + // NOTES: do not allocate tensors upfront! + auto compute_stream = at::cuda::getCurrentCUDAStream(); + if (allocate_on_comm_stream) { + EP_HOST_ASSERT(previous_event.has_value() and async); + at::cuda::setCurrentCUDAStream(comm_stream); + } + + // Wait previous tasks to be finished + if (previous_event.has_value()) { + stream_wait(comm_stream, previous_event.value()); + } else { + stream_wait(comm_stream, compute_stream); + } + + // Create handles (only return for non-cached mode) + int num_recv_tokens = -1, num_rdma_recv_tokens = -1; + auto rdma_channel_prefix_matrix = torch::Tensor(); + auto recv_rdma_rank_prefix_sum = torch::Tensor(); + auto gbl_channel_prefix_matrix = torch::Tensor(); + auto recv_gbl_rank_prefix_sum = torch::Tensor(); + std::vector num_recv_tokens_per_expert_list; + + // Barrier or send sizes + if (cached_mode) { + num_recv_tokens = cached_num_recv_tokens; + num_rdma_recv_tokens = cached_num_rdma_recv_tokens; + rdma_channel_prefix_matrix = cached_rdma_channel_prefix_matrix.value(); + recv_rdma_rank_prefix_sum = cached_recv_rdma_rank_prefix_sum.value(); + gbl_channel_prefix_matrix = cached_gbl_channel_prefix_matrix.value(); + recv_gbl_rank_prefix_sum = cached_recv_gbl_rank_prefix_sum.value(); + + // Just a barrier and clean flags + internode::cached_notify(hidden_int4, num_scales, num_topk, num_topk, + num_ranks, num_channels, 0, nullptr, + nullptr, nullptr, nullptr, + rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, + barrier_signal_ptrs_gpu, rank, comm_stream, + config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), + num_nvl_bytes, true, low_latency_mode); + } else { + rdma_channel_prefix_matrix = torch::empty({num_rdma_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); + recv_rdma_rank_prefix_sum = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); + gbl_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); + recv_gbl_rank_prefix_sum = torch::empty({num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); + + // Send sizes + *moe_recv_counter = -1, *moe_recv_rdma_counter = -1; + for (int i = 0; i < num_local_experts; ++ i) + moe_recv_expert_counter[i] = -1; + internode::notify_dispatch(num_tokens_per_rank->data_ptr(), moe_recv_counter_mapped, num_ranks, + num_tokens_per_rdma_rank->data_ptr(), moe_recv_rdma_counter_mapped, + num_tokens_per_expert->data_ptr(), moe_recv_expert_counter_mapped, num_experts, + is_token_in_rank.data_ptr(), num_tokens, num_channels, + hidden_int4, num_scales, num_topk, expert_alignment, + rdma_channel_prefix_matrix.data_ptr(), recv_rdma_rank_prefix_sum.data_ptr(), + gbl_channel_prefix_matrix.data_ptr(), recv_gbl_rank_prefix_sum.data_ptr(), + rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, + barrier_signal_ptrs_gpu, rank, comm_stream, + config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), + num_nvl_bytes, low_latency_mode); + + // Synchronize total received tokens and tokens per expert + auto start_time = std::chrono::high_resolution_clock::now(); + while (true) { + // Read total count + num_recv_tokens = static_cast(*moe_recv_counter); + num_rdma_recv_tokens = static_cast(*moe_recv_rdma_counter); + + // Read per-expert count + bool ready = (num_recv_tokens >= 0) and (num_rdma_recv_tokens >= 0); + for (int i = 0; i < num_local_experts and ready; ++ i) + ready &= moe_recv_expert_counter[i] >= 0; + + if (ready) + break; + + // Timeout check + if (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS) { + printf("Global rank: %d, num_recv_tokens: %d, num_rdma_recv_tokens: %d\n", rank, num_recv_tokens, num_rdma_recv_tokens); + for (int i = 0; i < num_local_experts; ++ i) + printf("moe_recv_expert_counter[%d]: %d\n", i, moe_recv_expert_counter[i]); + throw std::runtime_error("DeepEP error: timeout (dispatch CPU)"); + } + } + num_recv_tokens_per_expert_list = std::vector(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); + } + + // Allocate new tensors + auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); + auto recv_topk_idx = std::optional(), recv_topk_weights = std::optional(), recv_x_scales = std::optional(); + auto recv_src_meta = std::optional(); + auto recv_rdma_channel_prefix_matrix = std::optional(); + auto recv_gbl_channel_prefix_matrix = std::optional(); + auto send_rdma_head = std::optional(); + auto send_nvl_head = std::optional(); + if (not cached_mode) { + recv_src_meta = torch::empty({num_recv_tokens, internode::get_source_meta_bytes()}, dtype(torch::kByte).device(torch::kCUDA)); + recv_rdma_channel_prefix_matrix = torch::empty({num_rdma_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); + recv_gbl_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); + send_rdma_head = torch::empty({num_tokens, num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); + send_nvl_head = torch::empty({num_rdma_recv_tokens, NUM_MAX_NVL_PEERS}, dtype(torch::kInt32).device(torch::kCUDA)); + } + + // Assign pointers + int64_t* recv_topk_idx_ptr = nullptr; + float* recv_topk_weights_ptr = nullptr; + float* recv_x_scales_ptr = nullptr; + if (topk_idx.has_value()) { + recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options()); + recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); + recv_topk_idx_ptr = recv_topk_idx->data_ptr(); + recv_topk_weights_ptr = recv_topk_weights->data_ptr(); + } + if (x_scales.has_value()) { + recv_x_scales = x_scales->dim() == 1 ? + torch::empty({num_recv_tokens}, x_scales->options()) : + torch::empty({num_recv_tokens, num_scales}, x_scales->options()); + recv_x_scales_ptr = static_cast(recv_x_scales->data_ptr()); + } + + // Launch data dispatch + // NOTES: the buffer size checks are moved into the `.cu` file + internode::dispatch(recv_x.data_ptr(), recv_x_scales_ptr, recv_topk_idx_ptr, recv_topk_weights_ptr, + cached_mode ? nullptr : recv_src_meta->data_ptr(), + x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr, + cached_mode ? nullptr : send_rdma_head->data_ptr(), cached_mode ? nullptr : send_nvl_head->data_ptr(), + cached_mode ? nullptr : recv_rdma_channel_prefix_matrix->data_ptr(), + cached_mode ? nullptr : recv_gbl_channel_prefix_matrix->data_ptr(), + rdma_channel_prefix_matrix.data_ptr(), recv_rdma_rank_prefix_sum.data_ptr(), + gbl_channel_prefix_matrix.data_ptr(), recv_gbl_rank_prefix_sum.data_ptr(), + is_token_in_rank.data_ptr(), + num_tokens, hidden_int4, num_scales, num_topk, num_experts, + scale_token_stride, scale_hidden_stride, + rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens, + rank, num_ranks, cached_mode, + comm_stream, num_channels, low_latency_mode); + + // Wait streams + std::optional event; + if (async) { + event = EventHandle(comm_stream); + for (auto& t: {x, is_token_in_rank, recv_x, + rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum}) { + t.record_stream(comm_stream); + if (allocate_on_comm_stream) + t.record_stream(compute_stream); + } + for (auto& to: {x_scales, topk_idx, topk_weights, + num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, + cached_rdma_channel_prefix_matrix, cached_recv_rdma_rank_prefix_sum, + cached_gbl_channel_prefix_matrix, cached_recv_gbl_rank_prefix_sum, + recv_topk_idx, recv_topk_weights, recv_x_scales, + recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, send_rdma_head, send_nvl_head, + recv_src_meta}) { + to.has_value() ? to->record_stream(comm_stream) : void(); + if (allocate_on_comm_stream) + to.has_value() ? to->record_stream(compute_stream) : void(); + } + } else { + stream_wait(compute_stream, comm_stream); + } + + // Switch back compute stream + if (allocate_on_comm_stream) + at::cuda::setCurrentCUDAStream(compute_stream); + + // Return values + return {recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, + rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, + recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, + recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, + recv_src_meta, send_rdma_head, send_nvl_head, event}; +#else + EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); + return {}; +#endif +} + +std::tuple, std::optional> +Buffer::internode_combine(const torch::Tensor& x, const std::optional& topk_weights, + const std::optional& bias_0, const std::optional& bias_1, + const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank, + const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix, + const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head, + const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { +#ifndef DISABLE_NVSHMEM + const int num_channels = config.num_sms / 2; + EP_HOST_ASSERT(config.num_sms % 2 == 0); + + // Shape and contiguous checks + EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); + EP_HOST_ASSERT(src_meta.dim() == 2 and src_meta.is_contiguous() and src_meta.scalar_type() == torch::kByte); + EP_HOST_ASSERT(is_combined_token_in_rank.dim() == 2 and is_combined_token_in_rank.is_contiguous() and is_combined_token_in_rank.scalar_type() == torch::kBool); + EP_HOST_ASSERT(rdma_channel_prefix_matrix.dim() == 2 and rdma_channel_prefix_matrix.is_contiguous() and rdma_channel_prefix_matrix.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(rdma_rank_prefix_sum.dim() == 1 and rdma_rank_prefix_sum.is_contiguous() and rdma_rank_prefix_sum.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(gbl_channel_prefix_matrix.dim() == 2 and gbl_channel_prefix_matrix.is_contiguous() and gbl_channel_prefix_matrix.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.is_contiguous() and combined_rdma_head.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.is_contiguous() and combined_nvl_head.scalar_type() == torch::kInt32); + + auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)), hidden_int4 = static_cast(x.size(1) * x.element_size() / sizeof(int4)); + auto num_combined_tokens = static_cast(is_combined_token_in_rank.size(0)); + EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0); + EP_HOST_ASSERT(src_meta.size(1) == internode::get_source_meta_bytes()); + EP_HOST_ASSERT(is_combined_token_in_rank.size(1) == num_ranks); + EP_HOST_ASSERT(rdma_channel_prefix_matrix.size(0) == num_rdma_ranks and rdma_channel_prefix_matrix.size(1) == num_channels); + EP_HOST_ASSERT(rdma_rank_prefix_sum.size(0) == num_rdma_ranks); + EP_HOST_ASSERT(gbl_channel_prefix_matrix.size(0) == num_ranks and gbl_channel_prefix_matrix.size(1) == num_channels); + EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.size(0) == num_combined_tokens and combined_rdma_head.size(1) == num_rdma_ranks); + EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.size(1) == NUM_MAX_NVL_PEERS); + + // Allocate all tensors on comm stream if set + // NOTES: do not allocate tensors upfront! + auto compute_stream = at::cuda::getCurrentCUDAStream(); + if (allocate_on_comm_stream) { + EP_HOST_ASSERT(previous_event.has_value() and async); + at::cuda::setCurrentCUDAStream(comm_stream); + } + + // Wait previous tasks to be finished + if (previous_event.has_value()) { + stream_wait(comm_stream, previous_event.value()); + } else { + stream_wait(comm_stream, compute_stream); + } + + // Top-k checks + int num_topk = 0; + auto combined_topk_weights = std::optional(); + float* topk_weights_ptr = nullptr; + float* combined_topk_weights_ptr = nullptr; + if (topk_weights.has_value()) { + EP_HOST_ASSERT(topk_weights->dim() == 2 and topk_weights->is_contiguous()); + EP_HOST_ASSERT(topk_weights->size(0) == num_tokens); + EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); + num_topk = static_cast(topk_weights->size(1)); + topk_weights_ptr = topk_weights->data_ptr(); + combined_topk_weights = torch::empty({num_combined_tokens, num_topk}, topk_weights->options()); + combined_topk_weights_ptr = combined_topk_weights->data_ptr(); + } + + // Extra check for avoid-dead-lock design + EP_HOST_ASSERT(config.num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0); + EP_HOST_ASSERT(config.num_max_nvl_chunked_send_tokens <= config.num_max_nvl_chunked_recv_tokens / num_rdma_ranks); + + // Launch barrier and reset queue head and tail + internode::cached_notify(hidden_int4, 0, 0, num_topk, + num_ranks, num_channels, + num_combined_tokens, combined_rdma_head.data_ptr(), + rdma_channel_prefix_matrix.data_ptr(), rdma_rank_prefix_sum.data_ptr(), combined_nvl_head.data_ptr(), + rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, + barrier_signal_ptrs_gpu, rank, comm_stream, + config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), + num_nvl_bytes, false, low_latency_mode); + + // Assign bias pointers + auto bias_opts = std::vector>({bias_0, bias_1}); + void* bias_ptrs[2] = {nullptr, nullptr}; + for (int i = 0; i < 2; ++ i) if (bias_opts[i].has_value()) { + auto bias = bias_opts[i].value(); + EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous()); + EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type()); + EP_HOST_ASSERT(bias.size(0) == num_combined_tokens and bias.size(1) == hidden); + bias_ptrs[i] = bias.data_ptr(); + } + + // Launch data combine + auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options()); + internode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()), + combined_x.data_ptr(), combined_topk_weights_ptr, + is_combined_token_in_rank.data_ptr(), + x.data_ptr(), topk_weights_ptr, bias_ptrs[0], bias_ptrs[1], + combined_rdma_head.data_ptr(), combined_nvl_head.data_ptr(), + src_meta.data_ptr(), rdma_channel_prefix_matrix.data_ptr(), rdma_rank_prefix_sum.data_ptr(), gbl_channel_prefix_matrix.data_ptr(), + num_tokens, num_combined_tokens, hidden, num_topk, + rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens, + rank, num_ranks, comm_stream, num_channels, low_latency_mode); + + // Wait streams + std::optional event; + if (async) { + event = EventHandle(comm_stream); + for (auto& t: {x, src_meta, + is_combined_token_in_rank, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, + combined_x, combined_rdma_head, combined_nvl_head}) { + t.record_stream(comm_stream); + if (allocate_on_comm_stream) + t.record_stream(compute_stream); + } + for (auto& to: {topk_weights, combined_topk_weights, bias_0, bias_1}) { + to.has_value() ? to->record_stream(comm_stream) : void(); + if (allocate_on_comm_stream) + to.has_value() ? to->record_stream(compute_stream) : void(); + } + } else { + stream_wait(compute_stream, comm_stream); + } + + // Switch back compute stream + if (allocate_on_comm_stream) + at::cuda::setCurrentCUDAStream(compute_stream); + + // Return values + return {combined_x, combined_topk_weights, event}; +#else + EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); + return {}; +#endif +} + +void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) { +#ifndef DISABLE_NVSHMEM + EP_HOST_ASSERT(low_latency_mode); + + auto layout = LowLatencyLayout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); + auto clean_meta_0 = layout.buffers[0].clean_meta(); + auto clean_meta_1 = layout.buffers[1].clean_meta(); + + auto check_boundary = [=](void* ptr, size_t num_bytes) { + auto offset = reinterpret_cast(ptr) - reinterpret_cast(rdma_buffer_ptr); + EP_HOST_ASSERT(0 <= offset and offset + num_bytes <= num_rdma_bytes); + }; + check_boundary(clean_meta_0.first, clean_meta_0.second * sizeof(int)); + check_boundary(clean_meta_1.first, clean_meta_1.second * sizeof(int)); + + internode_ll::clean_low_latency_buffer(clean_meta_0.first, clean_meta_0.second, + clean_meta_1.first, clean_meta_1.second, + at::cuda::getCurrentCUDAStream()); +#else + EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); +#endif +} + +std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> +Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, + const std::optional& cumulative_local_expert_recv_stats, + int num_max_dispatch_tokens_per_rank, int num_experts, + bool use_fp8, bool round_scale, bool use_ue8m0, + bool async, bool return_recv_hook) { +#ifndef DISABLE_NVSHMEM + EP_HOST_ASSERT(low_latency_mode); + + // Tensor checks + // By default using `ptp128c` FP8 cast + EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16); + EP_HOST_ASSERT(x.size(1) % sizeof(int4) == 0 and x.size(1) % 128 == 0); + EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous()); + EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64); + EP_HOST_ASSERT(num_experts % num_ranks == 0); + if (cumulative_local_expert_recv_stats.has_value()) { + EP_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt); + EP_HOST_ASSERT(cumulative_local_expert_recv_stats->dim() == 1 and cumulative_local_expert_recv_stats->is_contiguous()); + EP_HOST_ASSERT(cumulative_local_expert_recv_stats->size(0) == num_experts / num_ranks); + } + + auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); + auto num_scales = hidden / 128, num_topk = static_cast(topk_idx.size(1)); + auto num_local_experts = num_experts / num_ranks; + + // Buffer control + LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); + EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); + auto buffer = layout.buffers[low_latency_buffer_idx]; + auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; + + // Wait previous tasks to be finished + // NOTES: the hook mode will always use the default stream + auto compute_stream = at::cuda::getCurrentCUDAStream(); + auto launch_stream = return_recv_hook ? compute_stream : comm_stream; + EP_HOST_ASSERT(not (async and return_recv_hook)); + if (not return_recv_hook) + stream_wait(launch_stream, compute_stream); + + // Allocate packed tensors + auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, + x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16)); + auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA)); + auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + + // Allocate column-majored scales + auto packed_recv_x_scales = std::optional(); + void* packed_recv_x_scales_ptr = nullptr; + EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4"); + + if (use_fp8) { + // TODO: support unaligned cases + EP_HOST_ASSERT(hidden % 512 == 0); + if (not use_ue8m0) { + packed_recv_x_scales = torch::empty({num_local_experts, hidden / 128, num_ranks * num_max_dispatch_tokens_per_rank}, + torch::dtype(torch::kFloat32).device(torch::kCUDA)); + } else { + EP_HOST_ASSERT(round_scale); + packed_recv_x_scales = torch::empty({num_local_experts, hidden / 512, num_ranks * num_max_dispatch_tokens_per_rank}, + torch::dtype(torch::kInt).device(torch::kCUDA)); + } + packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); + packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); + } + + // Kernel launch + auto next_clean_meta = next_buffer.clean_meta(); + auto launcher = [=](int phases) { + internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, + packed_recv_src_info.data_ptr(), packed_recv_layout_range.data_ptr(), + packed_recv_count.data_ptr(), + cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr() : nullptr, + buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer, + buffer.dispatch_rdma_send_buffer, + x.data_ptr(), topk_idx.data_ptr(), + next_clean_meta.first, next_clean_meta.second, + num_tokens, hidden, num_max_dispatch_tokens_per_rank, + num_topk, num_experts, rank, num_ranks, + use_fp8, round_scale, use_ue8m0, + workspace, num_device_sms, + launch_stream, phases); + }; + launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); + + // Wait streams + std::optional event; + if (async) { + // NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens, + // so in Python API, we must wrap all tensors into the event handle. + event = EventHandle(launch_stream); + } else if (not return_recv_hook) { + stream_wait(compute_stream, launch_stream); + } + + // Receiver callback + std::optional> recv_hook = std::nullopt; + if (return_recv_hook) + recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); }; + + // Return values + return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook}; +#else + EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); + return {}; +#endif +} + +std::tuple, std::optional>> +Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, + const torch::Tensor& src_info, const torch::Tensor& layout_range, + int num_max_dispatch_tokens_per_rank, int num_experts, + bool zero_copy, bool async, bool return_recv_hook, + const std::optional& out) { +#ifndef DISABLE_NVSHMEM + EP_HOST_ASSERT(low_latency_mode); + + // Tensor checks + EP_HOST_ASSERT(x.dim() == 3 and x.is_contiguous() and x.scalar_type() == torch::kBFloat16); + EP_HOST_ASSERT(x.size(0) == num_experts / num_ranks); + EP_HOST_ASSERT(x.size(1) == num_ranks * num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(x.size(2) % sizeof(int4) == 0 and x.size(2) % 128 == 0); + EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous()); + EP_HOST_ASSERT(topk_idx.size(0) == topk_weights.size(0) and topk_idx.size(1) == topk_weights.size(1)); + EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64); + EP_HOST_ASSERT(topk_weights.dim() == 2 and topk_weights.is_contiguous()); + EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank); + EP_HOST_ASSERT(topk_weights.scalar_type() == torch::kFloat32); + EP_HOST_ASSERT(src_info.dim() == 2 and src_info.is_contiguous()); + EP_HOST_ASSERT(src_info.scalar_type() == torch::kInt32 and x.size(0) == src_info.size(0)); + EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous()); + EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64); + EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks); + auto hidden = static_cast(x.size(2)); + auto num_topk = static_cast(topk_weights.size(1)); + auto num_combined_tokens = static_cast(topk_weights.size(0)); + + // Buffer control + LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); + EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); + auto buffer = layout.buffers[low_latency_buffer_idx]; + auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; + + // Wait previous tasks to be finished + // NOTES: the hook mode will always use the default stream + auto compute_stream = at::cuda::getCurrentCUDAStream(); + auto launch_stream = return_recv_hook ? compute_stream : comm_stream; + EP_HOST_ASSERT(not (async and return_recv_hook)); + if (not return_recv_hook) + stream_wait(launch_stream, compute_stream); + + // Allocate output tensor + torch::Tensor combined_x; + if (out.has_value()) { + EP_HOST_ASSERT(out->dim() == 2 and out->is_contiguous()); + EP_HOST_ASSERT(out->size(0) == num_combined_tokens and out->size(1) == hidden); + EP_HOST_ASSERT(out->scalar_type() == x.scalar_type()); + combined_x = out.value(); + } else { + combined_x = torch::empty({num_combined_tokens, hidden}, x.options()); + } + + // Kernel launch + auto next_clean_meta = next_buffer.clean_meta(); + auto launcher = [=](int phases) { + internode_ll::combine(combined_x.data_ptr(), + buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer, + buffer.combine_rdma_send_buffer, + x.data_ptr(), topk_idx.data_ptr(), topk_weights.data_ptr(), + src_info.data_ptr(), layout_range.data_ptr(), + next_clean_meta.first, next_clean_meta.second, + num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank, + num_topk, num_experts, rank, num_ranks, + workspace, num_device_sms, + launch_stream, phases, zero_copy); + }; + launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); + + // Wait streams + std::optional event; + if (async) { + // NOTES: we must ensure the all tensors will not be deallocated before the stream-wait happens, + // so in Python API, we must wrap all tensors into the event handle. + event = EventHandle(launch_stream); + } else if (not return_recv_hook) { + stream_wait(compute_stream, launch_stream); + } + + // Receiver callback + std::optional> recv_hook = std::nullopt; + if (return_recv_hook) + recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); }; + + // Return values + return {combined_x, event, recv_hook}; +#else + EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); + return {}; +#endif +} + +torch::Tensor +Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const { +#ifndef DISABLE_NVSHMEM + LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); + + auto buffer = layout.buffers[low_latency_buffer_idx]; + auto dtype = torch::kBFloat16; + auto num_msg_elems = static_cast(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16)); + + EP_HOST_ASSERT(buffer.num_bytes_per_combine_msg % elementSize(torch::kBFloat16) == 0); + return torch::from_blob(buffer.combine_rdma_send_buffer_data_start, + {num_experts / num_ranks, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, + {num_ranks * num_max_dispatch_tokens_per_rank * num_msg_elems, num_msg_elems, 1}, + torch::TensorOptions().dtype(dtype).device(torch::kCUDA)); +#else + EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); + return {}; +#endif +} + +bool is_sm90_compiled() { +#ifndef DISABLE_SM90_FEATURES + return true; +#else + return false; +#endif +} + +} // namespace deep_ep + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "DeepEP: an efficient expert-parallel communication library"; + + pybind11::class_(m, "Config") + .def(pybind11::init(), + py::arg("num_sms") = 20, + py::arg("num_max_nvl_chunked_send_tokens") = 6, py::arg("num_max_nvl_chunked_recv_tokens") = 256, + py::arg("num_max_rdma_chunked_send_tokens") = 6, py::arg("num_max_rdma_chunked_recv_tokens") = 256) + .def("get_nvl_buffer_size_hint", &deep_ep::Config::get_nvl_buffer_size_hint) + .def("get_rdma_buffer_size_hint", &deep_ep::Config::get_rdma_buffer_size_hint); + m.def("get_low_latency_rdma_size_hint", &deep_ep::get_low_latency_rdma_size_hint); + + pybind11::class_(m, "EventHandle") + .def(pybind11::init<>()) + .def("current_stream_wait", &deep_ep::EventHandle::current_stream_wait); + + pybind11::class_(m, "Buffer") + .def(pybind11::init()) + .def("is_available", &deep_ep::Buffer::is_available) + .def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks) + .def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank) + .def("get_root_rdma_rank", &deep_ep::Buffer::get_root_rdma_rank) + .def("get_local_device_id", &deep_ep::Buffer::get_local_device_id) + .def("get_local_ipc_handle", &deep_ep::Buffer::get_local_ipc_handle) + .def("get_local_nvshmem_unique_id", &deep_ep::Buffer::get_local_nvshmem_unique_id) + .def("get_local_buffer_tensor", &deep_ep::Buffer::get_local_buffer_tensor) + .def("get_comm_stream", &deep_ep::Buffer::get_comm_stream) + .def("sync", &deep_ep::Buffer::sync) + .def("get_dispatch_layout", &deep_ep::Buffer::get_dispatch_layout) + .def("intranode_dispatch", &deep_ep::Buffer::intranode_dispatch) + .def("intranode_combine", &deep_ep::Buffer::intranode_combine) + .def("internode_dispatch", &deep_ep::Buffer::internode_dispatch) + .def("internode_combine", &deep_ep::Buffer::internode_combine) + .def("clean_low_latency_buffer", &deep_ep::Buffer::clean_low_latency_buffer) + .def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch) + .def("low_latency_combine", &deep_ep::Buffer::low_latency_combine) + .def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer); + + m.def("is_sm90_compiled", deep_ep::is_sm90_compiled); +} diff --git a/DeepEP/csrc/deep_ep.hpp b/DeepEP/csrc/deep_ep.hpp new file mode 100644 index 000000000..00f8d0c49 --- /dev/null +++ b/DeepEP/csrc/deep_ep.hpp @@ -0,0 +1,157 @@ +#pragma once + +// Forcibly disable NDEBUG +#ifdef NDEBUG +#undef NDEBUG +#endif + +#include +#include +#include +#include +#include + +#include "config.hpp" +#include "event.hpp" +#include "kernels/configs.cuh" +#include "kernels/exception.cuh" + +#ifndef TORCH_EXTENSION_NAME +#define TORCH_EXTENSION_NAME deep_ep_cpp +#endif + +namespace deep_ep { + +struct Buffer { + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "The number of maximum NVLink peers must be 8"); + +private: + // Low-latency mode buffer + int low_latency_buffer_idx = 0; + bool low_latency_mode = false; + + // NVLink Buffer + int64_t num_nvl_bytes; + void* buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; + void** buffer_ptrs_gpu = nullptr; + + // NVSHMEM Buffer + int64_t num_rdma_bytes; + void* rdma_buffer_ptr = nullptr; + + // Device info and communication + int device_id; + int num_device_sms; + int rank, rdma_rank, nvl_rank; + int num_ranks, num_rdma_ranks, num_nvl_ranks; + cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS]; + + // Stream for communication + at::cuda::CUDAStream comm_stream; + + // After IPC/NVSHMEM synchronization, this flag will be true + bool available = false; + + // Barrier signals + int* barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; + int** barrier_signal_ptrs_gpu = nullptr; + + // Workspace + void* workspace = nullptr; + + // Host-side MoE info + volatile int* moe_recv_counter = nullptr; + int* moe_recv_counter_mapped = nullptr; + + // Host-side expert-level MoE info + volatile int* moe_recv_expert_counter = nullptr; + int* moe_recv_expert_counter_mapped = nullptr; + + // Host-side RDMA-level MoE info + volatile int* moe_recv_rdma_counter = nullptr; + int* moe_recv_rdma_counter_mapped = nullptr; + +public: + Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode); + + ~Buffer() noexcept(false); + + bool is_available() const; + + bool is_internode_available() const; + + int get_num_rdma_ranks() const; + + int get_rdma_rank() const; + + int get_root_rdma_rank(bool global) const; + + int get_local_device_id() const; + + pybind11::bytearray get_local_ipc_handle() const; + + pybind11::bytearray get_local_nvshmem_unique_id() const; + + torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const; + + torch::Stream get_comm_stream() const; + + void sync(const std::vector& device_ids, const std::vector>& all_gathered_handles, const std::optional& root_unique_id_opt); + + std::tuple, torch::Tensor, torch::Tensor, std::optional> + get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, std::optional& previous_event, + bool async, bool allocate_on_comm_stream); + + std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional> + intranode_dispatch(const torch::Tensor& x, const std::optional& x_scales, + const std::optional& topk_idx, const std::optional& topk_weights, + const std::optional& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, + int cached_num_recv_tokens, const std::optional& cached_rank_prefix_matrix, const std::optional& cached_channel_prefix_matrix, + int expert_alignment, int num_worst_tokens, const Config& config, + std::optional& previous_event, bool async, bool allocate_on_comm_stream); + + std::tuple, std::optional> + intranode_combine(const torch::Tensor& x, const std::optional& topk_weights, + const std::optional& bias_0, const std::optional& bias_1, + const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix, + const torch::Tensor& send_head, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); + + std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, std::optional, torch::Tensor, std::optional, torch::Tensor, std::optional, std::optional, std::optional, std::optional> + internode_dispatch(const torch::Tensor& x, const std::optional& x_scales, + const std::optional& topk_idx, const std::optional& topk_weights, + const std::optional& num_tokens_per_rank, const std::optional& num_tokens_per_rdma_rank, + const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, + int cached_num_recv_tokens, int cached_num_rdma_recv_tokens, + const std::optional& cached_rdma_channel_prefix_matrix, const std::optional& cached_recv_rdma_rank_prefix_sum, + const std::optional& cached_gbl_channel_prefix_matrix, const std::optional& cached_recv_gbl_rank_prefix_sum, + int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); + + std::tuple, std::optional> + internode_combine(const torch::Tensor& x, const std::optional& topk_weights, + const std::optional& bias_0, const std::optional& bias_1, + const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank, + const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix, + const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head, + const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); + + void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); + + std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> + low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, + const std::optional& cumulative_local_expert_recv_stats, + int num_max_dispatch_tokens_per_rank, int num_experts, + bool use_fp8, bool round_scale, bool use_ue8m0, + bool async, bool return_recv_hook); + + std::tuple, std::optional>> + low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, + const torch::Tensor& src_info, const torch::Tensor& layout_range, + int num_max_dispatch_tokens_per_rank, int num_experts, + bool zero_copy, bool async, bool return_recv_hook, + const std::optional& out = std::nullopt); + + torch::Tensor + get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const; +}; + +} // namespace deep_ep diff --git a/DeepEP/csrc/event.hpp b/DeepEP/csrc/event.hpp new file mode 100644 index 000000000..a93444d16 --- /dev/null +++ b/DeepEP/csrc/event.hpp @@ -0,0 +1,43 @@ +#include +#include + +#include "kernels/exception.cuh" + +namespace deep_ep { + +struct EventHandle { + std::shared_ptr event; + + EventHandle() { + event = std::make_shared(torch::kCUDA); + event->record(at::cuda::getCurrentCUDAStream()); + } + + explicit EventHandle(const at::cuda::CUDAStream& stream) { + event = std::make_shared(torch::kCUDA); + event->record(stream); + } + + EventHandle(const EventHandle& other) = default; + + void current_stream_wait() const { + at::cuda::getCurrentCUDAStream().unwrap().wait(*event); + } +}; + +torch::Event create_event(const at::cuda::CUDAStream &s) { + auto event = torch::Event(torch::kCUDA); + event.record(s); + return event; +} + +void stream_wait(const at::cuda::CUDAStream& s_0, const at::cuda::CUDAStream& s_1) { + EP_HOST_ASSERT(s_0.id() != s_1.id()); + s_0.unwrap().wait(create_event(s_1)); +} + +void stream_wait(const at::cuda::CUDAStream& s, const EventHandle& event) { + s.unwrap().wait(*event.event); +} + +} // namespace deep_ep diff --git a/DeepEP/csrc/kernels/CMakeLists.txt b/DeepEP/csrc/kernels/CMakeLists.txt new file mode 100644 index 000000000..22e34a38c --- /dev/null +++ b/DeepEP/csrc/kernels/CMakeLists.txt @@ -0,0 +1,21 @@ +function(add_deep_ep_library target_name source_file) + add_library(${target_name} STATIC ${source_file}) + set_target_properties(${target_name} PROPERTIES + POSITION_INDEPENDENT_CODE ON + CXX_STANDARD_REQUIRED ON + CUDA_STANDARD_REQUIRED ON + CXX_STANDARD 17 + CUDA_STANDARD 17 + CUDA_SEPARABLE_COMPILATION ON + ) + target_link_libraries(${target_name} PUBLIC nvshmem cudart cudadevrt mlx5) +endfunction() + +add_deep_ep_library(runtime_cuda runtime.cu) +add_deep_ep_library(layout_cuda layout.cu) +add_deep_ep_library(intranode_cuda intranode.cu) +add_deep_ep_library(internode_cuda internode.cu) +add_deep_ep_library(internode_ll_cuda internode_ll.cu) + +# Later, we should link all libraries in `EP_CUDA_LIBRARIES` +set(EP_CUDA_LIBRARIES runtime_cuda layout_cuda intranode_cuda internode_cuda internode_ll_cuda PARENT_SCOPE) diff --git a/DeepEP/csrc/kernels/api.cuh b/DeepEP/csrc/kernels/api.cuh new file mode 100644 index 000000000..84703c95a --- /dev/null +++ b/DeepEP/csrc/kernels/api.cuh @@ -0,0 +1,167 @@ +#pragma once + +#include + +namespace deep_ep { + +// Intranode runtime +namespace intranode { + +void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream); + +} // namespace intranode + +// Internode runtime +namespace internode { + +std::vector get_unique_id(); + +int init(const std::vector &root_unique_id_val, int rank, int num_ranks, bool low_latency_mode); + +void *alloc(size_t size, size_t alignment); + +void free(void *ptr); + +void barrier(); + +void finalize(); + +} // namespace internode + +// Layout kernels +namespace layout { + +void get_dispatch_layout(const int64_t* topk_idx, + int* num_tokens_per_rank, int* num_tokens_per_rdma_rank, + int* num_tokens_per_expert, bool* is_token_in_rank, + int num_tokens, int num_topk, int num_ranks, int num_experts, + cudaStream_t stream); + +} // namespace layout + +// Intranode kernels +namespace intranode { + +void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, + const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, + int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix, + int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment, + void** buffer_ptrs, int** barrier_signal_ptrs, int rank, + cudaStream_t stream, int num_sms); + +void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, + void** buffer_ptrs, int** barrier_signal_ptrs, int rank, int num_ranks, + cudaStream_t stream); + +void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset, + int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, + const bool* is_token_in_rank, const int* channel_prefix_matrix, + int num_tokens, int num_worst_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, + int scale_token_stride, int scale_hidden_stride, + void** buffer_ptrs, int rank, int num_ranks, + cudaStream_t stream, int num_sms, + int num_max_send_tokens, int num_recv_buffer_tokens); + +void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int, + int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream); + +void combine(cudaDataType_t type, + void* recv_x, float* recv_topk_weights, + const void* x, const float* topk_weights, + const void* bias_0, const void* bias_1, + const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix, + int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk, + void** buffer_ptrs, int rank, int num_ranks, + cudaStream_t stream, int num_sms, + int num_max_send_tokens, int num_recv_buffer_tokens); + +} // namespace intranode + +// Internode kernels +namespace internode { + +int get_source_meta_bytes(); + +void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, + const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, + const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, + const bool* is_token_in_rank, int num_tokens, int num_channels, + int hidden_int4, int num_scales, int num_topk, int expert_alignment, + int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, + void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, + int** barrier_signal_ptrs, int rank, + cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, + bool low_latency_mode); + +void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta, + const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, + int* send_rdma_head, int* send_nvl_head, + int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, + const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum, + const bool* is_token_in_rank, + int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, + int scale_token_stride, int scale_hidden_stride, + void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, + int rank, int num_ranks, bool is_cached_dispatch, + cudaStream_t stream, int num_channels, bool low_latency_mode); + +void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, + int num_ranks, int num_channels, int num_combined_tokens, int* combined_rdma_head, + const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head, + void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, + int** barrier_signal_ptrs, int rank, cudaStream_t stream, + int64_t num_rdma_bytes, int64_t num_nvl_bytes, + bool is_cached_dispatch, bool low_latency_mode); + +void combine(cudaDataType_t type, + void* combined_x, float* combined_topk_weights, + const bool* is_combined_token_in_rank, + const void* x, const float* topk_weights, + const void* bias_0, const void* bias_1, + const int* combined_rdma_head, const int* combined_nvl_head, + const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, + int num_tokens, int num_combined_tokens, int hidden, int num_topk, + void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, + int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode); + +} // namespace internode + +// Internode low-latency kernels +namespace internode_ll { + +void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, + int* clean_1, int num_clean_int_1, + cudaStream_t stream); + +void dispatch(void* packed_recv_x, void* packed_recv_x_scales, + int* packed_recv_src_info, int64_t* packed_recv_layout_range, + int* packed_recv_count, + int* cumulative_local_expert_recv_stats, + void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, + const void* x, const int64_t* topk_idx, + int* next_clean, int num_next_clean_int, + int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, + int num_topk, int num_experts, int rank, int num_ranks, + bool use_fp8, bool round_scale, bool use_ue8m0, + void* workspace, int num_device_sms, + cudaStream_t stream, int phases); + +void combine(void* combined_x, + void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x, + const void* x, const int64_t* topk_idx, const float* topk_weights, + const int* src_info, const int64_t* layout_range, + int* next_clean, int num_next_clean_int, + int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, + int num_topk, int num_experts, int rank, int num_ranks, + void* workspace, int num_device_sms, + cudaStream_t stream, int phases, bool zero_copy); + +} // namespace internode_ll + +} // namespace deep_ep diff --git a/DeepEP/csrc/kernels/buffer.cuh b/DeepEP/csrc/kernels/buffer.cuh new file mode 100644 index 000000000..7c243d3c2 --- /dev/null +++ b/DeepEP/csrc/kernels/buffer.cuh @@ -0,0 +1,138 @@ +#pragma once + +#include "configs.cuh" +#include "exception.cuh" + +namespace deep_ep { + +template +struct Buffer { +private: + uint8_t* ptr; + +public: + int total_bytes; + + __device__ __forceinline__ Buffer() : ptr(nullptr), total_bytes(0) {} + + __device__ __forceinline__ Buffer(void* &gbl_ptr, int num_elems, int offset = 0) { + total_bytes = num_elems * sizeof(dtype_t); + ptr = reinterpret_cast(gbl_ptr) + offset * sizeof(dtype_t); + gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + } + + __device__ __forceinline__ Buffer advance_also(void* &gbl_ptr) { + gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + return *this; + } + + __device__ __forceinline__ dtype_t* buffer() { + return reinterpret_cast(ptr); + } + + __device__ __forceinline__ dtype_t& operator[](int idx) { + return buffer()[idx]; + } +}; + +template +struct AsymBuffer { +private: + uint8_t* ptrs[kNumRanks]; + int num_bytes; + +public: + int total_bytes; + + __device__ __forceinline__ AsymBuffer(void* &gbl_ptr, int num_elems, int num_ranks, + int sm_id = 0, int num_sms = 1, int offset = 0) { + EP_STATIC_ASSERT(kNumRanks == 1, ""); + num_bytes = num_elems * sizeof(dtype_t); + + int per_channel_bytes = num_bytes * num_ranks; + total_bytes = per_channel_bytes * num_sms; + ptrs[0] = reinterpret_cast(gbl_ptr) + per_channel_bytes * sm_id + num_bytes * offset; + gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + } + + __device__ __forceinline__ AsymBuffer(void** gbl_ptrs, int num_elems, int num_ranks, + int sm_id = 0, int num_sms = 1, int offset = 0) { + EP_STATIC_ASSERT(kNumRanks > 1, ""); + num_bytes = num_elems * sizeof(dtype_t); + + int per_channel_bytes = num_bytes * num_ranks; + total_bytes = per_channel_bytes * num_sms; + for (int i = 0; i < kNumRanks; ++ i) { + ptrs[i] = reinterpret_cast(gbl_ptrs[i]) + per_channel_bytes * sm_id + num_bytes * offset; + gbl_ptrs[i] = reinterpret_cast(gbl_ptrs[i]) + total_bytes; + } + } + + __device__ __forceinline__ void advance(int shift) { + #pragma unroll + for (int i = 0; i < kNumRanks; ++ i) + ptrs[i] = ptrs[i] + shift * sizeof(dtype_t); + } + + __device__ __forceinline__ AsymBuffer advance_also(void* &gbl_ptr) { + gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + return *this; + } + + template + __device__ __forceinline__ AsymBuffer advance_also(void** gbl_ptrs) { + for (int i = 0; i < kNumAlsoRanks; ++ i) + gbl_ptrs[i] = reinterpret_cast(gbl_ptrs[i]) + total_bytes; + return *this; + } + + __device__ __forceinline__ dtype_t* buffer(int idx = 0) { + EP_STATIC_ASSERT(kNumRanks == 1, "`buffer` is only available for single rank case"); + return reinterpret_cast(ptrs[0] + num_bytes * idx); + } + + __device__ __forceinline__ dtype_t* buffer_by(int rank_idx, int idx = 0) { + EP_STATIC_ASSERT(kNumRanks > 1, "`buffer` is only available for single rank case"); + return reinterpret_cast(ptrs[rank_idx] + num_bytes * idx); + } +}; + +template +struct SymBuffer { +private: + // NOTES: for non-decoupled case, `recv_ptr` is not used + uint8_t* send_ptr; + uint8_t* recv_ptr; + int num_bytes; + +public: + int total_bytes; + + __device__ __forceinline__ SymBuffer(void* &gbl_ptr, int num_elems, int num_ranks, + int sm_id = 0, int num_sms = 1) { + num_bytes = num_elems * sizeof(dtype_t); + + int per_channel_bytes = num_bytes * num_ranks; + total_bytes = per_channel_bytes * num_sms * (static_cast(kDecoupled) + 1); + send_ptr = reinterpret_cast(gbl_ptr) + per_channel_bytes * sm_id; + recv_ptr = reinterpret_cast(gbl_ptr) + per_channel_bytes * (sm_id + num_sms); + gbl_ptr = reinterpret_cast(gbl_ptr) + total_bytes; + } + + __device__ __forceinline__ dtype_t* send_buffer(int idx = 0) { + EP_STATIC_ASSERT(kDecoupled, "`send_buffer` is only available for non-decoupled case"); + return reinterpret_cast(send_ptr + num_bytes * idx); + } + + __device__ __forceinline__ dtype_t* recv_buffer(int idx = 0) { + EP_STATIC_ASSERT(kDecoupled, "`recv_buffer` is only available for non-decoupled case"); + return reinterpret_cast(recv_ptr + num_bytes * idx); + } + + __device__ __forceinline__ dtype_t* buffer(int idx = 0) { + EP_STATIC_ASSERT(not kDecoupled, "`buffer` is only available for decoupled case"); + return reinterpret_cast(send_ptr + num_bytes * idx); + } +}; + +} // namespace deep_ep diff --git a/DeepEP/csrc/kernels/configs.cuh b/DeepEP/csrc/kernels/configs.cuh new file mode 100644 index 000000000..a7b960bdc --- /dev/null +++ b/DeepEP/csrc/kernels/configs.cuh @@ -0,0 +1,67 @@ +#pragma once + +#define NUM_MAX_NVL_PEERS 8 +#define NUM_MAX_RDMA_PEERS 20 +#define NUM_WORKSPACE_BYTES (32 * 1024 * 1024) +#define NUM_MAX_LOCAL_EXPERTS 1024 +#define NUM_BUFFER_ALIGNMENT_BYTES 128 + +#define FINISHED_SUM_TAG 1024 +#define NUM_WAIT_NANOSECONDS 500 + +#ifndef ENABLE_FAST_DEBUG +#define NUM_CPU_TIMEOUT_SECS 100 +#define NUM_TIMEOUT_CYCLES 200000000000ull // 200G cycles ~= 100s +#else +#define NUM_CPU_TIMEOUT_SECS 10 +#define NUM_TIMEOUT_CYCLES 20000000000ull // 20G cycles ~= 10s +#endif + +#define LOW_LATENCY_SEND_PHASE 1 +#define LOW_LATENCY_RECV_PHASE 2 + +// Make CLion CUDA indexing work +#ifdef __CLION_IDE__ +#define __CUDA_ARCH__ 900 // NOLINT(*-reserved-identifier) +#define __CUDACC_RDC__ // NOLINT(*-reserved-identifier) +#endif + +// Remove Torch restrictions +#ifdef __CUDA_NO_HALF_CONVERSIONS__ +#undef __CUDA_NO_HALF_CONVERSIONS__ +#endif +#ifdef __CUDA_NO_HALF_OPERATORS__ +#undef __CUDA_NO_HALF_OPERATORS__ +#endif +#ifdef __CUDA_NO_HALF2_OPERATORS__ +#undef __CUDA_NO_HALF2_OPERATORS__ +#endif +#ifdef __CUDA_NO_BFLOAT16_CONVERSIONS__ +#undef __CUDA_NO_BFLOAT16_CONVERSIONS__ +#endif +#ifdef __CUDA_NO_BFLOAT162_OPERATORS__ +#undef __CUDA_NO_BFLOAT162_OPERATORS__ +#endif + +#include +#include +#include + +#ifndef DISABLE_SM90_FEATURES +#include +#else +// Ampere does not support FP8 features +#define __NV_E4M3 0 +#define __NV_E5M2 1 +typedef int __nv_fp8_interpretation_t; +typedef int __nv_fp8x4_e4m3; +typedef uint8_t __nv_fp8_storage_t; +#endif + +#ifndef DISABLE_NVSHMEM +#include +#include +#include +#include +#include +#endif diff --git a/DeepEP/csrc/kernels/exception.cuh b/DeepEP/csrc/kernels/exception.cuh new file mode 100644 index 000000000..7db0ddb7f --- /dev/null +++ b/DeepEP/csrc/kernels/exception.cuh @@ -0,0 +1,51 @@ +#pragma once + +#include +#include + +#include "configs.cuh" + +#ifndef EP_STATIC_ASSERT +#define EP_STATIC_ASSERT(cond, reason) static_assert(cond, reason) +#endif + +class EPException: public std::exception { +private: + std::string message = {}; + +public: + explicit EPException(const char *name, const char* file, const int line, const std::string& error) { + message = std::string("Failed: ") + name + " error " + file + ":" + std::to_string(line) + " '" + error + "'"; + } + + const char *what() const noexcept override { return message.c_str(); } +}; + +#ifndef CUDA_CHECK +#define CUDA_CHECK(cmd) \ +do { \ + cudaError_t e = (cmd); \ + if (e != cudaSuccess) { \ + throw EPException("CUDA", __FILE__, __LINE__, cudaGetErrorString(e)); \ + } \ +} while (0) +#endif + +#ifndef EP_HOST_ASSERT +#define EP_HOST_ASSERT(cond) \ +do { \ + if (not (cond)) { \ + throw EPException("Assertion", __FILE__, __LINE__, #cond); \ + } \ +} while (0) +#endif + +#ifndef EP_DEVICE_ASSERT +#define EP_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) { \ + printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ + asm("trap;"); \ + } \ +} while (0) +#endif diff --git a/DeepEP/csrc/kernels/ibgda_device.cuh b/DeepEP/csrc/kernels/ibgda_device.cuh new file mode 100644 index 000000000..c9a10a5ab --- /dev/null +++ b/DeepEP/csrc/kernels/ibgda_device.cuh @@ -0,0 +1,482 @@ +// Portions derived from NVSHMEM (https://developer.nvidia.com/nvshmem) +// Copyright (c) NVIDIA Corporation. +// Licensed under the NVSHMEM Software License Agreement (version: September 3, 2019). +// See full license at: https://docs.nvidia.com/nvshmem/api/sla.html +// +// Modified from original source: +// - nvshmem/src/include/non_abi/device/pt-to-pt/ibgda_device.cuh +#pragma once + +#include "configs.cuh" +#include "exception.cuh" +#include "utils.cuh" + +namespace deep_ep { + +EP_STATIC_ASSERT(NVSHMEMI_IBGDA_MIN_QP_DEPTH >= 64, "Invalid QP minimum depth"); + +__device__ static __forceinline__ +uint64_t HtoBE64(uint64_t x) { + uint64_t ret; + asm("{\n\t" + ".reg .b32 ign;\n\t" + ".reg .b32 lo;\n\t" + ".reg .b32 hi;\n\t" + ".reg .b32 new_lo;\n\t" + ".reg .b32 new_hi;\n\t" + "mov.b64 {lo,hi}, %1;\n\t" + "prmt.b32 new_hi, lo, ign, 0x0123;\n\t" + "prmt.b32 new_lo, hi, ign, 0x0123;\n\t" + "mov.b64 %0, {new_lo,new_hi};\n\t" + "}" : "=l"(ret) : "l"(x)); + return ret; +} + +__device__ static __forceinline__ +uint32_t HtoBE32(uint32_t x) { + uint32_t ret; + asm("{\n\t" + ".reg .b32 ign;\n\t" + "prmt.b32 %0, %1, ign, 0x0123;\n\t" + "}" : "=r"(ret) : "r"(x)); + return ret; +} + +__device__ static __forceinline__ +uint16_t HtoBE16(uint16_t x) { + // TODO: simplify PTX using 16-bit instructions + auto a = static_cast(x); + uint32_t d; + asm volatile( + "{\n\t" + ".reg .b32 mask;\n\t" + ".reg .b32 ign;\n\t" + "mov.b32 mask, 0x4401;\n\t" + "mov.b32 ign, 0x0;\n\t" + "prmt.b32 %0, %1, ign, mask;\n\t" + "}" + : "=r"(d) + : "r"(a)); + return static_cast(d); +} + +typedef struct mlx5_wqe_ctrl_seg __attribute__((__aligned__(8))) ibgda_ctrl_seg_t; + +typedef struct { + uint32_t add_data; + uint32_t field_boundary; + uint64_t reserved; +} __attribute__((__packed__)) ibgda_atomic_32_masked_fa_seg_t; + +__device__ static __forceinline__ +nvshmemi_ibgda_device_state_t* ibgda_get_state() { + return &nvshmemi_ibgda_device_state_d; +} + +__device__ static __forceinline__ +nvshmemi_ibgda_device_qp_t* ibgda_get_rc(int pe, int id) { + auto state = ibgda_get_state(); + const auto num_rc_per_pe = ibgda_get_state()->num_rc_per_pe; + return &state->globalmem.rcs[pe * num_rc_per_pe + id % num_rc_per_pe]; +} + +__device__ static __forceinline__ +void ibgda_lock_acquire(int *lock) { + while (atomicCAS(lock, 0, 1) == 1); + + // Prevent reordering before the lock is acquired + memory_fence_cta(); +} + +__device__ static __forceinline__ +void ibgda_lock_release(int *lock) { + memory_fence_cta(); + + // Prevent reordering before lock is released + st_na_relaxed(lock, 0); +} + +__device__ static __forceinline__ +void ibgda_update_dbr(nvshmemi_ibgda_device_qp_t *qp, uint32_t dbrec_head) { + // `DBREC` contains the index of the next empty `WQEBB` + __be32 dbrec_val; + __be32 *dbrec_ptr = qp->tx_wq.dbrec; + + // This is equivalent to `WRITE_ONCE(dbrec_ptr, HtoBE32(dbrec_head & 0xffff))` + asm("{\n\t" + ".reg .b32 dbrec_head_16b;\n\t" + ".reg .b32 ign;\n\t" + "and.b32 dbrec_head_16b, %1, 0xffff;\n\t" + "prmt.b32 %0, dbrec_head_16b, ign, 0x123;\n\t" + "}" + : "=r"(dbrec_val) + : "r"(dbrec_head)); + st_na_release(dbrec_ptr, dbrec_val); +} + +__device__ static __forceinline__ +void ibgda_ring_db(nvshmemi_ibgda_device_qp_t *qp, uint16_t prod_idx) { + auto bf_ptr = reinterpret_cast(qp->tx_wq.bf); + ibgda_ctrl_seg_t ctrl_seg = { + .opmod_idx_opcode = HtoBE32(prod_idx << 8), + .qpn_ds = HtoBE32(qp->qpn << 8) + }; + + EP_STATIC_ASSERT(sizeof(decltype(&ctrl_seg)) == sizeof(uint64_t), ""); + st_na_release(bf_ptr, *(reinterpret_cast(&ctrl_seg))); +} + +__device__ static __forceinline__ +void ibgda_post_send(nvshmemi_ibgda_device_qp_t *qp, uint64_t new_prod_idx) { + nvshmemi_ibgda_device_qp_management_t *mvars = &qp->mvars; + uint64_t old_prod_idx; + + // Update `prod_idx` before ringing the doorbell, so that we know which index is needed in quiet/fence + ibgda_lock_acquire(&mvars->post_send_lock); + + old_prod_idx = atomicMax(reinterpret_cast(&mvars->tx_wq.prod_idx), new_prod_idx); + if (new_prod_idx > old_prod_idx) { + ibgda_update_dbr(qp, new_prod_idx); + ibgda_ring_db(qp, new_prod_idx); + } + ibgda_lock_release(&mvars->post_send_lock); +} + +template +__device__ static __forceinline__ +void ibgda_submit_requests(nvshmemi_ibgda_device_qp_t *qp, uint64_t base_wqe_idx, + uint32_t num_wqes, int message_idx = 0) { + nvshmemi_ibgda_device_qp_management_t *mvars = &qp->mvars; + uint64_t new_wqe_idx = base_wqe_idx + num_wqes; + + // WQE writes must be finished first + __threadfence(); + + // Wait for prior WQE slots to be filled first + auto *ready_idx = reinterpret_cast(&mvars->tx_wq.ready_head); + while (atomicCAS(ready_idx, base_wqe_idx, new_wqe_idx) != base_wqe_idx); + + // Always post, not in batch + constexpr int kNumRequestInBatch = 4; + if (kAlwaysDoPostSend or (message_idx + 1) % kNumRequestInBatch == 0) + ibgda_post_send(qp, new_wqe_idx); +} + +__device__ static __forceinline__ void +ibgda_write_rdma_write_inl_wqe(nvshmemi_ibgda_device_qp_t *qp, const uint32_t *val, uint64_t raddr, + __be32 rkey, uint16_t wqe_idx, void** out_wqes, uint32_t imm) { + ibgda_ctrl_seg_t ctrl_seg; + struct mlx5_wqe_raddr_seg raddr_seg; + struct mlx5_wqe_inl_data_seg inl_seg; + + auto *ctrl_seg_ptr = reinterpret_cast(out_wqes[0]); + auto *raddr_seg_ptr = reinterpret_cast(reinterpret_cast(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr)); + auto *inl_seg_ptr = reinterpret_cast(reinterpret_cast(raddr_seg_ptr) + sizeof(*raddr_seg_ptr)); + auto *wqe_data_ptr = reinterpret_cast(reinterpret_cast(inl_seg_ptr) + sizeof(*inl_seg_ptr)); + + raddr_seg.raddr = HtoBE64(raddr); + raddr_seg.rkey = rkey; + raddr_seg.reserved = 0; + + inl_seg.byte_count = HtoBE32(4 | MLX5_INLINE_SEG); + + // `imm == std::numeric_limits::max()` means no imm writes + ctrl_seg = {0}; + ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 3); + ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; + ctrl_seg.opmod_idx_opcode = HtoBE32((wqe_idx << 8) | (imm != std::numeric_limits::max() ? MLX5_OPCODE_RDMA_WRITE_IMM : MLX5_OPCODE_RDMA_WRITE)); + if (imm != std::numeric_limits::max()) + ctrl_seg.imm = HtoBE32(imm); + + EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == 16, "sizeof(*ctrl_seg_ptr) == 16"); + EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == 16, "sizeof(*raddr_seg_ptr) == 16"); + EP_STATIC_ASSERT(sizeof(*inl_seg_ptr) == 4, "sizeof(*inl_seg_ptr) == 4"); + st_na_relaxed(reinterpret_cast(ctrl_seg_ptr), *reinterpret_cast(&ctrl_seg)); + st_na_relaxed(reinterpret_cast(raddr_seg_ptr), *reinterpret_cast(&raddr_seg)); + st_na_relaxed(reinterpret_cast(inl_seg_ptr), *reinterpret_cast(&inl_seg)); + st_na_relaxed(reinterpret_cast(wqe_data_ptr), *reinterpret_cast(val)); +} + +__device__ static __forceinline__ +uint64_t ibgda_get_lkey_and_rkey(uint64_t laddr, __be32 *lkey, + uint64_t raddr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey) { + auto state = ibgda_get_state(); + auto heap_start = reinterpret_cast(nvshmemi_device_state_d.heap_base); + auto log2_cumem_granularity = state->log2_cumem_granularity; + + // Local key + uint64_t idx = (laddr - heap_start) >> log2_cumem_granularity; + auto device_key = state->constmem.lkeys[idx]; + auto lchunk_size = device_key.next_addr - laddr; + *lkey = device_key.key; + + // Remote key + uint64_t roffset = raddr - heap_start; + idx = ((roffset >> log2_cumem_granularity) * nvshmemi_device_state_d.npes) + dst_pe; + if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS) { + device_key = state->constmem.rkeys[idx]; + } else { + device_key = state->globalmem.rkeys[idx - NVSHMEMI_IBGDA_MAX_CONST_RKEYS]; + } + *out_raddr = reinterpret_cast(nvshmemi_device_state_d.peer_heap_base_remote[dst_pe]) + roffset; + *out_rkey = device_key.key; + + // Return the minimum of local and remote chunk sizes + auto rchunk_size = device_key.next_addr - roffset; + return min(lchunk_size, rchunk_size); +} + +__device__ static __forceinline__ void +ibgda_get_rkey(uint64_t addr, int dst_pe, uint64_t *out_raddr, __be32 *out_rkey) { + auto state = ibgda_get_state(); + auto heap_start = reinterpret_cast(nvshmemi_device_state_d.heap_base); + + uint64_t roffset = addr - heap_start; + uint64_t idx = ((roffset >> state->log2_cumem_granularity) * nvshmemi_device_state_d.npes) + dst_pe; + nvshmemi_ibgda_device_key_t device_key; + if (idx < NVSHMEMI_IBGDA_MAX_CONST_RKEYS) + device_key = state->constmem.rkeys[idx]; + else + device_key = state->globalmem.rkeys[idx - NVSHMEMI_IBGDA_MAX_CONST_RKEYS]; + *out_raddr = reinterpret_cast(nvshmemi_device_state_d.peer_heap_base_remote[dst_pe]) + roffset; + *out_rkey = device_key.key; +} + +__device__ static __forceinline__ uint64_t +ibgda_reserve_wqe_slots(nvshmemi_ibgda_device_qp_t *qp, uint32_t num_wqes) { + auto mvars = &qp->mvars; + return atomicAdd(reinterpret_cast(&mvars->tx_wq.resv_head), static_cast(num_wqes)); +} + +__device__ static __forceinline__ void* +ibgda_get_wqe_ptr(nvshmemi_ibgda_device_qp_t* qp, uint16_t wqe_idx) { + uint16_t cnt = qp->tx_wq.nwqes; + uint16_t idx = wqe_idx & (cnt - 1); + return reinterpret_cast(reinterpret_cast(qp->tx_wq.wqe) + (idx << MLX5_SEND_WQE_SHIFT)); +} + +__device__ static __forceinline__ void +nvshmemi_ibgda_rma_p(int *rptr, const int value, int dst_pe, int qp_id, uint32_t imm = std::numeric_limits::max()) { + // Get rkey + // NOTES: the `p` operation will not cross multiple remote chunks + __be32 rkey; + uint64_t raddr; + ibgda_get_rkey(reinterpret_cast(rptr), dst_pe, &raddr, &rkey); + + // Write WQEs + auto qp = ibgda_get_rc(dst_pe, qp_id); + uint64_t base_wqe_idx = ibgda_reserve_wqe_slots(qp, 1); + void *wqe_ptrs; + wqe_ptrs = ibgda_get_wqe_ptr(qp, base_wqe_idx); + ibgda_write_rdma_write_inl_wqe(qp, reinterpret_cast(&value), raddr, rkey, base_wqe_idx, &wqe_ptrs, imm); + + // Submit requests + ibgda_submit_requests(qp, base_wqe_idx, 1); +} + +__device__ static __forceinline__ void +ibgda_write_rdma_write_wqe(nvshmemi_ibgda_device_qp_t *qp, uint64_t laddr, __be32 lkey, + uint64_t raddr, __be32 rkey, uint32_t bytes, uint16_t wqe_idx, + void** out_wqes) { + ibgda_ctrl_seg_t ctrl_seg; + struct mlx5_wqe_raddr_seg raddr_seg; + struct mlx5_wqe_data_seg data_seg; + + auto *ctrl_seg_ptr = reinterpret_cast(out_wqes[0]); + void *av_seg_ptr = reinterpret_cast(reinterpret_cast(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr)); + struct mlx5_wqe_raddr_seg *raddr_seg_ptr; + struct mlx5_wqe_data_seg *data_seg_ptr; + + raddr_seg_ptr = reinterpret_cast(reinterpret_cast(av_seg_ptr)); + data_seg_ptr = reinterpret_cast(reinterpret_cast(raddr_seg_ptr) + sizeof(*raddr_seg_ptr)); + + raddr_seg.raddr = HtoBE64(raddr); + raddr_seg.rkey = rkey; + raddr_seg.reserved = 0; + + data_seg.byte_count = HtoBE32(bytes); + data_seg.lkey = lkey; + data_seg.addr = HtoBE64(laddr); + + ctrl_seg = {0}; + ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 3); + ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; + ctrl_seg.opmod_idx_opcode = HtoBE32((wqe_idx << 8) | MLX5_OPCODE_RDMA_WRITE); + + EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == 16, "sizeof(*ctrl_seg_ptr) == 16"); + EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == 16, "sizeof(*raddr_seg_ptr) == 16"); + EP_STATIC_ASSERT(sizeof(*data_seg_ptr) == 16, "sizeof(*data_seg_ptr) == 16"); + st_na_relaxed(reinterpret_cast(ctrl_seg_ptr), *reinterpret_cast(&ctrl_seg)); + st_na_relaxed(reinterpret_cast(raddr_seg_ptr), *reinterpret_cast(&raddr_seg)); + st_na_relaxed(reinterpret_cast(data_seg_ptr), *reinterpret_cast(&data_seg)); +} + +__device__ static __forceinline__ void +ibgda_write_empty_recv_wqe(void *out_wqe) { + auto *data_seg_ptr = reinterpret_cast(out_wqe); + struct mlx5_wqe_data_seg data_seg; + + // Make the first segment in the WQE invalid, then the entire list will be invalid + data_seg.byte_count = 0; + data_seg.lkey = HtoBE64(MLX5_INVALID_LKEY); + data_seg.addr = 0; + + EP_STATIC_ASSERT(sizeof(mlx5_wqe_data_seg) == sizeof(int4), "Invalid data type length"); + st_na_relaxed(reinterpret_cast(data_seg_ptr), *reinterpret_cast(&data_seg)); +} + +template +__device__ static __forceinline__ void +nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes, int dst_pe, int qp_id, int lane_id, int message_idx) { + // Get lkey and rkey, store them into lanes + uint32_t num_wqes = 0; + __be32 my_lkey = 0; + uint64_t my_laddr = 0; + __be32 my_rkey = 0; + uint64_t my_raddr = 0; + uint64_t my_chunk_size = 0; + + // Decide how many messages (theoretically 3 for maximum) + auto remaining_bytes = bytes; + while (remaining_bytes > 0) { + if (lane_id == num_wqes) + my_chunk_size = min(remaining_bytes, ibgda_get_lkey_and_rkey(my_laddr = req_lptr, &my_lkey, req_rptr, dst_pe, &my_raddr, &my_rkey)); + + // Move one more message + auto chunk_size = __shfl_sync(0xffffffff, my_chunk_size, static_cast(num_wqes)); + remaining_bytes -= chunk_size; + req_lptr += chunk_size; + req_rptr += chunk_size; + ++ num_wqes; + } + EP_DEVICE_ASSERT(num_wqes <= 32); + + // Process WQE + auto qp = ibgda_get_rc(dst_pe, qp_id); + uint64_t base_wqe_idx = 0; + if (lane_id == 0) + base_wqe_idx = ibgda_reserve_wqe_slots(qp, num_wqes); + base_wqe_idx = __shfl_sync(0xffffffff, base_wqe_idx, 0); + if (lane_id < num_wqes) { + auto wqe_ptr = ibgda_get_wqe_ptr(qp, base_wqe_idx + lane_id); + ibgda_write_rdma_write_wqe(qp, my_laddr, my_lkey, my_raddr, my_rkey, my_chunk_size, + base_wqe_idx, &wqe_ptr); + } + __syncwarp(); + + // Submit + if (lane_id == 0) + ibgda_submit_requests(qp, base_wqe_idx, num_wqes, message_idx); + __syncwarp(); +} + +__device__ static __forceinline__ void ibgda_write_amo_add_wqe( + nvshmemi_ibgda_device_qp_t *qp, const int &value, + uint64_t laddr, __be32 lkey, uint64_t raddr, __be32 rkey, + uint16_t wqe_idx, void** out_wqes) { + ibgda_ctrl_seg_t ctrl_seg = {0}; + struct mlx5_wqe_raddr_seg raddr_seg; + struct mlx5_wqe_atomic_seg atomic_seg_1; + struct mlx5_wqe_data_seg data_seg; + + auto ctrl_seg_ptr = reinterpret_cast(out_wqes[0]); + auto raddr_seg_ptr = reinterpret_cast(reinterpret_cast(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr)); + auto atomic_seg_ptr = reinterpret_cast(reinterpret_cast(raddr_seg_ptr) + sizeof(*raddr_seg_ptr)); + auto data_seg_ptr = reinterpret_cast(reinterpret_cast(atomic_seg_ptr) + sizeof(*atomic_seg_ptr)); + + raddr_seg.raddr = HtoBE64(raddr); + raddr_seg.rkey = rkey; + raddr_seg.reserved = 0; + + // NOTES: `0x08000000` means `IBGDA_4_BYTE_EXT_AMO_OPMOD` + ctrl_seg.opmod_idx_opcode = HtoBE32(MLX5_OPCODE_ATOMIC_MASKED_FA | (wqe_idx << 8) | 0x08000000); + auto atomic_32_masked_fa_seg = reinterpret_cast(&atomic_seg_1); + atomic_32_masked_fa_seg->add_data = HtoBE32(value); + atomic_32_masked_fa_seg->field_boundary = 0; + + ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 4); + ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; + + data_seg.byte_count = HtoBE32(sizeof(int)); + data_seg.lkey = lkey; + data_seg.addr = HtoBE64(laddr); + + EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == sizeof(int4), "Invalid vectorization"); + EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == sizeof(int4), "Invalid vectorization"); + EP_STATIC_ASSERT(sizeof(*atomic_seg_ptr) == sizeof(int4), "Invalid vectorization"); + EP_STATIC_ASSERT(sizeof(*data_seg_ptr) == sizeof(int4), "Invalid vectorization"); + st_na_relaxed(reinterpret_cast(ctrl_seg_ptr), *reinterpret_cast(&ctrl_seg)); + st_na_relaxed(reinterpret_cast(raddr_seg_ptr), *reinterpret_cast(&raddr_seg)); + st_na_relaxed(reinterpret_cast(atomic_seg_ptr), *reinterpret_cast(&atomic_seg_1)); + st_na_relaxed(reinterpret_cast(data_seg_ptr), *reinterpret_cast(&data_seg)); +} + +__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id, bool is_local_copy = false) { + if (is_local_copy) { + atomicAdd(static_cast(rptr), value); + } else { + nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id); + + __be32 rkey; + uint64_t raddr; + ibgda_get_rkey(reinterpret_cast(rptr), pe, &raddr, &rkey); + + uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1); + void *wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx); + + ibgda_write_amo_add_wqe(qp, value, reinterpret_cast(qp->ibuf.buf), + qp->ibuf.lkey, raddr, rkey, my_wqe_idx, &wqe_ptrs); + + ibgda_submit_requests(qp, my_wqe_idx, 1); + } +} + +__device__ __forceinline__ uint64_t nvshmemi_get_p2p_ptr(const uint64_t& ptr, const int& rank, const int& dst_rank) { + // Local rank, no need for mapping + if (rank == dst_rank) + return ptr; + auto peer_base = __ldg(reinterpret_cast(nvshmemi_device_state_d.peer_heap_base_p2p) + dst_rank); + + // RDMA connected + if (peer_base == 0) + return 0; + + // NVLink P2P is enabled + return peer_base + (ptr - reinterpret_cast(nvshmemi_device_state_d.heap_base)); +} + +// This is a simplified version of NVSHMEM's `ibgda_poll_cq`. +// Note that this implementation does not guarantee thread safety, +// so we must ensure that no other threads are concurrently using the same QP. +__device__ static __forceinline__ void +ibgda_poll_cq(nvshmemi_ibgda_device_cq_t *cq, uint64_t idx) { + const auto cqe64 = static_cast(cq->cqe); + const uint32_t ncqes = cq->ncqes; + memory_fence_cta(); + + // NOTES: this while loop is part of do-while below. + // `wqe_counter` is the HW consumer index. However, we always maintain `index + 1`. + // To be able to compare with the index, we need to use `wqe_counter + 1`. + // Because `wqe_counter` is `uint16_t`, it may be overflow. Still, we know for + // sure that if `idx - wqe_counter - 1 < ncqes`, `wqe_counter + 1 is less than + // idx, and thus we need to wait. We don't need to wait when `idx == wqe_counter + 1` + // That's why we use `- 2` here to make this case overflow. + uint16_t wqe_counter; + do { + wqe_counter = HtoBE16(ld_na_relaxed(&cqe64->wqe_counter)); + } while ((static_cast(static_cast(idx) - wqe_counter - static_cast(2)) < ncqes)); + *cq->cons_idx = idx; + + // Prevent reordering of this function and later instructions + memory_fence_cta(); +} + +// Wait until wqe `idx - 1` is completed. +__device__ static __forceinline__ void +nvshmemi_ibgda_quiet(int dst_pe, int qp_id) { + auto qp = ibgda_get_rc(dst_pe, qp_id); + uint64_t prod_idx = ld_na_relaxed(qp->tx_wq.prod_idx); + ibgda_poll_cq(qp->tx_wq.cq, prod_idx); +} + +} // namespace deep_ep diff --git a/DeepEP/csrc/kernels/internode.cu b/DeepEP/csrc/kernels/internode.cu new file mode 100644 index 000000000..4dacaa2fb --- /dev/null +++ b/DeepEP/csrc/kernels/internode.cu @@ -0,0 +1,1673 @@ +#include "configs.cuh" +#include "buffer.cuh" +#include "exception.cuh" +#include "launch.cuh" +#include "utils.cuh" +#include "ibgda_device.cuh" + +namespace deep_ep { + +namespace internode { + +extern nvshmem_team_t cpu_rdma_team; + +struct SourceMeta { + int src_rdma_rank, is_token_in_nvl_rank_bits; + + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS == 8, "Invalid number of maximum NVL peers"); + + __forceinline__ SourceMeta() = default; + + // TODO: faster encoding + __device__ __forceinline__ SourceMeta(int rdma_rank, const bool* is_token_in_nvl_ranks) { + src_rdma_rank = rdma_rank; + is_token_in_nvl_rank_bits = is_token_in_nvl_ranks[0]; + #pragma unroll + for (int i = 1; i < NUM_MAX_NVL_PEERS; ++ i) + is_token_in_nvl_rank_bits |= is_token_in_nvl_ranks[i] << i; + } + + __device__ __forceinline__ bool is_token_in_nvl_rank(int nvl_rank) const { + return (is_token_in_nvl_rank_bits >> nvl_rank) & 1; + } +}; + +EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, "Invalid size of `SourceMeta`"); + +int get_source_meta_bytes() { + return sizeof(SourceMeta); +} + +__host__ __device__ __forceinline__ +int get_num_bytes_per_rdma_token(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights) { + return static_cast(align(hidden_int4 * sizeof(int4) + sizeof(SourceMeta) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float), sizeof(int4))); +} + +__host__ __device__ __forceinline__ +std::pair get_rdma_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, int num_rdma_ranks, int num_rdma_recv_buffer_tokens, int num_sms) { + // Return `int32_t` offset and count to clean + return { + (get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_rdma_recv_buffer_tokens * num_rdma_ranks * 2 * num_sms) / sizeof(int), + (NUM_MAX_NVL_PEERS * 2 + 4) * num_rdma_ranks * 2 * num_sms + }; +} + +__host__ __device__ __forceinline__ +std::pair get_nvl_clean_meta(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, int num_rdma_ranks, int num_nvl_ranks, int num_nvl_recv_buffer_tokens, int num_sms) { + // Return `int32_t` offset and to clean + EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, "Invalid size of `SourceMeta`"); + return { + (num_nvl_recv_buffer_tokens * (hidden_int4 * sizeof(int4) + num_scales * sizeof(float) + num_topk_idx * sizeof(int) + num_topk_weights * sizeof(float) + sizeof(SourceMeta)) * num_nvl_ranks * num_sms) / sizeof(int), + num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_sms, + }; +} + +template +__forceinline__ __device__ int translate_dst_rdma_rank(const int dst_rdma_rank, const int nvl_rank) { + return kLowLatencyMode ? (dst_rdma_rank * NUM_MAX_NVL_PEERS + nvl_rank) : dst_rdma_rank; +} + +template +__forceinline__ __device__ void nvshmem_sync_with_same_gpu_idx(const nvshmem_team_t& rdma_team) { + kLowLatencyMode ? void(nvshmem_sync(rdma_team)) : nvshmem_sync_all(); +} + +template +__global__ void +notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, + const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, + const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, + const bool* is_token_in_rank, int num_tokens, int num_channels, int expert_alignment, + const int rdma_clean_offset, const int rdma_num_int_clean, + const int nvl_clean_offset, const int nvl_num_int_clean, + int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, + void* rdma_buffer_ptr, + void** buffer_ptrs, int** barrier_signal_ptrs, int rank, + const nvshmem_team_t rdma_team) { + auto sm_id = static_cast(blockIdx.x); + auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id(); + auto num_threads = static_cast(blockDim.x), num_warps = num_threads / 32; + + auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; + auto num_rdma_experts = num_experts / kNumRDMARanks, num_nvl_experts = num_rdma_experts / NUM_MAX_NVL_PEERS; + + if (sm_id == 0) { + // Communication with others + // Global barrier: the first warp does intra-node sync, the second warp does internode sync + EP_DEVICE_ASSERT(num_warps > 1); + EP_DEVICE_ASSERT(kNumRDMARanks <= num_threads); + if (thread_id == 32) + nvshmem_sync_with_same_gpu_idx(rdma_team); + barrier_block(barrier_signal_ptrs, nvl_rank); + + // Send numbers of tokens per rank/expert to RDMA ranks + auto rdma_buffer_ptr_int = static_cast(rdma_buffer_ptr); + auto rdma_recv_num_tokens_mixed = SymBuffer(rdma_buffer_ptr, NUM_MAX_NVL_PEERS + num_rdma_experts + 1, kNumRDMARanks); + + // Clean up for later data dispatch + EP_DEVICE_ASSERT(rdma_recv_num_tokens_mixed.total_bytes <= rdma_clean_offset * sizeof(int)); + #pragma unroll + for (int i = thread_id; i < rdma_num_int_clean; i += num_threads) + rdma_buffer_ptr_int[rdma_clean_offset + i] = 0; + + // Copy to send buffer + #pragma unroll + for (int i = thread_id; i < num_ranks; i += num_threads) + rdma_recv_num_tokens_mixed.send_buffer(i / NUM_MAX_NVL_PEERS)[i % NUM_MAX_NVL_PEERS] = num_tokens_per_rank[i]; + #pragma unroll + for (int i = thread_id; i < num_experts; i += num_threads) + rdma_recv_num_tokens_mixed.send_buffer(i / num_rdma_experts)[NUM_MAX_NVL_PEERS + i % num_rdma_experts] = num_tokens_per_expert[i]; + if (thread_id < kNumRDMARanks) + rdma_recv_num_tokens_mixed.send_buffer(thread_id)[NUM_MAX_NVL_PEERS + num_rdma_experts] = num_tokens_per_rdma_rank[thread_id]; + __syncthreads(); + + // Issue send + // TODO: more light fence or barrier or signaling + // TODO: overlap EP barrier and NVL cleaning + for (int i = 0; i < kNumRDMARanks; ++i) { + if (i != rdma_rank) { + if (warp_id == 0) { + nvshmemi_ibgda_put_nbi_warp(reinterpret_cast(rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank)), + reinterpret_cast(rdma_recv_num_tokens_mixed.send_buffer(i)), + (NUM_MAX_NVL_PEERS + num_rdma_experts + 1) * sizeof(int), + translate_dst_rdma_rank(i, nvl_rank), 0, lane_id, 0); + } + } else { + UNROLLED_WARP_COPY(1, lane_id, NUM_MAX_NVL_PEERS + num_rdma_experts + 1, + rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank), + rdma_recv_num_tokens_mixed.send_buffer(i), + ld_volatile_global, st_na_global); + } + } + if (thread_id < kNumRDMARanks and thread_id != rdma_rank) + nvshmemi_ibgda_quiet(translate_dst_rdma_rank(thread_id, nvl_rank), 0); + + __syncthreads(); + if (thread_id == 0) + nvshmem_sync_with_same_gpu_idx(rdma_team); + __syncthreads(); + + // NVL buffers + auto nvl_send_buffer = thread_id < NUM_MAX_NVL_PEERS ? buffer_ptrs[thread_id] : nullptr; + auto nvl_recv_buffer = buffer_ptrs[nvl_rank]; + auto nvl_reduced_num_tokens_per_expert = Buffer(nvl_recv_buffer, num_rdma_experts).advance_also(nvl_send_buffer); + auto nvl_send_num_tokens_per_rank = AsymBuffer(nvl_send_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS); + auto nvl_send_num_tokens_per_expert = AsymBuffer(nvl_send_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS); + auto nvl_recv_num_tokens_per_rank = AsymBuffer(nvl_recv_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS); + auto nvl_recv_num_tokens_per_expert = AsymBuffer(nvl_recv_buffer, num_nvl_experts, NUM_MAX_NVL_PEERS); + + // Clean up for later data dispatch + auto nvl_buffer_ptr_int = static_cast(buffer_ptrs[nvl_rank]); + EP_DEVICE_ASSERT(nvl_reduced_num_tokens_per_expert.total_bytes + nvl_send_num_tokens_per_rank.total_bytes + + nvl_send_num_tokens_per_expert.total_bytes <= nvl_clean_offset * sizeof(int)); + #pragma unroll + for (int i = thread_id; i < nvl_num_int_clean; i += num_threads) + nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; + + // Reduce number of tokens per expert into the NVL send buffer + // TODO: may use NVSHMEM reduction + EP_DEVICE_ASSERT(num_rdma_experts <= num_threads); + if (thread_id < num_rdma_experts) { + int sum = 0; + #pragma unroll + for (int i = 0; i < kNumRDMARanks; ++ i) + sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + thread_id]; + nvl_reduced_num_tokens_per_expert[thread_id] = sum; + } + __syncthreads(); + + // Reduce RDMA received tokens + if (thread_id == 0) { + int sum = 0; + #pragma unroll + for (int i = 0; i < kNumRDMARanks; ++ i) { + sum += rdma_recv_num_tokens_mixed.recv_buffer(i)[NUM_MAX_NVL_PEERS + num_rdma_experts]; + recv_rdma_rank_prefix_sum[i] = sum; + } + while (ld_volatile_global(moe_recv_rdma_counter_mapped) != -1); + *moe_recv_rdma_counter_mapped = sum; + } + + // Send numbers of tokens per rank/expert to NVL ranks + EP_DEVICE_ASSERT(NUM_MAX_NVL_PEERS <= num_threads); + if (thread_id < NUM_MAX_NVL_PEERS) { + #pragma unroll + for (int i = 0; i < kNumRDMARanks; ++ i) + nvl_send_num_tokens_per_rank.buffer(nvl_rank)[i] = rdma_recv_num_tokens_mixed.recv_buffer(i)[thread_id]; + #pragma unroll + for (int i = 0; i < num_nvl_experts; ++ i) + nvl_send_num_tokens_per_expert.buffer(nvl_rank)[i] = nvl_reduced_num_tokens_per_expert[thread_id * num_nvl_experts + i]; + } + barrier_block(barrier_signal_ptrs, nvl_rank); + + // Reduce the number of tokens per rank/expert + EP_DEVICE_ASSERT(num_nvl_experts <= num_threads); + if (thread_id == 0) { + int sum = 0; + #pragma unroll + for (int i = 0; i < num_ranks; ++ i) { + int src_rdma_rank = i / NUM_MAX_NVL_PEERS, src_nvl_rank = i % NUM_MAX_NVL_PEERS; + sum += nvl_recv_num_tokens_per_rank.buffer(src_nvl_rank)[src_rdma_rank]; + recv_gbl_rank_prefix_sum[i] = sum; + } + while (ld_volatile_global(moe_recv_counter_mapped) != -1); + *moe_recv_counter_mapped = sum; + } + if (thread_id < num_nvl_experts) { + int sum = 0; + #pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i) + sum += nvl_recv_num_tokens_per_expert.buffer(i)[thread_id]; + sum = (sum + expert_alignment - 1) / expert_alignment * expert_alignment; + while (ld_volatile_global(moe_recv_expert_counter_mapped + thread_id) != -1); + moe_recv_expert_counter_mapped[thread_id] = sum; + } + + // Finally barrier + if (thread_id == 32) + nvshmem_sync_with_same_gpu_idx(rdma_team); + barrier_block(barrier_signal_ptrs, nvl_rank); + } else { + // Calculate meta data + int dst_rdma_rank = sm_id - 1; + for (int channel_id = warp_id; channel_id < num_channels; channel_id += num_warps) { + int token_start_idx, token_end_idx; + get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); + + // Iterate over tokens + int total_count = 0, per_nvl_rank_count[NUM_MAX_NVL_PEERS] = {0}; + for (int64_t i = token_start_idx + lane_id; i < token_end_idx; i += 32) { + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers"); + auto is_token_in_rank_uint64 = *reinterpret_cast(is_token_in_rank + i * num_ranks + dst_rdma_rank * NUM_MAX_NVL_PEERS); + auto is_token_in_rank_values = reinterpret_cast(&is_token_in_rank_uint64); + #pragma unroll + for (int j = 0; j < NUM_MAX_NVL_PEERS; ++ j) + per_nvl_rank_count[j] += is_token_in_rank_values[j]; + total_count += (is_token_in_rank_uint64 != 0); + } + + // Warp reduce + total_count = warp_reduce_sum(total_count); + #pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i) + per_nvl_rank_count[i] = warp_reduce_sum(per_nvl_rank_count[i]); + + // Write into channel matrix + if (lane_id == 0) { + #pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i) + gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + i) * num_channels + channel_id] = per_nvl_rank_count[i]; + rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] = total_count; + } + } + + // Calculate prefix sum + __syncthreads(); + if (thread_id == 0) { + auto prefix_row = rdma_channel_prefix_matrix + dst_rdma_rank * num_channels; + #pragma unroll + for (int i = 1; i < num_channels; ++ i) + prefix_row[i] += prefix_row[i - 1]; + } + + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers"); + if (thread_id < NUM_MAX_NVL_PEERS) { + auto prefix_row = gbl_channel_prefix_matrix + (dst_rdma_rank * NUM_MAX_NVL_PEERS + thread_id) * num_channels; + #pragma unroll + for (int i = 1; i < num_channels; ++ i) + prefix_row[i] += prefix_row[i - 1]; + } + } +} + +void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, + const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, + const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, + const bool* is_token_in_rank, int num_tokens, int num_channels, + int hidden_int4, int num_scales, int num_topk, int expert_alignment, + int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, + void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, + int** barrier_signal_ptrs, int rank, + cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, + bool low_latency_mode) { +#define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \ + auto notify_dispatch_func = low_latency_mode ? \ + notify_dispatch : notify_dispatch; \ + LAUNCH_KERNEL(&cfg, notify_dispatch_func, \ + num_tokens_per_rank, moe_recv_counter_mapped, num_ranks, \ + num_tokens_per_rdma_rank, moe_recv_rdma_counter_mapped, \ + num_tokens_per_expert, moe_recv_expert_counter_mapped, num_experts, \ + is_token_in_rank, num_tokens, num_channels, expert_alignment, \ + rdma_clean_meta.first, rdma_clean_meta.second, \ + nvl_clean_meta.first, nvl_clean_meta.second, \ + rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \ + gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \ + rdma_buffer_ptr, \ + buffer_ptrs, barrier_signal_ptrs, rank, \ + cpu_rdma_team); } break + + constexpr int kNumThreads = 512; + const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + + // Get clean meta + auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels); + auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk, num_topk, num_rdma_ranks, NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels); + EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes); + EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes); + EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits::max()); + EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits::max()); + + // Launch kernel + SETUP_LAUNCH_CONFIG(1 + num_rdma_ranks, kNumThreads, stream); + SWITCH_RDMA_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE); +#undef NOTIFY_DISPATCH_LAUNCH_CASE +} + +// At most 8 RDMA ranks to be sent +constexpr int get_num_topk_rdma_ranks(int num_rdma_ranks) { + return num_rdma_ranks < 8 ? num_rdma_ranks : 8; +} + +template +__global__ void __launch_bounds__(((kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32), 1) +dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, SourceMeta* recv_src_meta, + const int4* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, + int* send_rdma_head, int* send_nvl_head, + int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, + const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum, + const bool* is_token_in_rank, + int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, + int scale_token_stride, int scale_hidden_stride, + void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, + int rank, int num_ranks) { + enum class WarpRole { + kRDMASender, + kRDMASenderCoordinator, + kRDMAAndNVLForwarder, + kForwarderCoordinator, + kNVLReceivers + }; + + const auto num_sms = static_cast(gridDim.x); + const auto sm_id = static_cast(blockIdx.x); + const auto num_threads = static_cast(blockDim.x), num_warps = num_threads / 32; + const auto thread_id = static_cast(threadIdx.x), warp_id = thread_id / 32, lane_id = get_lane_id(); + const auto num_channels = num_sms / 2, channel_id = sm_id / 2; + const bool is_forwarder = sm_id % 2 == 0; + const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; + + EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe == num_channels or ibgda_get_state()->num_rc_per_pe >= num_sms); + + const auto role_meta = [=]() -> std::pair { + if (is_forwarder) { + if (warp_id < NUM_MAX_NVL_PEERS) { + return {WarpRole::kRDMAAndNVLForwarder, (warp_id + channel_id) % NUM_MAX_NVL_PEERS}; + } else { + return {WarpRole::kForwarderCoordinator, warp_id - NUM_MAX_NVL_PEERS}; + } + } else if (warp_id < kNumDispatchRDMASenderWarps) { + return {WarpRole::kRDMASender, -1}; + } else if (warp_id == kNumDispatchRDMASenderWarps) { + return {WarpRole::kRDMASenderCoordinator, -1}; + } else { + return {WarpRole::kNVLReceivers, (warp_id + channel_id - kNumDispatchRDMASenderWarps) % NUM_MAX_NVL_PEERS}; + } + }(); + auto warp_role = role_meta.first; + auto target_rank = role_meta.second; // Not applicable for RDMA senders + EP_DEVICE_ASSERT(num_warps == kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS); + + // Data checks + EP_DEVICE_ASSERT(num_topk <= 32); + + // RDMA symmetric layout + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * sizeof(bool) == sizeof(uint64_t), "Invalid number of NVL peers"); + auto hidden_bytes = hidden_int4 * sizeof(int4); + auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, num_scales, num_topk, num_topk); + auto rdma_channel_data = SymBuffer(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels); + auto rdma_channel_meta = SymBuffer(rdma_buffer_ptr, NUM_MAX_NVL_PEERS * 2 + 2, kNumRDMARanks, channel_id, num_channels); + auto rdma_channel_head = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); + auto rdma_channel_tail = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); + + // NVL buffer layouts + // NOTES: `rs_wr_buffer_ptr` means "Read for Senders, Write for Receivers", `ws_rr_buffer_ptr` means "Write for Senders, Read for Receivers" + void *rs_wr_buffer_ptr = nullptr, *ws_rr_buffer_ptr = nullptr; + int rs_wr_rank = 0, ws_rr_rank = 0; + if (warp_role == WarpRole::kRDMAAndNVLForwarder) + rs_wr_buffer_ptr = buffer_ptrs[nvl_rank], ws_rr_buffer_ptr = buffer_ptrs[target_rank], rs_wr_rank = nvl_rank, ws_rr_rank = target_rank; + if (warp_role == WarpRole::kNVLReceivers) + rs_wr_buffer_ptr = buffer_ptrs[target_rank], ws_rr_buffer_ptr = buffer_ptrs[nvl_rank], rs_wr_rank = target_rank, ws_rr_rank = nvl_rank; + + // Allocate buffers + auto nvl_channel_x = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); + auto nvl_channel_src_meta = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); + auto nvl_channel_x_scales = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_scales, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); + auto nvl_channel_topk_idx = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); + auto nvl_channel_topk_weights = AsymBuffer(ws_rr_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); + auto nvl_channel_prefix_start = AsymBuffer(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); + auto nvl_channel_prefix_end = AsymBuffer(ws_rr_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); + auto nvl_channel_head = AsymBuffer(rs_wr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, ws_rr_rank).advance_also(ws_rr_buffer_ptr); + auto nvl_channel_tail = AsymBuffer(ws_rr_buffer_ptr, 1, NUM_MAX_NVL_PEERS, channel_id, num_channels, rs_wr_rank).advance_also(rs_wr_buffer_ptr); + + // RDMA sender warp synchronization + // NOTES: `rdma_send_channel_tail` means the latest released tail + // NOTES: `rdma_send_channel_window` means the ongoing 32 transactions' status + __shared__ int rdma_send_channel_lock[kNumRDMARanks]; + __shared__ int rdma_send_channel_tail[kNumRDMARanks]; + __shared__ uint32_t rdma_send_channel_window[kNumRDMARanks]; + auto sync_rdma_sender_smem = []() { asm volatile("bar.sync 0, %0;" :: "r"((kNumDispatchRDMASenderWarps + 1) * 32)); }; + + // Forward warp synchronization + __shared__ volatile int forward_channel_head[NUM_MAX_NVL_PEERS][kNumRDMARanks]; + __shared__ volatile bool forward_channel_retired[NUM_MAX_NVL_PEERS]; + auto sync_forwarder_smem = []() { asm volatile("bar.sync 1, %0;" :: "r"((NUM_MAX_NVL_PEERS + 1) * 32)); }; + + if (warp_role == WarpRole::kRDMASender) { + // Get tasks + int token_start_idx, token_end_idx; + get_channel_task_range(num_tokens, num_channels, channel_id, token_start_idx, token_end_idx); + + // Send number of tokens in this channel by `-value - 1` + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS * 2 + 2 <= 32, "Invalid number of NVL peers"); + for (int dst_rdma_rank = warp_id; dst_rdma_rank < kNumRDMARanks; dst_rdma_rank += kNumDispatchRDMASenderWarps) { + auto dst_ptr = dst_rdma_rank == rdma_rank ? rdma_channel_meta.recv_buffer(dst_rdma_rank) : rdma_channel_meta.send_buffer(dst_rdma_rank); + if (lane_id < NUM_MAX_NVL_PEERS) { + dst_ptr[lane_id] = -(channel_id == 0 ? 0 : gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id) * num_channels + channel_id - 1]) - 1; + } else if (lane_id < NUM_MAX_NVL_PEERS * 2) { + dst_ptr[lane_id] = -gbl_channel_prefix_matrix[(dst_rdma_rank * NUM_MAX_NVL_PEERS + lane_id - NUM_MAX_NVL_PEERS) * num_channels + channel_id] - 1; + } else if (lane_id == NUM_MAX_NVL_PEERS * 2) { + dst_ptr[lane_id] = -(channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]) - 1; + } else if (lane_id == NUM_MAX_NVL_PEERS * 2 + 1) { + dst_ptr[lane_id] = -rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id] - 1; + } + __syncwarp(); + + // Issue RDMA for non-local ranks + if (dst_rdma_rank != rdma_rank) { + nvshmemi_ibgda_put_nbi_warp(reinterpret_cast(rdma_channel_meta.recv_buffer(rdma_rank)), + reinterpret_cast(rdma_channel_meta.send_buffer(dst_rdma_rank)), + sizeof(int) * (NUM_MAX_NVL_PEERS * 2 + 2), + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), + channel_id, lane_id, 0); + } + } + sync_rdma_sender_smem(); + + // Iterate over tokens and copy into buffer + int64_t token_idx; + int cached_rdma_channel_head = 0, global_rdma_tail_idx = 0; + auto send_buffer = lane_id == rdma_rank ? rdma_channel_data.recv_buffer(lane_id) : rdma_channel_data.send_buffer(lane_id); + for (token_idx = token_start_idx; token_idx < token_end_idx; ++ token_idx) { + // Read RDMA rank existence + uint64_t is_token_in_rank_uint64 = 0; + if (lane_id < kNumRDMARanks) { + is_token_in_rank_uint64 = __ldg(reinterpret_cast(is_token_in_rank + token_idx * num_ranks + lane_id * NUM_MAX_NVL_PEERS)); + global_rdma_tail_idx += (is_token_in_rank_uint64 != 0); + } + __syncwarp(); + + // Skip the token which does not belong to this warp + if ((token_idx - token_start_idx) % kNumDispatchRDMASenderWarps != warp_id) + continue; + auto rdma_tail_idx = is_token_in_rank_uint64 == 0 ? -1 : global_rdma_tail_idx - 1; + + // Wait the remote buffer to be released + auto start_time = clock64(); + while (is_token_in_rank_uint64 != 0 and rdma_tail_idx - cached_rdma_channel_head >= num_max_rdma_chunked_recv_tokens) { + cached_rdma_channel_head = static_cast(ld_volatile_global(rdma_channel_head.buffer(lane_id))); + + // Timeout check + if (clock64() - start_time >= NUM_TIMEOUT_CYCLES) { + printf("DeepEP dispatch RDMA sender timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA lane: %d, head: %d, tail: %d\n", + channel_id, rdma_rank, nvl_rank, lane_id, cached_rdma_channel_head, rdma_tail_idx); + trap(); + } + } + __syncwarp(); + + // Store RDMA head for combine + if (lane_id < kNumRDMARanks and not kCachedMode) + send_rdma_head[token_idx * kNumRDMARanks + lane_id] = rdma_tail_idx; + + // Broadcast tails + SourceMeta src_meta; + int num_topk_ranks = 0, topk_ranks[kNumTopkRDMARanks]; + void* dst_send_buffers[kNumTopkRDMARanks]; + #pragma unroll + for (int i = 0, slot_idx; i < kNumRDMARanks; ++ i) if ((slot_idx = __shfl_sync(0xffffffff, rdma_tail_idx, i)) >= 0) { + slot_idx = slot_idx % num_max_rdma_chunked_recv_tokens; + topk_ranks[num_topk_ranks] = i; + auto recv_is_token_in_rank_uint64 = broadcast(is_token_in_rank_uint64, i); + auto recv_is_token_in_rank_values = reinterpret_cast(&recv_is_token_in_rank_uint64); + if (lane_id == num_topk_ranks) + src_meta = SourceMeta(rdma_rank, recv_is_token_in_rank_values); + dst_send_buffers[num_topk_ranks ++] = reinterpret_cast(broadcast(send_buffer, i)) + slot_idx * num_bytes_per_rdma_token; + } + EP_DEVICE_ASSERT(num_topk_ranks <= kNumTopkRDMARanks); + + // Copy `x` into symmetric send buffer + auto st_broadcast = [=](const int key, const int4& value) { + #pragma unroll + for (int j = 0; j < num_topk_ranks; ++ j) + st_na_global(reinterpret_cast(dst_send_buffers[j]) + key, value); + }; + UNROLLED_WARP_COPY(5, lane_id, hidden_int4, 0, x + token_idx * hidden_int4, ld_nc_global, st_broadcast); + #pragma unroll + for (int i = 0; i < num_topk_ranks; ++ i) + dst_send_buffers[i] = reinterpret_cast(dst_send_buffers[i]) + hidden_int4; + + // Copy source metadata into symmetric send buffer + if (lane_id < num_topk_ranks) + st_na_global(reinterpret_cast(dst_send_buffers[lane_id]), src_meta); + #pragma unroll + for (int i = 0; i < num_topk_ranks; ++ i) + dst_send_buffers[i] = reinterpret_cast(dst_send_buffers[i]) + 1; + + // Copy `x_scales` into symmetric send buffer + #pragma unroll + for (int i = lane_id; i < num_scales; i += 32) { + auto offset = token_idx * scale_token_stride + i * scale_hidden_stride; + auto value = ld_nc_global(x_scales + offset); + #pragma unroll + for (int j = 0; j < num_topk_ranks; ++ j) + st_na_global(reinterpret_cast(dst_send_buffers[j]) + i, value); + } + #pragma unroll + for (int i = 0; i < num_topk_ranks; ++ i) + dst_send_buffers[i] = reinterpret_cast(dst_send_buffers[i]) + num_scales; + + // Copy `topk_idx` and `topk_weights` into symmetric send buffer + #pragma unroll + for (int i = lane_id; i < num_topk * num_topk_ranks; i += 32) { + auto rank_idx = i / num_topk, copy_idx = i % num_topk; + auto idx_value = static_cast(ld_nc_global(topk_idx + token_idx * num_topk + copy_idx)); + auto weight_value = ld_nc_global(topk_weights + token_idx * num_topk + copy_idx); + st_na_global(reinterpret_cast(dst_send_buffers[rank_idx]) + copy_idx, idx_value); + st_na_global(reinterpret_cast(dst_send_buffers[rank_idx]) + num_topk + copy_idx, weight_value); + } + __syncwarp(); + + // Release the transaction in the window + if (is_token_in_rank_uint64 != 0) { + // Acquire lock first + acquire_lock(rdma_send_channel_lock + lane_id); + auto latest_tail = rdma_send_channel_tail[lane_id]; + auto offset = rdma_tail_idx - latest_tail; + while (offset >= 32) { + release_lock(rdma_send_channel_lock + lane_id); + acquire_lock(rdma_send_channel_lock + lane_id); + latest_tail = rdma_send_channel_tail[lane_id]; + offset = rdma_tail_idx - latest_tail; + } + + // Release the transaction slot + // Add the bit and move the ones if possible + auto window = rdma_send_channel_window[lane_id] | (1u << offset); + if (offset == 0) { + auto num_empty_slots = (~window) == 0 ? 32 : __ffs(~window) - 1; + st_release_cta(rdma_send_channel_tail + lane_id, latest_tail + num_empty_slots); + window >>= num_empty_slots; + } + rdma_send_channel_window[lane_id] = window; + + // Release lock + release_lock(rdma_send_channel_lock + lane_id); + } + __syncwarp(); + } + } else if (warp_role == WarpRole::kRDMASenderCoordinator) { + // NOTES: in case of splitting, the issued put at the end of the buffer + EP_DEVICE_ASSERT(num_max_rdma_chunked_recv_tokens % num_max_rdma_chunked_send_tokens == 0); + + // Clean shared memory + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA ranks"); + (lane_id < kNumRDMARanks) ? (rdma_send_channel_lock[lane_id] = 0) : 0; + (lane_id < kNumRDMARanks) ? (rdma_send_channel_tail[lane_id] = 0) : 0; + (lane_id < kNumRDMARanks) ? (rdma_send_channel_window[lane_id] = 0) : 0; + + // Synchronize shared memory + sync_rdma_sender_smem(); + + // Get number of tokens to send for each RDMA rank + int num_tokens_to_send = 0; + if (lane_id < kNumRDMARanks) { + num_tokens_to_send = rdma_channel_prefix_matrix[lane_id * num_channels + channel_id]; + if (channel_id > 0) + num_tokens_to_send -= rdma_channel_prefix_matrix[lane_id * num_channels + channel_id - 1]; + } + + // Iterate all RDMA ranks + int last_issued_tail = 0; + auto start_time = clock64(); + while (__any_sync(0xffffffff, num_tokens_to_send > 0)) { + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { + printf("DeepEP RDMA sender coordinator timeout, channel: %d, IB: %d, nvl %d, dst IB: %d, tail: %d, remaining: %d\n", + channel_id, rdma_rank, nvl_rank, lane_id, last_issued_tail, num_tokens_to_send); + trap(); + } + + // TODO: try thread-level `put_nbi`? + for (int i = 0, synced_num_tokens_to_send; i < kNumRDMARanks; ++ i) { + // To mitigate incast congestion, shuffle the starting index of target rank for different ranks and channels + int dst_rdma_rank = (i + channel_id + rdma_rank) % kNumRDMARanks; + synced_num_tokens_to_send = __shfl_sync(0xffffffff, num_tokens_to_send, dst_rdma_rank); + if (synced_num_tokens_to_send == 0) + continue; + + // Read the latest progress + // NOTES: `rdma_send_channel_tail` does not need to be protected by lock + auto processed_tail = __shfl_sync(0xffffffff, ld_acquire_cta(const_cast(rdma_send_channel_tail + dst_rdma_rank)), 0); + auto synced_last_issued_tail = __shfl_sync(0xffffffff, last_issued_tail, dst_rdma_rank); + auto num_tokens_processed = processed_tail - synced_last_issued_tail; + if (num_tokens_processed != synced_num_tokens_to_send and num_tokens_processed < num_max_rdma_chunked_send_tokens) + continue; + + // Issue RDMA send + auto num_tokens_to_issue = min(num_tokens_processed, num_max_rdma_chunked_send_tokens); + EP_DEVICE_ASSERT(num_tokens_to_issue >= 0 and num_tokens_to_issue <= synced_num_tokens_to_send); + if (dst_rdma_rank != rdma_rank) { + auto dst_slot_idx = synced_last_issued_tail % num_max_rdma_chunked_recv_tokens; + EP_DEVICE_ASSERT(dst_slot_idx + num_tokens_to_issue <= num_max_rdma_chunked_recv_tokens); + const size_t num_bytes_per_msg = num_bytes_per_rdma_token * num_tokens_to_issue; + const auto dst_ptr = reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token); + const auto src_ptr = reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + dst_slot_idx * num_bytes_per_rdma_token); + nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, lane_id, 0); + } else { + // Lighter fence for local RDMA rank + memory_fence(); + } + __syncwarp(); + + // Update tails + if (lane_id == dst_rdma_rank) { + last_issued_tail += num_tokens_to_issue; + num_tokens_to_send -= num_tokens_to_issue; + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_tokens_to_issue, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank); + } + __syncwarp(); + } + } + } else if (warp_role == WarpRole::kRDMAAndNVLForwarder) { + // RDMA consumers and NVL producers + const auto dst_nvl_rank = target_rank; + const auto dst_rank = rdma_rank * NUM_MAX_NVL_PEERS + dst_nvl_rank; + const auto dst_rank_expert_begin = dst_rank * (num_experts / num_ranks); + const auto dst_rank_expert_end = dst_rank_expert_begin + (num_experts / num_ranks); + + // Wait counters to arrive + int num_tokens_to_recv_from_rdma = 0, src_rdma_channel_prefix = 0; + EP_DEVICE_ASSERT(kNumRDMARanks <= 32); + auto start_time = clock64(); + if (lane_id < kNumRDMARanks) { + while (true) { + auto meta_0 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + dst_nvl_rank); + auto meta_1 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS + dst_nvl_rank); + auto meta_2 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2); + auto meta_3 = ld_volatile_global(rdma_channel_meta.recv_buffer(lane_id) + NUM_MAX_NVL_PEERS * 2 + 1); + if (meta_0 < 0 and meta_1 < 0 and meta_2 < 0 and meta_3 < 0) { + // Notify NVL ranks + int start_sum = -meta_0 - 1, end_sum = -meta_1 - 1; + EP_DEVICE_ASSERT(start_sum >= 0 and end_sum >= 0 and end_sum >= start_sum); + st_relaxed_sys_global(nvl_channel_prefix_start.buffer() + lane_id, -start_sum - 1); + st_relaxed_sys_global(nvl_channel_prefix_end.buffer() + lane_id, -end_sum - 1); + + // Save RDMA channel received token count + src_rdma_channel_prefix = -meta_2 - 1; + auto src_rdma_channel_prefix_1 = -meta_3 - 1; + num_tokens_to_recv_from_rdma = src_rdma_channel_prefix_1 - src_rdma_channel_prefix; + if (not kCachedMode) + recv_rdma_channel_prefix_matrix[lane_id * num_channels + channel_id] = src_rdma_channel_prefix_1; + src_rdma_channel_prefix += lane_id == 0 ? 0 : recv_rdma_rank_prefix_sum[lane_id - 1]; + EP_DEVICE_ASSERT(num_tokens_to_recv_from_rdma >= 0); + break; + } + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf("DeepEP dispatch forwarder timeout (RDMA meta), channel: %d, RDMA: %d, nvl: %d, src RDMA lane: %d, dst NVL: %d, meta: %d, %d, %d, %d\n", + channel_id, rdma_rank, nvl_rank, lane_id, dst_nvl_rank, meta_0, meta_1, meta_2, meta_3); + trap(); + } + } + } + __syncwarp(); + + // Shift cached head + send_nvl_head += src_rdma_channel_prefix * NUM_MAX_NVL_PEERS + dst_nvl_rank; + + // Wait shared memory to be cleaned + sync_forwarder_smem(); + + // Forward tokens from RDMA buffer + // NOTES: always start from the local rank + int src_rdma_rank = sm_id % kNumRDMARanks; + int cached_rdma_channel_head = 0, cached_rdma_channel_tail = 0; + int cached_nvl_channel_head = 0, cached_nvl_channel_tail = 0, rdma_nvl_token_idx = 0; + while (__any_sync(0xffffffff, num_tokens_to_recv_from_rdma > 0)) { + // Check destination queue emptiness, or wait a buffer to be released + start_time = clock64(); + while (lane_id == 0) { + int num_used_slots = cached_nvl_channel_tail - cached_nvl_channel_head; + if (num_max_nvl_chunked_recv_tokens - num_used_slots >= num_max_nvl_chunked_send_tokens) + break; + cached_nvl_channel_head = ld_volatile_global(nvl_channel_head.buffer()); + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf("DeepEP dispatch forwarder timeout (NVL check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, head: %d, tail: %d\n", + channel_id, rdma_rank, nvl_rank, dst_nvl_rank, ld_volatile_global(nvl_channel_head.buffer()), cached_nvl_channel_tail); + trap(); + } + } + __syncwarp(); + + // Find next source RDMA rank (round-robin) + start_time = clock64(); + while (true) { + src_rdma_rank = (src_rdma_rank + 1) % kNumRDMARanks; + if (__shfl_sync(0xffffffff, num_tokens_to_recv_from_rdma, src_rdma_rank) > 0) { + if (lane_id == src_rdma_rank and cached_rdma_channel_head == cached_rdma_channel_tail) + cached_rdma_channel_tail = static_cast(ld_acquire_sys_global(rdma_channel_tail.buffer(src_rdma_rank))); + if (__shfl_sync(0xffffffff, cached_rdma_channel_tail > cached_rdma_channel_head, src_rdma_rank)) + break; + } + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { + printf("DeepEP dispatch forwarder timeout (RDMA check), channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, src RDMA lane: %d, head: %d, tail: %d, expected: %d\n", + channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, cached_rdma_channel_head, cached_rdma_channel_tail, num_tokens_to_recv_from_rdma); + trap(); + } + } + auto src_rdma_head = __shfl_sync(0xffffffff, cached_rdma_channel_head, src_rdma_rank); + auto src_rdma_tail = __shfl_sync(0xffffffff, cached_rdma_channel_tail, src_rdma_rank); + + // Iterate over every token from the RDMA buffer + for (int i = src_rdma_head, num_tokens_sent = 0; i < src_rdma_tail; ++ i) { + auto rdma_slot_idx = i % num_max_rdma_chunked_recv_tokens; + void* shifted = rdma_channel_data.recv_buffer(src_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token; + auto src_meta = ld_nc_global(reinterpret_cast(static_cast(shifted) + hidden_bytes)); + lane_id == src_rdma_rank ? (num_tokens_to_recv_from_rdma -= 1) : 0; + bool is_in_dst_nvl_rank = src_meta.is_token_in_nvl_rank(dst_nvl_rank); + if (lane_id == src_rdma_rank) { + auto cached_head = is_in_dst_nvl_rank ? rdma_nvl_token_idx : -1; + rdma_nvl_token_idx += is_in_dst_nvl_rank; + if (not kCachedMode) + send_nvl_head[i * NUM_MAX_NVL_PEERS] = cached_head; + } + if (not is_in_dst_nvl_rank) + continue; + + // Get an empty slot + int dst_slot_idx = (cached_nvl_channel_tail ++) % num_max_nvl_chunked_recv_tokens; + + // Copy data + UNROLLED_WARP_COPY(5, lane_id, hidden_int4, + nvl_channel_x.buffer() + dst_slot_idx * hidden_int4, + reinterpret_cast(shifted), + ld_nc_global, st_na_global); + shifted = static_cast(shifted) + hidden_int4; + + // Copy source meta + if (lane_id == 0) + st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, src_meta); + shifted = static_cast(shifted) + 1; + + // Copy `x_scales` + UNROLLED_WARP_COPY(1, lane_id, num_scales, + nvl_channel_x_scales.buffer() + dst_slot_idx * num_scales, + reinterpret_cast(shifted), + ld_nc_global, st_na_global); + shifted = static_cast(shifted) + num_scales; + + // Copy `topk_idx` and `topk_weights` + // NOTES: do not use `shifted` after this `if`, because only several lanes are shifted + if (lane_id < num_topk) { + // Read + auto idx_value = ld_nc_global(static_cast(shifted) + lane_id); + shifted = static_cast(shifted) + num_topk; + auto weight_value = ld_nc_global(static_cast(shifted) + lane_id); + + // Transform and write + idx_value = (idx_value >= dst_rank_expert_begin and idx_value < dst_rank_expert_end) ? idx_value - dst_rank_expert_begin : -1; + st_na_global(nvl_channel_topk_idx.buffer() + dst_slot_idx * num_topk + lane_id, idx_value); + weight_value = idx_value >= 0 ? weight_value : 0.0f; + st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, weight_value); + } + + // In case of insufficient NVL buffers, early stopping + if ((++ num_tokens_sent) == num_max_nvl_chunked_send_tokens) + src_rdma_tail = i + 1; + } + + // Sync head index + if (lane_id == src_rdma_rank) + forward_channel_head[dst_nvl_rank][src_rdma_rank] = (cached_rdma_channel_head = src_rdma_tail); + + // Move tail index + __syncwarp(); + if (lane_id == 0) + st_release_sys_global(nvl_channel_tail.buffer(), cached_nvl_channel_tail); + } + + // Retired + __syncwarp(); + if (lane_id == 0) + forward_channel_retired[dst_nvl_rank] = true; + } else if (warp_role == WarpRole::kForwarderCoordinator) { + // Extra warps for forwarder coordinator should exit directly + if (target_rank > 0) + return; + + // Forward warp coordinator + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); + + // Clean shared memory + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers"); + #pragma unroll + for (int i = lane_id; i < kNumRDMARanks * NUM_MAX_NVL_PEERS; i += 32) + forward_channel_head[i % NUM_MAX_NVL_PEERS][i / NUM_MAX_NVL_PEERS] = 0; + if (lane_id < NUM_MAX_NVL_PEERS) + forward_channel_retired[lane_id] = false; + sync_forwarder_smem(); + + int last_head = 0, target_rdma = lane_id < kNumRDMARanks ? lane_id : 0; + while (true) { + // Find minimum head + int min_head = std::numeric_limits::max(); + #pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i) if (not forward_channel_retired[i]) + min_head = min(min_head, forward_channel_head[i][target_rdma]); + if (__all_sync(0xffffffff, min_head == std::numeric_limits::max())) + break; + + // Update remote head + if (min_head != std::numeric_limits::max() and min_head >= last_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_head, + translate_dst_rdma_rank(lane_id, nvl_rank), channel_id + num_channels, lane_id == rdma_rank); + last_head = min_head; + } + + // Nanosleep and let other warps work + __nanosleep(NUM_WAIT_NANOSECONDS); + } + } else { + // NVL consumers + // Retrieve rank offset from barrier results (each lane's register stores an RDMA rank) + int src_nvl_rank = target_rank, total_offset = 0; + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); + if (lane_id < kNumRDMARanks and lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank > 0) + total_offset = recv_gbl_rank_prefix_sum[lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank - 1]; + + // Receive channel offsets + int start_offset = 0, end_offset = 0, num_tokens_to_recv; + auto start_time = clock64(); + while (lane_id < kNumRDMARanks) { + start_offset = ld_volatile_global(nvl_channel_prefix_start.buffer() + lane_id); + end_offset = ld_volatile_global(nvl_channel_prefix_end.buffer() + lane_id); + if (start_offset < 0 and end_offset < 0) { + start_offset = -start_offset - 1, end_offset = -end_offset - 1; + total_offset += start_offset; + break; + } + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, src nvl: %d, start: %d, end: %d\n", + channel_id, rdma_rank, nvl_rank, lane_id, src_nvl_rank, start_offset, end_offset); + trap(); + } + } + num_tokens_to_recv = warp_reduce_sum(end_offset - start_offset); + + // Save for combine usage + if (lane_id < kNumRDMARanks and not kCachedMode) + recv_gbl_channel_prefix_matrix[(lane_id * NUM_MAX_NVL_PEERS + src_nvl_rank) * num_channels + channel_id] = total_offset; + __syncwarp(); + + int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; + while (num_tokens_to_recv > 0) { + // Check channel status by lane 0 + start_time = clock64(); + while (lane_id == 0) { + // Ready to copy + if (cached_channel_head_idx != cached_channel_tail_idx) + break; + cached_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer()); + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf("DeepEP dispatch NVL receiver timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, head: %d, tail: %d\n", + channel_id, rdma_rank, nvl_rank, src_nvl_rank, cached_channel_head_idx, cached_channel_tail_idx); + trap(); + } + } + + // Sync queue tail + cached_channel_tail_idx = __shfl_sync(0xffffffff, cached_channel_tail_idx, 0); + + // Copy data + int num_recv_tokens = cached_channel_tail_idx - cached_channel_head_idx; + for (int chunk_idx = 0; chunk_idx < num_recv_tokens; ++ chunk_idx, -- num_tokens_to_recv) { + int token_idx_in_buffer = (cached_channel_head_idx ++) % num_max_nvl_chunked_recv_tokens; + auto meta = ld_nc_global(nvl_channel_src_meta.buffer() + token_idx_in_buffer); + int64_t recv_token_idx = __shfl_sync(0xffffffff, total_offset, meta.src_rdma_rank); + (lane_id == meta.src_rdma_rank) ? (total_offset += 1) : 0; + + // Copy data + UNROLLED_WARP_COPY(5, lane_id, hidden_int4, + recv_x + recv_token_idx * hidden_int4, + nvl_channel_x.buffer() + token_idx_in_buffer * hidden_int4, + ld_nc_global, st_na_global); + + // Copy source meta + if (lane_id == 0 and not kCachedMode) + st_na_global(recv_src_meta + recv_token_idx, meta); + + // Copy scales + UNROLLED_WARP_COPY(1, lane_id, num_scales, + recv_x_scales + recv_token_idx * num_scales, + nvl_channel_x_scales.buffer() + token_idx_in_buffer * num_scales, + ld_nc_global, st_na_global); + + // Copy `topk_idx` and `topk_weights` + if (lane_id < num_topk) { + auto recv_idx = recv_token_idx * num_topk + lane_id; + auto buffer_idx = token_idx_in_buffer * num_topk + lane_id; + st_na_global(recv_topk_idx + recv_idx, static_cast(ld_nc_global(nvl_channel_topk_idx.buffer() + buffer_idx))); + st_na_global(recv_topk_weights + recv_idx, ld_nc_global(nvl_channel_topk_weights.buffer() + buffer_idx)); + } + } + + // Move queue + __syncwarp(); + if (lane_id == 0) + st_relaxed_sys_global(nvl_channel_head.buffer(), cached_channel_head_idx); + } + } +} + +void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta, + const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, + int* send_rdma_head, int* send_nvl_head, + int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, + const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum, + const bool* is_token_in_rank, + int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, + int scale_token_stride, int scale_hidden_stride, + void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, + int rank, int num_ranks, bool is_cached_dispatch, + cudaStream_t stream, int num_channels, bool low_latency_mode) { + constexpr int kNumDispatchRDMASenderWarps = 7; + + // Make sure never OOB + EP_HOST_ASSERT(static_cast(num_scales) * scale_hidden_stride < std::numeric_limits::max()); + +#define DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \ + auto dispatch_func = low_latency_mode ? \ + (is_cached_dispatch ? dispatch : dispatch) : \ + (is_cached_dispatch ? dispatch : dispatch); \ + LAUNCH_KERNEL(&cfg, dispatch_func, \ + reinterpret_cast(recv_x), recv_x_scales, recv_topk_idx, recv_topk_weights, reinterpret_cast(recv_src_meta), \ + reinterpret_cast(x), x_scales, topk_idx, topk_weights, \ + send_rdma_head, send_nvl_head, \ + recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, \ + rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \ + gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \ + is_token_in_rank, \ + num_tokens, hidden_int4, num_scales, num_topk, num_experts, \ + scale_token_stride, scale_hidden_stride, \ + rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, \ + buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, \ + rank, num_ranks); } break + + EP_HOST_ASSERT((topk_idx == nullptr) == (topk_weights == nullptr)); + EP_HOST_ASSERT((recv_topk_idx == nullptr) == (recv_topk_weights == nullptr)); + + SETUP_LAUNCH_CONFIG(num_channels * 2, (kNumDispatchRDMASenderWarps + 1 + NUM_MAX_NVL_PEERS) * 32, stream); + SWITCH_RDMA_RANKS(DISPATCH_LAUNCH_CASE); +#undef DISPATCH_LAUNCH_CASE +} + +template +__global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_int_clean, + const int nvl_clean_offset, const int nvl_num_int_clean, + int* combined_rdma_head, int num_combined_tokens, int num_channels, + const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head, + void* rdma_buffer_ptr, + void** buffer_ptrs, int** barrier_signal_ptrs, int rank, int num_ranks, + bool is_cached_dispatch, const nvshmem_team_t rdma_team) { + auto sm_id = static_cast(blockIdx.x); + auto thread_id = static_cast(threadIdx.x); + auto num_threads = static_cast(blockDim.x); + auto num_warps = num_threads / 32; + auto warp_id = thread_id / 32; + auto lane_id = get_lane_id(); + + auto nvl_rank = rank % NUM_MAX_NVL_PEERS; + auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + + // Using two SMs, which clean the RDMA/NVL buffer respectively + if (sm_id == 0) { + // Barrier for RDMA + if (thread_id == 0) + nvshmem_sync_with_same_gpu_idx(rdma_team); + __syncthreads(); + + // Clean + auto rdma_buffer_ptr_int = static_cast(rdma_buffer_ptr); + #pragma unroll + for (int i = thread_id; i < rdma_num_int_clean; i += num_threads) + rdma_buffer_ptr_int[rdma_clean_offset + i] = 0; + __syncthreads(); + + // Barrier again + if (thread_id == 0) + nvshmem_sync_with_same_gpu_idx(rdma_team); + } else if (sm_id == 1) { + // Barrier for NVL + barrier_block(barrier_signal_ptrs, nvl_rank); + + // Clean + auto nvl_buffer_ptr_int = static_cast(buffer_ptrs[nvl_rank]); + #pragma unroll + for (int i = thread_id; i < nvl_num_int_clean; i += num_threads) + nvl_buffer_ptr_int[nvl_clean_offset + i] = 0; + + // Barrier again + barrier_block(barrier_signal_ptrs, nvl_rank); + } else if (sm_id == 2) { + if (is_cached_dispatch) + return; + + EP_DEVICE_ASSERT(num_warps >= num_channels); + EP_DEVICE_ASSERT(num_rdma_ranks <= 32); + + // Iterate in reverse order + if (lane_id < num_rdma_ranks and warp_id < num_channels) { + int token_start_idx, token_end_idx; + get_channel_task_range(num_combined_tokens, num_channels, warp_id, token_start_idx, token_end_idx); + + // NOTES: `1 << 25` is a heuristic large number + int last_head = 1 << 25; + for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; -- token_idx) { + auto current_head = __ldg(combined_rdma_head + token_idx * num_rdma_ranks + lane_id); + if (current_head < 0) { + combined_rdma_head[token_idx * num_rdma_ranks + lane_id] = -last_head - 1; + } else { + last_head = current_head; + } + } + } + } else { + if (is_cached_dispatch) + return; + + EP_DEVICE_ASSERT(num_warps >= num_channels); + EP_DEVICE_ASSERT(rdma_channel_prefix_matrix != nullptr and rdma_rank_prefix_sum != nullptr); + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Too many NVL peers"); + + if (lane_id < NUM_MAX_NVL_PEERS and warp_id < num_channels) { + for (int dst_rdma_rank = sm_id - 3; dst_rdma_rank < num_rdma_ranks; dst_rdma_rank += num_channels * 2 - 3) { + // Iterate in reverse order + int token_start_idx = warp_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id - 1]; + int token_end_idx = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + warp_id]; + int shift = dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1]; + token_start_idx += shift, token_end_idx += shift; + + // NOTES: `1 << 25` is a heuristic large number + int last_head = 1 << 25; + #pragma unroll + for (int token_idx = token_end_idx - 1; token_idx >= token_start_idx; -- token_idx) { + auto current_head = __ldg(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id); + if (current_head < 0) { + combined_nvl_head[token_idx * NUM_MAX_NVL_PEERS + lane_id] = -last_head - 1; + } else { + last_head = current_head; + } + } + } + } + } +} + +void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, + int num_ranks, int num_channels, int num_combined_tokens, int* combined_rdma_head, + const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head, + void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, + int** barrier_signal_ptrs, int rank, cudaStream_t stream, + int64_t num_rdma_bytes, int64_t num_nvl_bytes, + bool is_cached_dispatch, bool low_latency_mode) { + const int num_threads = std::max(128, 32 * num_channels); + const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + + // Get clean meta + auto rdma_clean_meta = get_rdma_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, num_max_rdma_chunked_recv_tokens, num_channels); + auto nvl_clean_meta = get_nvl_clean_meta(hidden_int4, num_scales, num_topk_idx, num_topk_weights, num_rdma_ranks, NUM_MAX_NVL_PEERS, num_max_nvl_chunked_recv_tokens, num_channels); + EP_HOST_ASSERT((rdma_clean_meta.first + rdma_clean_meta.second) * sizeof(int) <= num_rdma_bytes); + EP_HOST_ASSERT((nvl_clean_meta.first + nvl_clean_meta.second) * sizeof(int) <= num_nvl_bytes); + EP_HOST_ASSERT(num_rdma_bytes < std::numeric_limits::max()); + EP_HOST_ASSERT(num_nvl_bytes < std::numeric_limits::max()); + EP_HOST_ASSERT(num_channels * 2 > 3); + + // Launch kernel + auto cached_notify_func = low_latency_mode ? cached_notify : cached_notify; + SETUP_LAUNCH_CONFIG(num_channels * 2, num_threads, stream); + LAUNCH_KERNEL(&cfg, cached_notify_func, + rdma_clean_meta.first, rdma_clean_meta.second, + nvl_clean_meta.first, nvl_clean_meta.second, + combined_rdma_head, num_combined_tokens, num_channels, + rdma_channel_prefix_matrix, rdma_rank_prefix_sum, combined_nvl_head, + rdma_buffer_ptr, + buffer_ptrs, barrier_signal_ptrs, rank, num_ranks, + is_cached_dispatch, cpu_rdma_team); +} + +template +__device__ int combine_token(bool is_token_in_rank, int head_idx, + int lane_id, int hidden_int4, int num_topk, + int4* combined_row, float* combined_topk_weights, + const int4* bias_0_int4, const int4* bias_1_int4, + int num_max_recv_tokens, const ReceiveFn& recv_fn, const ReceiveTWFn& recv_tw_fn) { + constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t); + + // Broadcast current heads + // Lane `i` holds the head of rank `i` and `is_token_in_rank` + EP_STATIC_ASSERT(kMaxNumRanks <= 32, "Too many ranks"); + int num_topk_ranks = 0, topk_ranks[kMaxNumRanks], slot_indices[kMaxNumRanks]; + #pragma unroll + for (int i = 0; i < kNumRanks; ++ i) if (__shfl_sync(0xffffffff, is_token_in_rank, i)) { + slot_indices[num_topk_ranks] = __shfl_sync(0xffffffff, head_idx, i) % num_max_recv_tokens; + topk_ranks[num_topk_ranks ++] = i; + } + EP_DEVICE_ASSERT(num_topk_ranks <= kMaxNumRanks); + + // Reduce data + #pragma unroll + for (int i = lane_id; i < hidden_int4; i += 32) { + // Read bias + // TODO: make it as a finer-grained template + int4 bias_0_value_int4, bias_1_value_int4; + if (kMaybeWithBias) { + bias_0_value_int4 = bias_0_int4 != nullptr ? ld_nc_global(bias_0_int4 + i) : make_int4(0, 0, 0, 0); + bias_1_value_int4 = bias_1_int4 != nullptr ? ld_nc_global(bias_1_int4 + i) : make_int4(0, 0, 0, 0); + } + + // Read buffers + // TODO: maybe too many registers here + int4 recv_value_int4[kMaxNumRanks]; + #pragma unroll + for (int j = 0; j < num_topk_ranks; ++ j) + recv_value_int4[j] = recv_fn(topk_ranks[j], slot_indices[j], i); + + // Clean + // Reduce bias + float values[kDtypePerInt4] = {0}; + if (kMaybeWithBias) { + auto bias_0_values = reinterpret_cast(&bias_0_value_int4); + auto bias_1_values = reinterpret_cast(&bias_1_value_int4); + #pragma unroll + for (int j = 0; j < kDtypePerInt4; ++ j) + values[j] = static_cast(bias_0_values[j]) + static_cast(bias_1_values[j]); + } + + // Reduce all-to-all results + #pragma unroll + for (int j = 0; j < num_topk_ranks; ++ j) { + auto recv_value_dtypes = reinterpret_cast(&recv_value_int4[j]); + #pragma unroll + for (int k = 0; k < kDtypePerInt4; ++ k) + values[k] += static_cast(recv_value_dtypes[k]); + } + + // Cast back to `dtype_t` and write + int4 out_int4; + auto out_dtypes = reinterpret_cast(&out_int4); + #pragma unroll + for (int j = 0; j < kDtypePerInt4; ++ j) + out_dtypes[j] = static_cast(values[j]); + st_na_global(combined_row + i, out_int4); + } + + // Reduce `topk_weights` + if (lane_id < num_topk) { + float value = 0; + #pragma unroll + for (int i = 0; i < num_topk_ranks; ++ i) + value += recv_tw_fn(topk_ranks[i], slot_indices[i], lane_id); + st_na_global(combined_topk_weights + lane_id, value); + } + + // Return the minimum top-k rank + return topk_ranks[0]; +} + +template 0) ? kNumCombineForwarderWarps / kNumRDMARanks : 1, + int kNumForwarders = kNumRDMARanks * kNumWarpsPerForwarder, + int kNumRDMAReceivers = kNumForwarders + NUM_MAX_NVL_PEERS> +__global__ void __launch_bounds__((NUM_MAX_NVL_PEERS + 1 + kNumForwarders) * 32, 1) +combine(int4* combined_x, float* combined_topk_weights, + const bool* is_combined_token_in_rank, + const int4* x, const float* topk_weights, + const int4* bias_0, const int4* bias_1, + const int* combined_rdma_head, const int* combined_nvl_head, + const SourceMeta* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, + int num_tokens, int num_combined_tokens, int hidden, int num_topk, + void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, + int rank, int num_ranks) { + enum class WarpRole { + kNVLSender, + kNVLAndRDMAForwarder, + kRDMAReceiver, + kCoordinator + }; + + const auto sm_id = static_cast(blockIdx.x); + const auto num_threads = static_cast(blockDim.x), num_warps = num_threads / 32; + const auto thread_id = static_cast(threadIdx.x), lane_id = get_lane_id(); + const auto num_channels = static_cast(gridDim.x) / 2, channel_id = sm_id / 2; + const bool is_rdma_receiver_sm = sm_id % 2 == 1; + + EP_DEVICE_ASSERT(num_topk <= 32); + EP_DEVICE_ASSERT(hidden % (sizeof(int4) / sizeof(dtype_t)) == 0); + const auto hidden_int4 = hidden / (sizeof(int4) / sizeof(dtype_t)); + + // NOTES: we decouple a channel into 2 SMs + const auto rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; + auto role_meta = [=]() -> std::pair { + auto warp_id = thread_id / 32; + if (not is_rdma_receiver_sm) { + if (warp_id < NUM_MAX_NVL_PEERS) { + auto shuffled_warp_id = warp_id; + shuffled_warp_id = (shuffled_warp_id + channel_id) % NUM_MAX_NVL_PEERS; + return {WarpRole::kNVLSender, shuffled_warp_id}; + } else if (warp_id < NUM_MAX_NVL_PEERS + kNumForwarders) { + auto shuffled_warp_id = warp_id - NUM_MAX_NVL_PEERS; + shuffled_warp_id = (shuffled_warp_id + channel_id) % kNumForwarders; + return {WarpRole::kNVLAndRDMAForwarder, shuffled_warp_id}; + } else { + return {WarpRole::kCoordinator, 0}; + } + } else { + if (warp_id < NUM_MAX_NVL_PEERS + kNumForwarders) { + return {WarpRole::kRDMAReceiver, warp_id}; + } else { + return {WarpRole::kCoordinator, 0}; + } + } + }(); + auto warp_role = role_meta.first; + auto warp_id = role_meta.second; + + EP_DEVICE_ASSERT(num_warps == NUM_MAX_NVL_PEERS + kNumForwarders + 1); + auto num_max_nvl_chunked_recv_tokens_per_rdma = num_max_nvl_chunked_recv_tokens / kNumRDMARanks; + + if (warp_role == WarpRole::kNVLSender) { + // NVL producers + const auto dst_nvl_rank = warp_id; + + // NVL layouts + // NOTES: to avoid deadlocks, we use separate NVL buffers for different RDMA sources + auto dst_buffer_ptr = buffer_ptrs[dst_nvl_rank], local_buffer_ptr = buffer_ptrs[nvl_rank]; + auto nvl_channel_x = AsymBuffer(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr); + auto nvl_channel_src_meta = AsymBuffer(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr); + auto nvl_channel_topk_weights = AsymBuffer(dst_buffer_ptr, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr); + auto nvl_channel_head = AsymBuffer(local_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, dst_nvl_rank).advance_also(dst_buffer_ptr); + auto nvl_channel_tail = AsymBuffer(dst_buffer_ptr, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_buffer_ptr); + + // Get tasks for each RDMA lane + int token_start_idx = 0, token_end_idx = 0; + if (lane_id < kNumRDMARanks) { + int prefix_idx = (lane_id * NUM_MAX_NVL_PEERS + dst_nvl_rank) * num_channels + channel_id; + token_start_idx = gbl_channel_prefix_matrix[prefix_idx]; + token_end_idx = (prefix_idx == num_channels * num_ranks - 1) ? num_tokens : gbl_channel_prefix_matrix[prefix_idx + 1]; + } + __syncwarp(); + + // NOTES: here the cached value of each lane is only responsible for a single RDMA buffer + int cached_channel_head_idx = 0, cached_channel_tail_idx = 0; + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); + + // Iterate over all tokens and send by chunks + while (true) { + // Exit if possible + if (__all_sync(0xffffffff, token_start_idx >= token_end_idx)) + break; + + // Decide the next RDMA buffer to send + bool is_lane_ready = false; + auto start_time = clock64(); + while (true) { + int num_used_slots = cached_channel_tail_idx - cached_channel_head_idx; + is_lane_ready = lane_id < kNumRDMARanks and token_start_idx < token_end_idx and num_max_nvl_chunked_recv_tokens_per_rdma - num_used_slots >= num_max_nvl_chunked_send_tokens; + if (__any_sync(0xffffffff, is_lane_ready)) + break; + + // Retry + if (lane_id < kNumRDMARanks and token_start_idx < token_end_idx) + cached_channel_head_idx = ld_volatile_global(nvl_channel_head.buffer() + lane_id); + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < kNumRDMARanks) { + printf("DeepEP combine NVL sender timeout, channel: %d, RDMA: %d, nvl: %d, dst NVL: %d, RDMA lane: %d, head: %d, tail: %d, start: %d, end: %d\n", + channel_id, rdma_rank, nvl_rank, dst_nvl_rank, lane_id, ld_volatile_global(nvl_channel_head.buffer() + lane_id), cached_channel_tail_idx, + token_start_idx, token_end_idx); + trap(); + } + } + + // Sync token start index and count + for (int current_rdma_idx = 0; current_rdma_idx < kNumRDMARanks; ++ current_rdma_idx) { + if (__shfl_sync(0xffffffff, (token_start_idx >= token_end_idx) or (not is_lane_ready), current_rdma_idx)) + continue; + + // Sync token start index + auto token_idx = static_cast(__shfl_sync(0xffffffff, token_start_idx, current_rdma_idx)); + int num_tokens_in_chunk = __shfl_sync(0xffffffff, min(num_max_nvl_chunked_send_tokens, token_end_idx - token_start_idx), current_rdma_idx); + + // Send by chunk + for (int chunk_idx = 0; chunk_idx < num_tokens_in_chunk; ++ chunk_idx, ++ token_idx) { + // Get an empty slot + int dst_slot_idx = 0; + if (lane_id == current_rdma_idx) { + dst_slot_idx = (cached_channel_tail_idx ++) % num_max_nvl_chunked_recv_tokens_per_rdma; + dst_slot_idx = current_rdma_idx * num_max_nvl_chunked_recv_tokens_per_rdma + dst_slot_idx; + } + dst_slot_idx = __shfl_sync(0xffffffff, dst_slot_idx, current_rdma_idx); + + // Copy data + auto shifted_x_buffers = nvl_channel_x.buffer() + dst_slot_idx * hidden_int4; + auto shifted_x = x + token_idx * hidden_int4; + UNROLLED_WARP_COPY(5, lane_id, hidden_int4, shifted_x_buffers, shifted_x, ld_nc_global, st_na_global); + + // Copy source meta + if (lane_id == 0) + st_na_global(nvl_channel_src_meta.buffer() + dst_slot_idx, ld_nc_global(src_meta + token_idx)); + + // Copy `topk_weights` + if (lane_id < num_topk) + st_na_global(nvl_channel_topk_weights.buffer() + dst_slot_idx * num_topk + lane_id, ld_nc_global(topk_weights + token_idx * num_topk + lane_id)); + } + lane_id == current_rdma_idx ? (token_start_idx = static_cast(token_idx)) : 0; + } + + // Move queue tail + __syncwarp(); + if (lane_id < kNumRDMARanks and is_lane_ready) + st_release_sys_global(nvl_channel_tail.buffer() + lane_id, cached_channel_tail_idx); + } + } else { + // Combiners and coordinators + // RDMA symmetric layout + auto hidden_bytes = hidden_int4 * sizeof(int4); + auto num_bytes_per_rdma_token = get_num_bytes_per_rdma_token(hidden_int4, 0, 0, num_topk); + auto rdma_channel_data = SymBuffer(rdma_buffer_ptr, num_max_rdma_chunked_recv_tokens * num_bytes_per_rdma_token, kNumRDMARanks, channel_id, num_channels); + auto rdma_channel_head = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); + auto rdma_channel_tail = SymBuffer(rdma_buffer_ptr, 1, kNumRDMARanks, channel_id, num_channels); + + // NVL layouts + void* local_nvl_buffer = buffer_ptrs[nvl_rank]; + void* nvl_buffers[NUM_MAX_NVL_PEERS]; + #pragma unroll + for (int i = 0; i < NUM_MAX_NVL_PEERS; ++ i) + nvl_buffers[i] = buffer_ptrs[i]; + auto nvl_channel_x = AsymBuffer(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * hidden_int4, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also(nvl_buffers); + auto nvl_channel_src_meta = AsymBuffer(local_nvl_buffer, num_max_nvl_chunked_recv_tokens, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also(nvl_buffers); + auto nvl_channel_topk_weights = AsymBuffer(local_nvl_buffer, num_max_nvl_chunked_recv_tokens * num_topk, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also(nvl_buffers); + auto nvl_channel_head = AsymBuffer(nvl_buffers, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels, nvl_rank).advance_also(local_nvl_buffer); + auto nvl_channel_tail = AsymBuffer(local_nvl_buffer, kNumRDMARanks, NUM_MAX_NVL_PEERS, channel_id, num_channels).advance_also(nvl_buffers); + + // Combiner warp synchronization + __shared__ volatile int forwarder_nvl_head[kNumForwarders][NUM_MAX_NVL_PEERS]; + __shared__ volatile bool forwarder_retired[kNumForwarders]; + __shared__ volatile int rdma_receiver_rdma_head[kNumRDMAReceivers][kNumRDMARanks]; + __shared__ volatile bool rdma_receiver_retired[kNumRDMAReceivers]; + auto sync_forwarder_smem = [=]() { asm volatile("bar.sync 0, %0;" :: "r"((kNumForwarders + 1) * 32)); }; + auto sync_rdma_receiver_smem = [=]() { asm volatile("bar.sync 1, %0;" :: "r"((kNumRDMAReceivers + 1) * 32)); }; + + if (warp_role == WarpRole::kNVLAndRDMAForwarder) { + // Receive from NVL ranks and forward to RDMA ranks + // NOTES: this part is using "large warps" for each RDMA ranks + const auto dst_rdma_rank = warp_id / kNumWarpsPerForwarder; + const auto sub_warp_id = warp_id % kNumWarpsPerForwarder; + auto send_buffer = dst_rdma_rank == rdma_rank ? rdma_channel_data.recv_buffer(dst_rdma_rank) : rdma_channel_data.send_buffer(dst_rdma_rank); + auto sync_large_warp = [=]() { + if (kNumWarpsPerForwarder == 1) { + __syncwarp(); + } else { + asm volatile("bar.sync %0, %1;" :: "r"(dst_rdma_rank + 2), "r"(kNumWarpsPerForwarder * 32)); + } + }; + EP_STATIC_ASSERT(kNumWarpsPerForwarder == 1 or kNumRDMARanks + 2 <= 16, "Barriers are not enough"); + + // Advance to the corresponding NVL buffer + nvl_channel_x.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * hidden_int4); + nvl_channel_src_meta.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma); + nvl_channel_topk_weights.advance(dst_rdma_rank * num_max_nvl_chunked_recv_tokens_per_rdma * num_topk); + nvl_channel_head.advance(dst_rdma_rank); + nvl_channel_tail.advance(dst_rdma_rank); + + // Clean shared memory and sync + EP_STATIC_ASSERT(NUM_MAX_NVL_PEERS <= 32, "Invalid number of NVL peers"); + lane_id < NUM_MAX_NVL_PEERS ? (forwarder_nvl_head[warp_id][lane_id] = 0) : 0; + lane_id == 0 ? (forwarder_retired[warp_id] = false) : false; + sync_forwarder_smem(); + + // Get count and cached head + int cached_nvl_channel_tail_idx = 0; + int num_tokens_to_combine = rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id]; + int num_tokens_prefix = channel_id == 0 ? 0 : rdma_channel_prefix_matrix[dst_rdma_rank * num_channels + channel_id - 1]; + num_tokens_to_combine -= num_tokens_prefix; + num_tokens_prefix += dst_rdma_rank == 0 ? 0 : rdma_rank_prefix_sum[dst_rdma_rank - 1]; + combined_nvl_head += num_tokens_prefix * NUM_MAX_NVL_PEERS; + + // Iterate over all tokens and combine by chunks + for (int token_start_idx = 0; token_start_idx < num_tokens_to_combine; token_start_idx += num_max_rdma_chunked_send_tokens) { + // Check destination queue emptiness, or wait a buffer to be released + auto token_end_idx = min(token_start_idx + num_max_rdma_chunked_send_tokens, num_tokens_to_combine); + auto num_chunked_tokens = token_end_idx - token_start_idx; + auto start_time = clock64(); + while (sub_warp_id == 0 and lane_id == 0) { + // Inequality: `num_max_rdma_chunked_recv_tokens - (tail - head) >= num_chunked_tokens` + // Here, `token_start_idx` is the actual tail + int num_used_slots = token_start_idx - ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)); + if (num_max_rdma_chunked_recv_tokens - num_used_slots >= num_chunked_tokens) + break; + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf("DeepEP combine forwarder (RDMA check) timeout, channel: %d, RDMA: %d, nvl: %d, dst RDMA: %d, head: %ld, tail: %d, chunked: %d\n", + channel_id, rdma_rank, nvl_rank, dst_rdma_rank, ld_volatile_global(rdma_channel_head.buffer(dst_rdma_rank)), token_start_idx, num_chunked_tokens); + trap(); + } + } + sync_large_warp(); + + // Combine and write to the RDMA buffer + for (int token_idx = token_start_idx + sub_warp_id; token_idx < token_end_idx; token_idx += kNumWarpsPerForwarder) { + // Read expected head + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); + int expected_head = -1; + if (lane_id < NUM_MAX_NVL_PEERS) + expected_head = ld_nc_global(combined_nvl_head + token_idx * NUM_MAX_NVL_PEERS + lane_id); + + // Wait lanes to be ready + start_time = clock64(); + while (cached_nvl_channel_tail_idx <= expected_head) { + cached_nvl_channel_tail_idx = ld_acquire_sys_global(nvl_channel_tail.buffer(lane_id)); + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES and lane_id < NUM_MAX_NVL_PEERS) { + printf("DeepEP combine forwarder (NVL check) timeout, channel: %d, RDMA: %d, nvl: %d, src NVL: %d, dst RDMA: %d, tail: %d, waiting: %d, total: %d, sub: %d, large: %d, expected: %d\n", + channel_id, rdma_rank, nvl_rank, lane_id, dst_rdma_rank, cached_nvl_channel_tail_idx, token_idx, num_tokens_to_combine, sub_warp_id, kNumWarpsPerForwarder, expected_head); + trap(); + } + } + + // Combine current token + auto rdma_slot_idx = token_idx % num_max_rdma_chunked_recv_tokens; + void* shifted = send_buffer + rdma_slot_idx * num_bytes_per_rdma_token; + auto recv_fn = [&](int src_nvl_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(nvl_channel_x.buffer(src_nvl_rank) + slot_idx * hidden_int4 + hidden_int4_idx); }; + auto recv_tw_fn = [&](int src_nvl_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(nvl_channel_topk_weights.buffer(src_nvl_rank) + slot_idx * num_topk + topk_idx); }; + combine_token(expected_head >= 0, + expected_head, lane_id, + hidden_int4, num_topk, + static_cast(shifted), + reinterpret_cast(static_cast(shifted) + hidden_bytes + sizeof(SourceMeta)), + nullptr, nullptr, num_max_nvl_chunked_recv_tokens_per_rdma, recv_fn, recv_tw_fn); + + // Update head + if (lane_id < NUM_MAX_NVL_PEERS) + expected_head < 0 ? (forwarder_nvl_head[warp_id][lane_id] = -expected_head - 1) : (forwarder_nvl_head[warp_id][lane_id] = expected_head + 1); + } + sync_large_warp(); + + // Issue RDMA send + if (sub_warp_id == kNumWarpsPerForwarder - 1) { + if (dst_rdma_rank != rdma_rank) { + auto rdma_slot_idx = token_start_idx % num_max_rdma_chunked_recv_tokens; + const size_t num_bytes_per_msg = num_chunked_tokens * num_bytes_per_rdma_token; + const auto dst_ptr = reinterpret_cast(rdma_channel_data.recv_buffer(rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token); + const auto src_ptr = reinterpret_cast(rdma_channel_data.send_buffer(dst_rdma_rank) + rdma_slot_idx * num_bytes_per_rdma_token); + nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, lane_id, 0); + } else { + memory_fence(); + } + + // Write new RDMA tail + __syncwarp(); + if (lane_id == 0) { + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_tail.buffer(rdma_rank), num_chunked_tokens, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id, dst_rdma_rank == rdma_rank); + } + } + } + + // Retired + __syncwarp(); + if (lane_id == 0) + forwarder_retired[warp_id] = true; + } else if (warp_role == WarpRole::kRDMAReceiver) { + // Receive from RDMA ranks and write to the output tensor + // Clean shared memory and sync + EP_DEVICE_ASSERT(kNumRDMARanks <= 32); + lane_id < kNumRDMARanks ? (rdma_receiver_rdma_head[warp_id][lane_id] = 0) : 0; + lane_id == 0 ? (rdma_receiver_retired[warp_id] = false) : 0; + sync_rdma_receiver_smem(); + + // The same tokens as the dispatch process + int token_start_idx, token_end_idx; + get_channel_task_range(num_combined_tokens, num_channels, channel_id, token_start_idx, token_end_idx); + + // Iterate over all tokens and combine + int cached_channel_tail_idx = 0; + for (int64_t token_idx = token_start_idx + warp_id; token_idx < token_end_idx; token_idx += kNumRDMAReceivers) { + // Read expected head + EP_STATIC_ASSERT(kNumRDMARanks <= 32, "Invalid number of RDMA peers"); + int expected_head = -1; + if (lane_id < kNumRDMARanks) { + expected_head = ld_nc_global(combined_rdma_head + token_idx * kNumRDMARanks + lane_id); + (expected_head < 0) ? (rdma_receiver_rdma_head[warp_id][lane_id] = -expected_head - 1) : (rdma_receiver_rdma_head[warp_id][lane_id] = expected_head); + } + + // Wait lanes to be ready + auto start_time = clock64(); + while (cached_channel_tail_idx <= expected_head) { + cached_channel_tail_idx = static_cast(ld_acquire_sys_global(rdma_channel_tail.buffer(lane_id))); + + // Timeout check + if (clock64() - start_time > NUM_TIMEOUT_CYCLES) { + printf("DeepEP combine RDMA receiver timeout, channel: %d, RDMA: %d, nvl: %d, src RDMA: %d, tail: %d, waiting: %ld, expect: %d\n", + channel_id, rdma_rank, nvl_rank, lane_id, cached_channel_tail_idx, token_idx, expected_head); + trap(); + } + } + __syncwarp(); + + // Combine current token + auto recv_fn = [&](int src_rdma_rank, int slot_idx, int hidden_int4_idx) -> int4 { return ld_nc_global(reinterpret_cast(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token) + hidden_int4_idx);}; + auto recv_tw_fn = [&](int src_rdma_rank, int slot_idx, int topk_idx) -> float { return ld_nc_global(reinterpret_cast(rdma_channel_data.recv_buffer(src_rdma_rank) + slot_idx * num_bytes_per_rdma_token + hidden_bytes + sizeof(SourceMeta)) + topk_idx);}; + combine_token(expected_head >= 0, + expected_head, lane_id, + hidden_int4, num_topk, + combined_x + token_idx * hidden_int4, + combined_topk_weights + token_idx * num_topk, + bias_0 == nullptr ? nullptr : bias_0 + token_idx * hidden_int4, + bias_1 == nullptr ? nullptr : bias_1 + token_idx * hidden_int4, + num_max_rdma_chunked_recv_tokens, recv_fn, recv_tw_fn); + } + + // Retired + __syncwarp(); + if (lane_id == 0) + rdma_receiver_retired[warp_id] = true; + } else { + // Coordinator + // Sync shared memory status + is_rdma_receiver_sm ? sync_rdma_receiver_smem() : sync_forwarder_smem(); + const auto num_warps_per_rdma_rank = kNumForwarders / kNumRDMARanks; + + int last_rdma_head = 0; + int last_nvl_head[kNumRDMARanks] = {0}; + int dst_rdma_rank = lane_id < kNumRDMARanks ? lane_id : 0; + int dst_nvl_rank = lane_id < NUM_MAX_NVL_PEERS ? lane_id : 0; + EP_STATIC_ASSERT(kNumCombineForwarderWarps <= 32, "Invalid number of forwarder warps"); + while (true) { + // Retired + if (is_rdma_receiver_sm and __all_sync(0xffffffff, lane_id >= kNumRDMAReceivers or rdma_receiver_retired[lane_id])) + break; + if (not is_rdma_receiver_sm and __all_sync(0xffffffff, lane_id >= kNumForwarders or forwarder_retired[lane_id])) + break; + + // Find minimum head for RDMA ranks + if (is_rdma_receiver_sm) { + int min_head = std::numeric_limits::max(); + #pragma unroll + for (int i = 0; i < kNumRDMAReceivers; ++ i) if (not rdma_receiver_retired[i]) + min_head = min(min_head, rdma_receiver_rdma_head[i][dst_rdma_rank]); + if (min_head != std::numeric_limits::max() and min_head >= last_rdma_head + num_max_rdma_chunked_send_tokens and lane_id < kNumRDMARanks) { + nvshmemi_ibgda_amo_nonfetch_add(rdma_channel_head.buffer(rdma_rank), min_head - last_rdma_head, + translate_dst_rdma_rank(dst_rdma_rank, nvl_rank), channel_id + num_channels, dst_rdma_rank == rdma_rank); + last_rdma_head = min_head; + } + } else { + // Find minimum head for NVL ranks + #pragma unroll + for (int i = 0; i < kNumRDMARanks; ++ i) { + int min_head = std::numeric_limits::max(); + #pragma unroll + for (int j = 0; j < num_warps_per_rdma_rank; ++ j) if (not forwarder_retired[i * num_warps_per_rdma_rank + j]) + min_head = min(min_head, forwarder_nvl_head[i * num_warps_per_rdma_rank + j][dst_nvl_rank]); + if (min_head != std::numeric_limits::max() and min_head > last_nvl_head[i] and lane_id < NUM_MAX_NVL_PEERS) + st_relaxed_sys_global(nvl_channel_head.buffer_by(dst_nvl_rank) + i, last_nvl_head[i] = min_head); + } + } + + // Nanosleep and let other warps work + __nanosleep(NUM_WAIT_NANOSECONDS); + } + } + } +} + +void combine(cudaDataType_t type, + void* combined_x, float* combined_topk_weights, + const bool* is_combined_token_in_rank, + const void* x, const float* topk_weights, + const void* bias_0, const void* bias_1, + const int* combined_rdma_head, const int* combined_nvl_head, + const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, + int num_tokens, int num_combined_tokens, int hidden, int num_topk, + void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, + int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode) { + constexpr int kNumCombineForwarderWarps = 16; + +#define COMBINE_LAUNCH_CASE(num_rdma_ranks) { \ + auto combine_func = low_latency_mode ? \ + combine : combine; \ + LAUNCH_KERNEL(&cfg, combine_func, \ + reinterpret_cast(combined_x), combined_topk_weights, is_combined_token_in_rank, \ + reinterpret_cast(x), topk_weights, \ + reinterpret_cast(bias_0), reinterpret_cast(bias_1), \ + combined_rdma_head, combined_nvl_head, \ + reinterpret_cast(src_meta), rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, \ + num_tokens, num_combined_tokens, hidden, num_topk, \ + rdma_buffer_ptr, num_max_rdma_chunked_send_tokens, num_max_rdma_chunked_recv_tokens, \ + buffer_ptrs, num_max_nvl_chunked_send_tokens, num_max_nvl_chunked_recv_tokens, \ + rank, num_ranks); } break + + int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; + auto num_warps_per_forwarder = std::max(kNumCombineForwarderWarps / num_rdma_ranks, 1); + int num_forwarder_warps = num_rdma_ranks * num_warps_per_forwarder; + EP_HOST_ASSERT(num_forwarder_warps > 0 and num_forwarder_warps % num_rdma_ranks == 0); + EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens % num_rdma_ranks == 0); + EP_HOST_ASSERT(num_max_nvl_chunked_recv_tokens / num_rdma_ranks > std::max(num_max_rdma_chunked_send_tokens, num_max_nvl_chunked_send_tokens)); + EP_HOST_ASSERT(type == CUDA_R_16BF); + + SETUP_LAUNCH_CONFIG(num_channels * 2, (NUM_MAX_NVL_PEERS + num_forwarder_warps + 1) * 32, stream); + SWITCH_RDMA_RANKS(COMBINE_LAUNCH_CASE); +#undef COMBINE_LAUNCH_CASE +} + +} // namespace internode + +} // namespace deep_ep diff --git a/DeepEP/csrc/kernels/internode_ll.cu b/DeepEP/csrc/kernels/internode_ll.cu new file mode 100644 index 000000000..db15bf5d2 --- /dev/null +++ b/DeepEP/csrc/kernels/internode_ll.cu @@ -0,0 +1,584 @@ +#include "configs.cuh" +#include "exception.cuh" +#include "launch.cuh" +#include "ibgda_device.cuh" + +namespace deep_ep { + +namespace internode_ll { + +template __launch_bounds__(kNumThreads, 1) +__global__ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, + int* clean_1, int num_clean_int_1) { + // Barrier before cleaning (in case of unfinished chunked EP) + nvshmemx_barrier_all_block(); + + // Clean + auto thread_id = static_cast(threadIdx.x); + #pragma unroll + for (int i = thread_id; i < num_clean_int_0; i += kNumThreads) + clean_0[i] = 0; + #pragma unroll + for (int i = thread_id; i < num_clean_int_1; i += kNumThreads) + clean_1[i] = 0; + + // Barrier after cleaning (make sure the low-latency mode works fine) + nvshmemx_barrier_all_block(); +} + +void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, + int* clean_1, int num_clean_int_1, + cudaStream_t stream) { + constexpr int kNumThreads = 256; + + SETUP_LAUNCH_CONFIG(1, kNumThreads, stream); + LAUNCH_KERNEL(&cfg, clean_low_latency_buffer, + clean_0, num_clean_int_0, clean_1, num_clean_int_1); +} + +template +__global__ __launch_bounds__(1024, 1) void +dispatch(void* packed_recv_x, void* packed_recv_x_scales, + int* packed_recv_src_info, int64_t* packed_recv_layout_range, + int* packed_recv_count, + int* cumulative_local_expert_recv_stats, + void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, + const void* x, const int64_t* topk_idx, + int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, + int* next_clean, int num_next_clean_int, + int num_tokens, int num_max_dispatch_tokens_per_rank, + int num_topk, int num_experts, int rank, int num_ranks, + int num_warp_groups, int num_warps_per_group, + bool round_scale, int phases) { + const auto sm_id = static_cast(blockIdx.x); + const auto thread_id = static_cast(threadIdx.x); + const auto warp_id = thread_id / 32, lane_id = get_lane_id(); + const auto num_sms = static_cast(gridDim.x); + const auto num_warps = num_warp_groups * num_warps_per_group; + const auto num_local_experts = num_experts / num_ranks; + const auto warp_group_id = warp_id / num_warps_per_group; + const auto sub_warp_id = warp_id % num_warps_per_group; + const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id; + + // May extract UE8M0 from the scales + using scale_t = std::conditional_t; + using packed_t = std::conditional_t; + EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length"); + + // FP8 staffs + constexpr int kNumPerChannels = 128; + const int num_scales = kHidden / kNumPerChannels; + const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16)); + const size_t hidden_int4 = hidden_bytes / sizeof(int4); + + // Message package: hidden data, FP8 scales, index at source + // NOTES: currently we have 3 reserved int fields for future use + using vec_t = typename std::conditional::type; + const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16))); + const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); + EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); + + // Expert counts + constexpr int kNumMaxWarpGroups = 32; + __shared__ int shared_num_tokens_sent_per_expert[kNumMaxWarpGroups]; + + // Sending phase + if ((phases & LOW_LATENCY_SEND_PHASE) == 0) + goto LOW_LATENCY_DISPATCH_RECV; + + // There are 2 kinds of warps in this part: + // 1. The first-kind warps for FP8 cast and sending top-k tokens + // 2. The last warp for reading `topk_idx` and count for per-expert information + if (warp_id < num_warps - 1) { + constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(nv_bfloat16); + EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0); + EP_STATIC_ASSERT(kNumElemsPerRead * 32 % kNumPerChannels == 0, "Invalid vectorization"); + const auto num_threads = (num_warps - 1) * 32; + const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead; + + for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { + const auto x_int4 = static_cast(x) + token_idx * hidden_bf16_int4; + const auto rdma_x_src_idx = reinterpret_cast(static_cast(rdma_x) + token_idx * num_bytes_per_msg); + const auto rdma_x_vec = reinterpret_cast(reinterpret_cast(rdma_x_src_idx) + sizeof(int4)); + const auto rdma_x_scales = reinterpret_cast(reinterpret_cast(rdma_x_vec) + hidden_bytes); + + // Overlap top-k index read and source token index writes + auto dst_expert_idx = warp_id < num_topk ? static_cast(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1; + thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0; + + // FP8 cast + #pragma unroll + for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) { + // Read + auto int4_value = __ldg(x_int4 + i); + + if constexpr (kUseFP8) { + // Calculate local amax + auto bf16_values = reinterpret_cast(&int4_value); + float fp32_values[kNumElemsPerRead]; + float amax = kFP8Margin, scale, scale_inv; + #pragma unroll + for (int j = 0; j < kNumElemsPerRead; ++ j) { + fp32_values[j] = static_cast(bf16_values[j]); + amax = fmaxf(amax, fabsf(fp32_values[j])); + } + + // Reduce amax and scale + EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization"); + amax = half_warp_reduce_max(amax); + calculate_fp8_scales(amax, scale, scale_inv, round_scale); + if (lane_id == 0 or lane_id == 16) + rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv; + + // Cast into send buffer + vec_t int2_value; + auto fp8x2_values = reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value); + #pragma unroll + for (int j = 0; j < kNumElemsPerRead; j += 2) { + float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale}; + fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3); + } + rdma_x_vec[i] = int2_value; + } else { + // Reinterpret-cast is for C++14 compatibility + rdma_x_vec[i] = *reinterpret_cast(&int4_value); + } + } + asm volatile("bar.sync 1, %0;" :: "r"(num_threads)); + + // Issue IBGDA sends + if (dst_expert_idx >= 0) { + int slot_idx = lane_id == 0 ? atomicAdd(atomic_counter_per_expert + dst_expert_idx, 1) : 0; + slot_idx = __shfl_sync(0xffffffff, slot_idx, 0); + const auto dst_rank = dst_expert_idx / num_local_experts; + const auto dst_expert_local_idx = dst_expert_idx % num_local_experts; + const auto src_ptr = reinterpret_cast(rdma_x_src_idx); + const auto dst_ptr = reinterpret_cast(rdma_recv_x) + + dst_expert_local_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + + rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + + slot_idx * num_bytes_per_msg; + const auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); + if (dst_p2p_ptr == 0) { + nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx); + } else { + // NOTES: only 2 load iterations for 7K hidden with 8 unrolls + const auto* src_int4_ptr = reinterpret_cast(src_ptr); + const auto* dst_int4_ptr = reinterpret_cast(dst_p2p_ptr); + UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); + } + + // Increase counter after finishing + __syncwarp(); + lane_id == 0 ? atomic_add_release_global(atomic_finish_counter_per_expert + dst_expert_idx, 1) : 0; + } + } + } else if (warp_id == num_warps - 1) { + EP_DEVICE_ASSERT(num_sms > 1); + if (sm_id == 0) { + // The first SM is also responsible for checking QPs + EP_DEVICE_ASSERT(ibgda_get_state()->num_rc_per_pe >= num_local_experts); + + // The first SM is also responsible for cleaning the next buffer + #pragma unroll + for (int i = lane_id; i < num_next_clean_int; i += 32) + next_clean[i] = 0; + + // Notify before executing `int_p` + __syncwarp(); + #pragma unroll + for (int i = lane_id; i < num_experts; i += 32) + atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG); + } + + // This SM should be responsible for some destination experts, read `topk_idx` for them + int expert_count[kNumMaxWarpGroups] = {0}; + const auto expert_begin_idx = sm_id * num_warp_groups; + const auto expert_end_idx = min(expert_begin_idx + num_warp_groups, num_experts); + + // Per lane count + #pragma unroll 8 + for (int i = lane_id; i < num_tokens * num_topk; i += 32) { + auto idx = static_cast(__ldg(topk_idx + i)); + if (idx >= expert_begin_idx and idx < expert_end_idx) + expert_count[idx - expert_begin_idx] ++; + } + + // Warp reduce + #pragma unroll + for (int i = expert_begin_idx; i < expert_end_idx; ++ i) { + auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]); + if (lane_id == 0) { + shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum; + atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum); + } + } + } + __syncthreads(); + + // Issue count sends + if (responsible_expert_idx < num_experts and sub_warp_id == 0 and lane_id == 0) { + const auto dst_rank = responsible_expert_idx / num_local_experts; + const auto dst_expert_local_idx = responsible_expert_idx % num_local_experts; + const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * num_warp_groups]; + + // Wait local sends issued and send expert counts + while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2); + auto dst_ptr = reinterpret_cast(rdma_recv_count + dst_expert_local_idx * num_ranks + rank); + auto dst_p2p_ptr = nvshmemi_get_p2p_ptr(dst_ptr, rank, dst_rank); + if (dst_p2p_ptr == 0) { + nvshmemi_ibgda_amo_nonfetch_add(reinterpret_cast(dst_ptr), -num_tokens_sent - 1, dst_rank, dst_expert_local_idx); + } else { + st_release_sys_global(reinterpret_cast(dst_p2p_ptr), -num_tokens_sent - 1); + } + + // Clean workspace for next use + atomic_counter_per_expert[responsible_expert_idx] = 0; + atomic_finish_counter_per_expert[responsible_expert_idx] = 0; + + // Clean `packed_recv_count` + if (dst_rank == 0) + packed_recv_count[dst_expert_local_idx] = 0; + } + __syncwarp(); + + // Receiving phase + LOW_LATENCY_DISPATCH_RECV: + if ((phases & LOW_LATENCY_RECV_PHASE) == 0) + return; + + // For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible + if (phases & LOW_LATENCY_SEND_PHASE) + cg::this_grid().sync(); + + // Receiving and packing + if (responsible_expert_idx < num_experts) { + const auto src_rank = responsible_expert_idx / num_local_experts; + const auto local_expert_idx = responsible_expert_idx % num_local_experts; + const auto rdma_recv_x_uint8 = static_cast(rdma_recv_x) + + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_bytes_per_msg + + src_rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg; + const auto recv_x_int4 = static_cast(packed_recv_x) + + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4; + const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; + const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks; + const auto num_aligned_scales = align(num_scales, sizeof(float) / sizeof(scale_t)); + const auto recv_x_scales = static_cast(packed_recv_x_scales) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales; + + // Shared between sub-warps in warp groups + __shared__ int shared_num_recv_tokens[kNumMaxWarpGroups], shared_recv_token_begin_idx[kNumMaxWarpGroups]; + + // Wait tokens to arrive + // NOTES: using sub-warp 1 to overlap with sub-warp 0 + int num_recv_tokens, recv_token_begin_idx; + EP_DEVICE_ASSERT(num_warps_per_group > 1 and num_warp_groups < 15); + if (sub_warp_id == 1 and lane_id == 0) { + while ((num_recv_tokens = ld_acquire_sys_global(rdma_recv_count + local_expert_idx * num_ranks + src_rank)) == 0); + num_recv_tokens = -num_recv_tokens - 1; + recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens); + shared_num_recv_tokens[warp_group_id] = num_recv_tokens; + shared_recv_token_begin_idx[warp_group_id] = recv_token_begin_idx; + recv_range[src_rank] = pack2(num_recv_tokens, recv_token_begin_idx); + if (cumulative_local_expert_recv_stats != nullptr) + atomicAdd(cumulative_local_expert_recv_stats + local_expert_idx, num_recv_tokens); + } + asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(num_warps_per_group * 32)); + num_recv_tokens = shared_num_recv_tokens[warp_group_id]; + recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id]; + + // Copy tokens + EP_DEVICE_ASSERT(num_scales <= 64); + for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) { + // Copy source info + const auto src_src_idx = reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_msg); + if (lane_id == 0) + recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx); + __syncwarp(); + + // Copy data + // NOTES: only 2 load iterations for 7K hidden with 7 unrolls + const auto src_data = reinterpret_cast(reinterpret_cast(src_src_idx) + sizeof(int4)); + const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4; + UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global); + + // Copy scales + if constexpr (kUseFP8) { + // Equivalent CuTe layout: + // (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1)) + const auto src_scales = reinterpret_cast(reinterpret_cast(src_data) + hidden_bytes); + const auto num_elems_per_pack = static_cast(sizeof(packed_t) / sizeof(scale_t)); + const auto token_idx = recv_token_begin_idx + i; + const auto token_stride = num_elems_per_pack; + const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack; + if (lane_id < num_scales) { + const auto pack_idx = lane_id / num_elems_per_pack; + const auto elem_idx = lane_id % num_elems_per_pack; + auto scale = extract_required_scale_format(ld_nc_global(src_scales + lane_id)); + recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale; + } + if (lane_id + 32 < num_scales) { + const auto pack_idx = (lane_id + 32) / num_elems_per_pack; + const auto elem_idx = (lane_id + 32) % num_elems_per_pack; + auto scale = extract_required_scale_format(ld_nc_global(src_scales + lane_id + 32)); + recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale; + } + } + } + } +} + +void dispatch(void* packed_recv_x, void* packed_recv_x_scales, + int* packed_recv_src_info, int64_t* packed_recv_layout_range, + int* packed_recv_count, + int* cumulative_local_expert_recv_stats, + void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, + const void* x, const int64_t* topk_idx, + int* next_clean, int num_next_clean_int, + int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, + int num_topk, int num_experts, int rank, int num_ranks, + bool use_fp8, bool round_scale, bool use_ue8m0, + void* workspace, int num_device_sms, + cudaStream_t stream, int phases) { + constexpr int kNumMaxTopK = 9; + const int num_warp_groups = ceil_div(num_experts, num_device_sms); + const int num_warps_per_group = 32 / num_warp_groups; + EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0); + EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group); + + const auto num_warps = num_warp_groups * num_warps_per_group; + const auto num_sms = ceil_div(num_experts, num_warp_groups); + EP_HOST_ASSERT(num_topk <= kNumMaxTopK); + + // Workspace checks + auto atomic_counter_per_expert = static_cast(workspace); + auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts; + EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES); + + // FP8 checks + if (use_ue8m0) + EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`"); + +#define DISPATCH_LAUNCH_CASE(hidden) { \ +auto dispatch_func = dispatch