mysora/opensora/models/vae/tensor_parallel.py

559 lines
19 KiB
Python

from typing import List, Optional, Union
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from colossalai.device.device_mesh import DeviceMesh
from colossalai.shardformer.layer._operation import (
gather_forward_split_backward,
reduce_forward,
split_forward_gather_backward,
)
from colossalai.shardformer.layer.parallel_module import ParallelModule
from colossalai.tensor.d_tensor.api import (
distribute_tensor,
is_distributed_tensor,
shard_rowwise,
sharded_tensor_to_existing_param,
)
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from .utils import ChannelChunkConv3d, channel_chunk_conv3d
def shard_channelwise(
tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None
) -> torch.Tensor:
"""
Shard the second dim of the given tensor.
Args:
tensor (torch.Tensor): The tensor to be sharded.
group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor.
If None, the tensor will be sharded with respect to the global process group.
Defaults to None.
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
Returns:
torch.Tensor: The sharded tensor.
"""
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
if group_or_device_mesh is None:
group_or_device_mesh = dist.GroupMember.WORLD
if isinstance(group_or_device_mesh, ProcessGroup):
device_mesh = DeviceMesh.from_process_group(group_or_device_mesh)
else:
assert len(group_or_device_mesh.shape) == 1, "Only 1D DeviceMesh is accepted for row-wise sharding."
device_mesh = group_or_device_mesh
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={1: [0]})
return distribute_tensor(tensor, device_mesh, sharding_spec)
class Conv3dTPCol(nn.Conv3d):
"""Conv3d with column-wise tensor parallelism. This is only for inference."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device=None,
dtype=None,
tp_group=None,
gather_output: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
) -> None:
super().__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype
)
self.tp_group = tp_group
self.gather_output = gather_output
self.tp_size = dist.get_world_size(tp_group)
self.tp_rank = dist.get_rank(tp_group)
# sanity check
if weight is not None:
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
else:
assert bias_ is None, "bias_ must be None if weight is None"
# Parameters.
if weight is None:
assert weight is not None, "weight must be provided"
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if not is_distributed_tensor(self.weight):
sharded_weight = shard_rowwise(self.weight.data, self.tp_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight)
if bias:
if bias_ is None:
assert bias is not None, "bias must be provided"
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
self.bias = bias_
if not is_distributed_tensor(self.bias):
sharded_bias = shard_rowwise(self.bias.data, self.tp_group)
sharded_tensor_to_existing_param(sharded_bias, self.bias)
else:
self.bias = None
@staticmethod
def from_native_module(
module: nn.Conv3d, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
) -> ParallelModule:
r"""
Convert a native PyTorch conv3d layer to a tensor parallelized layer.
"""
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
process_group = process_group[0]
conv3d_tp = Conv3dTPCol(
in_channels=module.in_channels,
out_channels=module.out_channels,
kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
bias=module.bias is not None,
padding_mode=module.padding_mode,
device=module.weight.device,
dtype=module.weight.dtype,
tp_group=process_group,
weight=module.weight,
bias_=module.bias,
**kwargs,
)
return conv3d_tp
def forward(self, input: torch.Tensor) -> torch.Tensor:
weight = self.weight
bias = None
if self.bias is not None:
bias = self.bias
out = channel_chunk_conv3d(
input,
weight,
bias,
self.stride,
self.padding,
self.dilation,
self.groups,
ChannelChunkConv3d.CONV3D_NUMEL_LIMIT,
)
if not self.gather_output:
return out
gathered_out = gather_forward_split_backward(out, 1, self.tp_group)
return gathered_out
class Conv3dTPRow(nn.Conv3d):
"""Conv3d with row-wise tensor parallelism. This is only for inference."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device=None,
dtype=None,
tp_group=None,
split_input: bool = False,
split_output: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
) -> None:
super().__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype
)
self.tp_group = tp_group
self.split_input = split_input
self.split_output = split_output
self.tp_size = dist.get_world_size(tp_group)
self.tp_rank = dist.get_rank(tp_group)
# sanity check
if weight is not None:
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
else:
assert bias_ is None, "bias_ must be None if weight is None"
# Parameters.
if weight is None:
assert weight is not None, "weight must be provided"
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if not is_distributed_tensor(self.weight):
sharded_weight = shard_channelwise(self.weight.data, self.tp_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight)
if bias:
if bias_ is None:
assert bias is not None, "bias must be provided"
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
self.bias = bias_
else:
self.bias = None
@staticmethod
def from_native_module(
module: nn.Conv3d, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
) -> ParallelModule:
r"""
Convert a native PyTorch conv3d layer to a tensor parallelized layer.
"""
conv3d_tp = Conv3dTPRow(
in_channels=module.in_channels,
out_channels=module.out_channels,
kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
bias=module.bias is not None,
padding_mode=module.padding_mode,
device=module.weight.device,
dtype=module.weight.dtype,
tp_group=process_group,
weight=module.weight,
bias_=module.bias,
**kwargs,
)
return conv3d_tp
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.split_input:
input = split_forward_gather_backward(input, 1, self.tp_group)
weight = self.weight
out = channel_chunk_conv3d(
input,
weight,
None,
self.stride,
self.padding,
self.dilation,
self.groups,
ChannelChunkConv3d.CONV3D_NUMEL_LIMIT,
)
# del input
out = reduce_forward(out, self.tp_group)
if self.bias is not None:
out = out + self.bias[:, None, None, None]
if self.split_output:
out = split_forward_gather_backward(out, 1, self.tp_group)
return out
class Conv2dTPRow(nn.Conv2d):
"""Conv2d with row-wise tensor parallelism. This is only for inference."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device=None,
dtype=None,
tp_group=None,
split_input: bool = False,
split_output: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
) -> None:
super().__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype
)
self.tp_group = tp_group
self.split_input = split_input
self.split_output = split_output
self.tp_size = dist.get_world_size(tp_group)
self.tp_rank = dist.get_rank(tp_group)
# sanity check
if weight is not None:
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
else:
assert bias_ is None, "bias_ must be None if weight is None"
# Parameters.
if weight is None:
assert weight is not None, "weight must be provided"
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if not is_distributed_tensor(self.weight):
sharded_weight = shard_channelwise(self.weight.data, self.tp_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight)
if bias:
if bias_ is None:
assert bias is not None, "bias must be provided"
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
self.bias = bias_
else:
self.bias = None
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.split_input:
input = split_forward_gather_backward(input, 1, self.tp_group)
weight = self.weight
out = F.conv2d(
input,
weight,
None,
self.stride,
self.padding,
self.dilation,
self.groups,
)
# del input
dist.all_reduce(out, group=self.tp_group)
if self.bias is not None:
out += self.bias[:, None, None]
if self.split_output:
out = split_forward_gather_backward(out, 1, self.tp_group)
return out
@staticmethod
def from_native_module(
module: nn.Conv2d, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
) -> ParallelModule:
r"""
Convert a native PyTorch conv2d layer to a tensor parallelized layer.
"""
conv2d_tp = Conv2dTPRow(
in_channels=module.in_channels,
out_channels=module.out_channels,
kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
bias=module.bias is not None,
padding_mode=module.padding_mode,
device=module.weight.device,
dtype=module.weight.dtype,
tp_group=process_group,
weight=module.weight,
bias_=module.bias,
**kwargs,
)
conv2d_tp.weight = module.weight
conv2d_tp.bias = module.bias
return conv2d_tp
class Conv1dTPRow(nn.Conv1d):
"""Conv1d with row-wise tensor parallelism. This is only for inference."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device=None,
dtype=None,
tp_group=None,
split_input: bool = False,
split_output: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
) -> None:
super().__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype
)
self.tp_group = tp_group
self.split_input = split_input
self.split_output = split_output
self.tp_size = dist.get_world_size(tp_group)
self.tp_rank = dist.get_rank(tp_group)
# sanity check
if weight is not None:
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
else:
assert bias_ is None, "bias_ must be None if weight is None"
# Parameters.
if weight is None:
assert weight is not None, "weight must be provided"
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if not is_distributed_tensor(self.weight):
sharded_weight = shard_channelwise(self.weight.data, self.tp_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight)
if bias:
if bias_ is None:
assert bias is not None, "bias must be provided"
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
self.bias = bias_
else:
self.bias = None
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.split_input:
input = split_forward_gather_backward(input, 1, self.tp_group)
weight = self.weight
out = F.conv1d(
input,
weight,
None,
self.stride,
self.padding,
self.dilation,
self.groups,
)
# del input
dist.all_reduce(out, group=self.tp_group)
if self.bias is not None:
out += self.bias[:, None]
if self.split_output:
out = split_forward_gather_backward(out, 1, self.tp_group)
return out
@staticmethod
def from_native_module(
module: nn.Conv1d, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
) -> ParallelModule:
r"""
Convert a native PyTorch conv1d layer to a tensor parallelized layer.
"""
conv1d_tp = Conv1dTPRow(
in_channels=module.in_channels,
out_channels=module.out_channels,
kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
bias=module.bias is not None,
padding_mode=module.padding_mode,
device=module.weight.device,
dtype=module.weight.dtype,
tp_group=process_group,
weight=module.weight,
bias_=module.bias,
**kwargs,
)
conv1d_tp.weight = module.weight
conv1d_tp.bias = module.bias
return conv1d_tp
class GroupNormTP(nn.GroupNorm):
def __init__(
self,
num_groups: int,
num_channels: int,
eps: float = 0.00001,
affine: bool = True,
device=None,
dtype=None,
tp_group=None,
weight: Optional[Parameter] = None,
bias: Optional[Parameter] = None,
) -> None:
super().__init__(num_groups, num_channels, eps, affine, device, dtype)
self.tp_group = tp_group
self.tp_size = dist.get_world_size(tp_group)
self.tp_rank = dist.get_rank(tp_group)
if affine:
assert weight is not None, "weight must be provided"
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if not is_distributed_tensor(self.weight):
sharded_weight = shard_rowwise(self.weight.data, self.tp_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight)
assert bias is not None, "bias must be provided"
bias.data = bias.data.to(device=device, dtype=dtype)
self.bias = bias
if not is_distributed_tensor(self.bias):
sharded_bias = shard_rowwise(self.bias.data, self.tp_group)
sharded_tensor_to_existing_param(sharded_bias, self.bias)
else:
self.weight = None
self.bias = None
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.group_norm(
input,
self.num_groups // self.tp_size,
self.weight,
self.bias,
self.eps,
)
@staticmethod
def from_native_module(
module: nn.GroupNorm, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
) -> ParallelModule:
r"""
Convert a native PyTorch nn.GroupNorm layer to a tensor parallelized layer.
"""
group_norm_tp = GroupNormTP(
num_groups=module.num_groups,
num_channels=module.num_channels,
eps=module.eps,
affine=module.affine,
device=module.weight.device,
dtype=module.weight.dtype,
tp_group=process_group,
weight=module.weight,
bias=module.bias,
**kwargs,
)
return group_norm_tp