sglang0.4.5.post1/python/sglang/srt/model_parallel.py

156 lines
5.9 KiB
Python

"""
Common utilities for torch model parallelism.
"""
from typing import Optional, Sequence
import torch
import torch.nn as nn
from torch.distributed.device_mesh import DeviceMesh
try:
import torch.distributed.tensor as dt
except ImportError:
# torch 2.4 or older
import torch.distributed._tensor as dt
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)
def _shard_tensor(
full_tensor: torch.Tensor,
device_mesh: DeviceMesh,
placements: Sequence[dt.Shard],
) -> "dt.DTensor":
"""
Locally shards a full tensor based on indicated sharding arrangement, and
returns a DTensor containing the local shard.
.. warning:: This is a private API that is subject to change. It skips the
communication otherwise required by `distribute_tensor`. It is only
applicable to cases where all ranks have the same `full_tensor`. For
example, in distributed inference all ranks load from the same
checkpoint. This API will not check for data equality between ranks, it
is thus user's responsibility to ensure the `full_tensor` is the same
across ranks.
Args:
full_tensor (torch.Tensor): the full tensor to be sharded.
device_mesh (:class:`DeviceMesh`): DeviceMesh to place the
DTensor. Must have same dimension as the number of placements.
placements (Sequence[:class:`Shard`]): the placements that
describes how to place the local tensor on DeviceMesh.
Returns:
A :class:`DTensor` object with the shard as its local tensor.
Examples:
>>> # xdoctest: +SKIP("need world_size and rank")
>>> device_mesh = dist.init_device_mesh("cuda", (world_size,))
>>> full_tensor = torch.arange(world_size, device=f"cuda:{rank}")
>>> dtensor = _shard_tensor(full_tensor, device_mesh, [Shard(1)])
"""
shape, offset = dt._utils.compute_local_shape_and_global_offset(
full_tensor.shape, device_mesh, placements
)
slices = [
slice(cur_offset, cur_offset + cur_shape)
for cur_shape, cur_offset in zip(shape, offset)
]
local_tensor = full_tensor[slices]
return dt.DTensor.from_local(local_tensor, device_mesh, placements)
class ColwiseParallelSharded(ColwiseParallel):
"""
A version of ColwiseParallel where the local weight has been already
sharded. This is used for the fused wqkv case, where during loading, we
already sharded wq, wk, wv before fusing them.
"""
# Override the _partition_linear_fn in ColwiseParallel
def _partition_linear_fn(self, name, module, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(0)
# means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1)
for name, param in module.named_parameters():
dtensor = dt.DTensor.from_local(param, device_mesh, [dt.Shard(0)])
dist_param = torch.nn.Parameter(dtensor, requires_grad=False)
module.register_parameter(name, dist_param)
class RowwiseParallelMaybeWait(RowwiseParallel):
"""
A version of RowwiseParallel that waits for the output (establish dependency
between comm stream and compute stream in CUDA sense) before going into the
next op. This is needed to workaround the current interaction between
AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`.
"""
def _partition_linear_fn(self, name, module, device_mesh):
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
# means Rowwise as nn.Linear is input * weight^T + bias, where
# weight would become Shard(0)
module.register_parameter(
"weight",
nn.Parameter(_shard_tensor(module.weight, device_mesh, [dt.Shard(1)])),
)
if getattr(module, "bias", None) is not None:
# The Linear module has bias
module.register_parameter(
"bias",
nn.Parameter(
dt.distribute_tensor(module.bias, device_mesh, [dt.Replicate()])
),
)
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
outputs = super(
RowwiseParallelMaybeWait, RowwiseParallelMaybeWait
)._prepare_output_fn(
output_layouts, use_local_output, mod, outputs, device_mesh
)
return torch.distributed._functional_collectives.wait_tensor(outputs)
def tensor_parallel(
module: torch.nn.Module,
device_mesh: Optional[DeviceMesh] = None,
):
"""
Tensor parallelize the model across the given device mesh.
Args:
module (`torch.nn.Module`):
The module to tensor parallelize.
device_mesh (`torch.distributed.DeviceMesh`):
The device mesh to use for tensor parallelism.
"""
# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
# No op if `_tp_plan` attribute does not exist under the module.
# This is a helper function to be used with `model.apply` to recursively
# parallelize a model.
def tplize(mod: torch.nn.Module) -> None:
tp_plan = getattr(mod, "_tp_plan", None)
if tp_plan is None:
return
for child_name, tp_style in tp_plan.items():
submod = mod.get_submodule(child_name)
if tp_style == "Colwise":
parallelize_module(submod, device_mesh, ColwiseParallel())
elif tp_style == "Rowwise":
parallelize_module(submod, device_mesh, RowwiseParallelMaybeWait())
elif tp_style == "Colwise_Sharded":
parallelize_module(submod, device_mesh, ColwiseParallelSharded())
else:
raise ValueError(f"Unknown TP style {tp_style}")
# `apply` is a native method of `nn.Module` that recursively applies a
# function to every submodule.
module.apply(tplize)