sglang0.4.5.post1/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py

473 lines
15 KiB
Python

# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/pynccl.py
# This file is a pure Python wrapper for the NCCL library.
# The main purpose is to use NCCL combined with CUDA graph.
# Before writing this script, we tried the following approach:
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
# often gets stuck when initializing the NCCL communicator.
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
# contains many other potential cuda APIs, that are not allowed during
# capturing the CUDA graph. For further details, please check
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
#
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
# doable, but we often encounter issues related with nccl versions, and need
# to switch between different versions of NCCL. See
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
# A C/C++ binding is not flexible enough to handle this. It requires
# recompilation of the code every time we want to switch between different
# versions. This current implementation, with a **pure** Python wrapper, is
# more flexible. We can easily switch between different versions of NCCL by
# changing the environment variable `SGLANG_NCCL_SO_PATH`, or the `so_file`
# variable in the code.
import ctypes
import logging
import os
import platform
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import torch
from torch.distributed import ReduceOp
logger = logging.getLogger(__name__)
def find_nccl_library() -> str:
"""
We either use the library file specified by the `SGLANG_NCCL_SO_PATH`
environment variable, or we find the library file brought by PyTorch.
After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be
found by `ctypes` automatically.
"""
# so_file can be set to None in sglang
so_file = os.environ.get("SGLANG_NCCL_SO_PATH", None)
# manually load the nccl library
if so_file:
logger.info(
"Found nccl from environment variable SGLANG_NCCL_SO_PATH=%s", so_file
)
else:
if torch.version.cuda is not None:
so_file = "libnccl.so.2"
elif torch.version.hip is not None:
so_file = "librccl.so.1"
else:
raise ValueError("NCCL only supports CUDA and ROCm backends.")
logger.debug("Found nccl from library %s", so_file)
return so_file
# === export types and functions from nccl to Python ===
# for the original nccl definition, please check
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
ncclResult_t = ctypes.c_int
ncclComm_t = ctypes.c_void_p
class ncclUniqueId(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 128)]
cudaStream_t = ctypes.c_void_p
buffer_type = ctypes.c_void_p
ncclDataType_t = ctypes.c_int
class ncclDataTypeEnum:
ncclInt8 = 0
ncclChar = 0
ncclUint8 = 1
ncclInt32 = 2
ncclInt = 2
ncclUint32 = 3
ncclInt64 = 4
ncclUint64 = 5
ncclFloat16 = 6
ncclHalf = 6
ncclFloat32 = 7
ncclFloat = 7
ncclFloat64 = 8
ncclDouble = 8
ncclBfloat16 = 9
ncclNumTypes = 10
@classmethod
def from_torch(cls, dtype: torch.dtype) -> int:
if dtype == torch.int8:
return cls.ncclInt8
if dtype == torch.uint8:
return cls.ncclUint8
if dtype == torch.int32:
return cls.ncclInt32
if dtype == torch.int64:
return cls.ncclInt64
if dtype == torch.float16:
return cls.ncclFloat16
if dtype == torch.float32:
return cls.ncclFloat32
if dtype == torch.float64:
return cls.ncclFloat64
if dtype == torch.bfloat16:
return cls.ncclBfloat16
raise ValueError(f"Unsupported dtype: {dtype}")
ncclRedOp_t = ctypes.c_int
class ncclRedOpTypeEnum:
ncclSum = 0
ncclProd = 1
ncclMax = 2
ncclMin = 3
ncclAvg = 4
ncclNumOps = 5
@classmethod
def from_torch(cls, op: ReduceOp) -> int:
if op == ReduceOp.SUM:
return cls.ncclSum
if op == ReduceOp.PRODUCT:
return cls.ncclProd
if op == ReduceOp.MAX:
return cls.ncclMax
if op == ReduceOp.MIN:
return cls.ncclMin
if op == ReduceOp.AVG:
return cls.ncclAvg
raise ValueError(f"Unsupported op: {op}")
@dataclass
class Function:
name: str
restype: Any
argtypes: List[Any]
class NCCLLibrary:
exported_functions = [
# const char* ncclGetErrorString(ncclResult_t result)
Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]),
# ncclResult_t ncclGetVersion(int *version);
Function("ncclGetVersion", ncclResult_t, [ctypes.POINTER(ctypes.c_int)]),
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
Function("ncclGetUniqueId", ncclResult_t, [ctypes.POINTER(ncclUniqueId)]),
# ncclResult_t ncclCommInitRank(
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
# note that ncclComm_t is a pointer type, so the first argument
# is a pointer to a pointer
Function(
"ncclCommInitRank",
ncclResult_t,
[ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, ctypes.c_int],
),
# ncclResult_t ncclAllReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function(
"ncclAllReduce",
ncclResult_t,
[
buffer_type,
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ncclRedOp_t,
ncclComm_t,
cudaStream_t,
],
),
# ncclResult_t ncclAllGather(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function(
"ncclAllGather",
ncclResult_t,
[
buffer_type,
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ncclComm_t,
cudaStream_t,
],
),
# ncclResult_t ncclReduceScatter(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function(
"ncclReduceScatter",
ncclResult_t,
[
buffer_type,
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ncclRedOp_t,
ncclComm_t,
cudaStream_t,
],
),
# ncclResult_t ncclSend(
# const void* sendbuff, size_t count, ncclDataType_t datatype,
# int dest, ncclComm_t comm, cudaStream_t stream);
Function(
"ncclSend",
ncclResult_t,
[
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ctypes.c_int,
ncclComm_t,
cudaStream_t,
],
),
# ncclResult_t ncclRecv(
# void* recvbuff, size_t count, ncclDataType_t datatype,
# int src, ncclComm_t comm, cudaStream_t stream);
Function(
"ncclRecv",
ncclResult_t,
[
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ctypes.c_int,
ncclComm_t,
cudaStream_t,
],
),
# ncclResult_t ncclBroadcast(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, int root, ncclComm_t comm,
# cudaStream_t stream);
Function(
"ncclBroadcast",
ncclResult_t,
[
buffer_type,
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ctypes.c_int,
ncclComm_t,
cudaStream_t,
],
),
# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
# it is better not to call it at all.
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
so_file = so_file or find_nccl_library()
try:
if so_file not in NCCLLibrary.path_to_dict_mapping:
lib = ctypes.CDLL(so_file)
NCCLLibrary.path_to_library_cache[so_file] = lib
self.lib = NCCLLibrary.path_to_library_cache[so_file]
except Exception as e:
logger.error(
"Failed to load NCCL library from %s ."
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise, the nccl library might not exist, be corrupted "
"or it does not support the current platform %s."
"If you already have the library, please set the "
"environment variable SGLANG_NCCL_SO_PATH"
" to point to the correct nccl library path.",
so_file,
platform.platform(),
)
raise e
if so_file not in NCCLLibrary.path_to_dict_mapping:
_funcs: Dict[str, Any] = {}
for func in NCCLLibrary.exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype
f.argtypes = func.argtypes
_funcs[func.name] = f
NCCLLibrary.path_to_dict_mapping[so_file] = _funcs
self._funcs = NCCLLibrary.path_to_dict_mapping[so_file]
def ncclGetErrorString(self, result: ncclResult_t) -> str:
return self._funcs["ncclGetErrorString"](result).decode("utf-8")
def NCCL_CHECK(self, result: ncclResult_t) -> None:
if result != 0:
error_str = self.ncclGetErrorString(result)
raise RuntimeError(f"NCCL error: {error_str}")
def ncclGetVersion(self) -> str:
version = ctypes.c_int()
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
version_str = str(version.value)
# something like 21903 --> "2.19.3"
major = version_str[0].lstrip("0")
minor = version_str[1:3].lstrip("0")
patch = version_str[3:].lstrip("0")
return f"{major}.{minor}.{patch}"
def ncclGetUniqueId(self) -> ncclUniqueId:
unique_id = ncclUniqueId()
self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](ctypes.byref(unique_id)))
return unique_id
def ncclCommInitRank(
self, world_size: int, unique_id: ncclUniqueId, rank: int
) -> ncclComm_t:
comm = ncclComm_t()
self.NCCL_CHECK(
self._funcs["ncclCommInitRank"](
ctypes.byref(comm), world_size, unique_id, rank
)
)
return comm
def ncclAllReduce(
self,
sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
op: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(
self._funcs["ncclAllReduce"](
sendbuff, recvbuff, count, datatype, op, comm, stream
)
)
def ncclReduceScatter(
self,
sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
op: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(
self._funcs["ncclReduceScatter"](
sendbuff, recvbuff, count, datatype, op, comm, stream
)
)
def ncclAllGather(
self,
sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
# `datatype` actually should be `ncclDataType_t`
# which is an aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(
self._funcs["ncclAllGather"](
sendbuff, recvbuff, count, datatype, comm, stream
)
)
def ncclSend(
self,
sendbuff: buffer_type,
count: int,
datatype: int,
dest: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
self.NCCL_CHECK(
self._funcs["ncclSend"](sendbuff, count, datatype, dest, comm, stream)
)
def ncclRecv(
self,
recvbuff: buffer_type,
count: int,
datatype: int,
src: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
self.NCCL_CHECK(
self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream)
)
def ncclBroadcast(
self,
sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
root: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
self.NCCL_CHECK(
self._funcs["ncclBroadcast"](
sendbuff, recvbuff, count, datatype, root, comm, stream
)
)
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
__all__ = [
"NCCLLibrary",
"ncclDataTypeEnum",
"ncclRedOpTypeEnum",
"ncclUniqueId",
"ncclComm_t",
"cudaStream_t",
"buffer_type",
]