"""For Now, MSCCL is only supported on TP16 and TP8 case export WORLD_SIZE=1 export RANK=0 export MASTER_ADDR=127.0.0.1 export MASTER_PORT=12345 torchrun --nproc_per_node gpu \ --nnodes $WORLD_SIZE \ --node_rank $RANK \ --master_addr $MASTER_ADDR \ --master_port $MASTER_PORT benchmark/kernels/all_reduce/benchmark_mscclpp.py """ import os from contextlib import nullcontext from typing import List import torch import torch.distributed as dist from torch.distributed import ProcessGroup from sglang.srt.distributed import init_distributed_environment from sglang.srt.distributed.device_communicators.pymscclpp import PyMscclppCommunicator from sglang.srt.distributed.device_communicators.pynccl import PyNcclCommunicator from sglang.srt.distributed.parallel_state import ( get_tensor_model_parallel_group, graph_capture, initialize_model_parallel, set_mscclpp_all_reduce, ) def torch_allreduce(torch_input: torch.Tensor, group: ProcessGroup) -> torch.Tensor: dist.all_reduce(torch_input, group=group) return torch_input def msccl_allreduce( msccl_input: torch.Tensor, msccl_comm: PyMscclppCommunicator ) -> torch.Tensor: return msccl_comm.all_reduce(msccl_input) def pynccl_allreduce( msccl_input: torch.Tensor, pynccl_comm: PyNcclCommunicator ) -> torch.Tensor: pynccl_comm.all_reduce(msccl_input) return msccl_input def _bench_graph_time(func, inp_randn, warmup_loop=2, graph_loop=10, test_loop=10): graph_input = inp_randn.clone() with graph_capture() as graph_capture_context: graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, stream=graph_capture_context.stream): for _ in range(graph_loop): graph_out = func(graph_input) graph.replay() func_output = graph_out.clone() for _ in range(warmup_loop): graph.replay() torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) latencies: List[float] = [] for _ in range(test_loop): torch.cuda.synchronize() dist.barrier() start_event.record() graph.replay() end_event.record() end_event.synchronize() latencies.append(start_event.elapsed_time(end_event)) func_cost_us = sum(latencies) / len(latencies) / graph_loop * 1000 graph.reset() return func_output, func_cost_us def _bench_eager_time(func, inp_randn, warmup_loop=2, test_loop=10): eager_input = inp_randn.clone() eager_output = func(eager_input) func_output = eager_output.clone() for _ in range(warmup_loop): func(eager_input) torch.cuda.synchronize() start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) torch.cuda.synchronize() start_event.record() for _ in range(test_loop): func(eager_input) end_event.record() torch.cuda.synchronize() func_cost_us = start_event.elapsed_time(end_event) / test_loop * 1000 return func_output, func_cost_us def get_torch_prof_ctx(do_prof: bool): ctx = ( torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], record_shapes=True, with_stack=True, ) if do_prof else nullcontext() ) return ctx def human_readable_size(size, decimal_places=1): for unit in ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]: if size < 1024.0 or unit == "PiB": break size /= 1024.0 return f"{size:.{decimal_places}f} {unit}" try: from tabulate import tabulate except ImportError: print("tabulate not installed, skipping table printing") tabulate = None def print_markdown_table(data): if tabulate is not None: print(tabulate(data, headers="keys", tablefmt="github")) return headers = data[0].keys() header_row = "| " + " | ".join(headers) + " |" separator = "| " + " | ".join(["---"] * len(headers)) + " |" rows = [] for item in data: row = "| " + " | ".join(str(item[key]) for key in headers) + " |" rows.append(row) markdown_table = "\n".join([header_row, separator] + rows) print(markdown_table) if __name__ == "__main__": import logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", force=True, ) if not dist.is_initialized(): dist.init_process_group(backend="nccl") world, world_size = dist.group.WORLD, dist.get_world_size() rank = dist.get_rank() torch.cuda.set_device(rank % 8) device = torch.cuda.current_device() set_mscclpp_all_reduce(True) init_distributed_environment( world_size=world_size, rank=rank, local_rank=rank % 8, ) initialize_model_parallel(tensor_model_parallel_size=world_size) group = get_tensor_model_parallel_group().device_group cpu_group = get_tensor_model_parallel_group().cpu_group pynccl_comm = get_tensor_model_parallel_group().pynccl_comm pymscclpp_comm = get_tensor_model_parallel_group().pymscclpp_comm dist.barrier() profile = False dtype = torch.bfloat16 ctx = get_torch_prof_ctx(profile) result = [] with ctx: for i in range(10, 20): sz = 2**i if sz * dtype.itemsize > 2**20: break inp_randn = torch.randint(1, 16, (sz,), dtype=dtype, device=device) memory = torch.empty_like(inp_randn) memory_out = torch.empty_like(memory) torch_eager_output, torch_eager_time = _bench_eager_time( lambda inp: torch_allreduce(inp, group), inp_randn ) msccl_eager_output, msccl_eager_time = _bench_eager_time( lambda inp: msccl_allreduce(inp, pymscclpp_comm), inp_randn ) msccl_graph_output, msccl_graph_time = _bench_graph_time( lambda inp: msccl_allreduce(inp, pymscclpp_comm), inp_randn ) # since pynccl is inplace op, this return result is not correct if graph loop > 1 _, pynccl_graph_time = _bench_graph_time( lambda inp: pynccl_allreduce(inp, pynccl_comm), inp_randn ) torch.testing.assert_close(torch_eager_output, msccl_graph_output) torch.testing.assert_close(torch_eager_output, msccl_eager_output) result.append( { "msg_size": human_readable_size(inp_randn.nbytes), "torch eager time": torch_eager_time, "msccl eager time": msccl_eager_time, "msccl graph time": msccl_graph_time, "pynccl graph time": pynccl_graph_time, } ) if rank == 0: print(f"sz={sz}, dtype={dtype}: correctness check PASS!") if rank == 0: print_markdown_table(result) if profile: prof_dir = f"prof/msccl" os.makedirs(prof_dir, exist_ok=True) ctx.export_chrome_trace(f"{prof_dir}/trace_rank{dist.get_rank()}.json.gz")