884 lines
37 KiB
Python
884 lines
37 KiB
Python
from functools import partial
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from colossalai.shardformer.layer import (FusedLinear1D_Col, FusedLinear1D_Row,
|
|
Linear1D_Col, Linear1D_Row)
|
|
from colossalai.shardformer.layer._operation import all_to_all_comm
|
|
from colossalai.shardformer.layer.attn import RingComm, _rescale_out_lse
|
|
from colossalai.shardformer.layer.utils import is_share_sp_tp
|
|
from colossalai.shardformer.policies.base_policy import (
|
|
ModulePolicyDescription, Policy, SubModuleReplacementDescription)
|
|
from colossalai.shardformer.shard import ShardConfig
|
|
from einops import rearrange
|
|
from flash_attn.flash_attn_interface import (_flash_attn_backward,
|
|
_flash_attn_forward)
|
|
from liger_kernel.ops.rope import LigerRopeFunction
|
|
|
|
try:
|
|
from flash_attn_interface import \
|
|
_flash_attn_backward as _flash_attn_backward_v3
|
|
from flash_attn_interface import \
|
|
_flash_attn_forward as _flash_attn_forward_v3
|
|
|
|
SUPPORT_FA3 = True
|
|
except:
|
|
SUPPORT_FA3 = False
|
|
|
|
from torch import Tensor
|
|
|
|
from opensora.acceleration.checkpoint import auto_grad_checkpoint
|
|
|
|
from .layers import DoubleStreamBlock, SingleStreamBlock
|
|
from .math import apply_rope, attention
|
|
from .model import MMDiTModel
|
|
|
|
|
|
class _SplitForwardGatherBackwardVarLen(torch.autograd.Function):
|
|
"""
|
|
Split the input and keep only the corresponding chuck to the rank.
|
|
|
|
Args:
|
|
input_ (`torch.Tensor`): input matrix.
|
|
dim (int): the dimension to perform split and gather
|
|
process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication
|
|
|
|
"""
|
|
|
|
@staticmethod
|
|
def forward(ctx, input_, dim, process_group, splits: List[int]):
|
|
ctx.process_group = process_group
|
|
ctx.dim = dim
|
|
rank = dist.get_rank(process_group)
|
|
ctx.grad_scale = splits[rank] / sum(splits)
|
|
ctx.splits = splits
|
|
return torch.split(input_, splits, dim=dim)[rank].clone()
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
grad_output = grad_output * ctx.grad_scale
|
|
grad_output = grad_output.contiguous()
|
|
world_size = dist.get_world_size(ctx.process_group)
|
|
shapes = [list(grad_output.shape) for _ in range(world_size)]
|
|
for i, shape in enumerate(shapes):
|
|
shape[ctx.dim] = ctx.splits[i]
|
|
tensor_list = [torch.empty(shape, dtype=grad_output.dtype, device=grad_output.device) for shape in shapes]
|
|
dist.all_gather(tensor_list, grad_output, group=ctx.process_group)
|
|
return torch.cat(tensor_list, dim=ctx.dim), None, None, None
|
|
|
|
|
|
def split_forward_gather_backward_var_len(input_, dim, process_group, splits: List[int]):
|
|
return _SplitForwardGatherBackwardVarLen.apply(input_, dim, process_group, splits)
|
|
|
|
|
|
class _GatherForwardSplitBackwardVarLen(torch.autograd.Function):
|
|
"""
|
|
Split the input and keep only the corresponding chuck to the rank.
|
|
|
|
Args:
|
|
input_ (`torch.Tensor`): input matrix.
|
|
dim (int): the dimension to perform split and gather
|
|
process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication
|
|
|
|
"""
|
|
|
|
@staticmethod
|
|
def forward(ctx, input_, dim, process_group, splits: List[int]):
|
|
input_ = input_.contiguous()
|
|
ctx.process_group = process_group
|
|
ctx.dim = dim
|
|
rank = dist.get_rank(process_group)
|
|
|
|
ctx.grad_scale = sum(splits) / splits[rank]
|
|
ctx.splits = splits
|
|
world_size = dist.get_world_size(ctx.process_group)
|
|
shapes = [list(input_.shape) for _ in range(world_size)]
|
|
for i, shape in enumerate(shapes):
|
|
shape[dim] = splits[i]
|
|
tensor_list = [torch.empty(shape, dtype=input_.dtype, device=input_.device) for shape in shapes]
|
|
dist.all_gather(tensor_list, input_, group=ctx.process_group)
|
|
return torch.cat(tensor_list, dim=dim)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
grad_output = grad_output * ctx.grad_scale
|
|
rank = dist.get_rank(ctx.process_group)
|
|
return torch.split(grad_output, ctx.splits, dim=ctx.dim)[rank].clone(), None, None, None
|
|
|
|
|
|
def gather_forward_split_backward_var_len(input_, dim, process_group, splits: List[int]):
|
|
return _GatherForwardSplitBackwardVarLen.apply(input_, dim, process_group, splits)
|
|
|
|
|
|
def _fa_forward(
|
|
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dropout_p: float = 0.0, softmax_scale: Optional[float] = None
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
if SUPPORT_FA3:
|
|
out, softmax_lse, *_ = _flash_attn_forward_v3(
|
|
q,
|
|
k,
|
|
v,
|
|
None,
|
|
None,
|
|
None,
|
|
None, # k_new, q_new, qv, out
|
|
None,
|
|
None,
|
|
None, # cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new
|
|
None,
|
|
None,
|
|
None,
|
|
None, # seqused_q, seqused_k, max_seqlen_q, max_seqlen_k
|
|
None,
|
|
None,
|
|
None, # page_table, kv_batch_idx, leftpad_k
|
|
None,
|
|
None, # rotary_cos/sin
|
|
None,
|
|
None,
|
|
None, # q_descale, k_descale, v_descale
|
|
softmax_scale,
|
|
False, # causal
|
|
(-1, -1),
|
|
)
|
|
rng_state = None
|
|
else:
|
|
out, softmax_lse, _, rng_state = _flash_attn_forward(
|
|
q,
|
|
k,
|
|
v,
|
|
dropout_p,
|
|
softmax_scale,
|
|
causal=False,
|
|
window_size_left=-1,
|
|
window_size_right=-1,
|
|
softcap=0.0,
|
|
alibi_slopes=None,
|
|
return_softmax=False,
|
|
)
|
|
return out, softmax_lse, rng_state
|
|
|
|
|
|
def _fa_backward(
|
|
dout: torch.Tensor,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
out: torch.Tensor,
|
|
softmax_lse: torch.Tensor,
|
|
dq: torch.Tensor,
|
|
dk: torch.Tensor,
|
|
dv: torch.Tensor,
|
|
rng_state: torch.Tensor,
|
|
dropout_p: float = 0.0,
|
|
softmax_scale: Optional[float] = None,
|
|
deterministic: bool = False,
|
|
) -> None:
|
|
if SUPPORT_FA3:
|
|
_flash_attn_backward_v3(
|
|
dout,
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
softmax_lse,
|
|
None, None, None, None, None, None,
|
|
dq,
|
|
dk,
|
|
dv,
|
|
softmax_scale,
|
|
False, # causal
|
|
(-1, -1),
|
|
deterministic=deterministic,
|
|
)
|
|
else:
|
|
_flash_attn_backward(
|
|
dout,
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
softmax_lse,
|
|
dq,
|
|
dk,
|
|
dv,
|
|
dropout_p=dropout_p,
|
|
softmax_scale=softmax_scale,
|
|
causal=False,
|
|
window_size_left=-1,
|
|
window_size_right=-1,
|
|
softcap=0.0,
|
|
alibi_slopes=None,
|
|
deterministic=deterministic,
|
|
rng_state=rng_state,
|
|
)
|
|
|
|
|
|
class RingAttention(torch.autograd.Function):
|
|
ATTN_DONE: torch.cuda.Event = None
|
|
SP_STREAM: torch.cuda.Stream = None
|
|
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
sp_group: dist.ProcessGroup,
|
|
sp_stream: torch.cuda.Stream,
|
|
dropout_p: float = 0.0,
|
|
softmax_scale: Optional[float] = None,
|
|
deterministic: Optional[bool] = False,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Ring attention forward
|
|
|
|
Args:
|
|
ctx (_type_): self
|
|
q (torch.Tensor): shape [B, S, N, D]
|
|
k (torch.Tensor): shape [B, S, N, D]
|
|
v (torch.Tensor): shape [B, S, N, D]
|
|
sp_group (dist.ProcessGroup): sequence parallel group
|
|
sp_stream (torch.cuda.Stream): sequence parallel stream
|
|
dropout_p (float, optional): dropout prob. Defaults to 0.0.
|
|
softmax_scale (Optional[float], optional): softmax scale. Defaults to None.
|
|
deterministic (Optional[bool], optional): backward deterministic mode. Defaults to False.
|
|
|
|
Returns:
|
|
Tuple[torch.Tensor, torch.Tensor]: output and log sum exp. Output's shape should be [B, S, N, D]. LSE's shape should be [B, N, S].
|
|
"""
|
|
if softmax_scale is None:
|
|
softmax_scale = q.shape[-1] ** (-0.5)
|
|
sp_size = dist.get_world_size(sp_group)
|
|
kv_comms: List[RingComm] = [RingComm(sp_group) for _ in range(2)]
|
|
|
|
# [B, S, N, D]
|
|
q, k, v = [x.contiguous() for x in [q, k, v]]
|
|
# Pre-allocate double buffer for overlapping and receiving next step's inputs
|
|
kv_buffers = [torch.stack((k, v))] # (2, B, S, N, D)
|
|
kv_buffers.append(torch.empty_like(kv_buffers[0]))
|
|
# outputs
|
|
out = None
|
|
block_out = [None, None]
|
|
softmax_lse = [None, None]
|
|
block_softmax_lse = [None, None] # log sum exp, the denominator of softmax in attention
|
|
rng_states = [None for _ in range(sp_size)]
|
|
sp_streams = [torch.cuda.current_stream(), sp_stream]
|
|
|
|
def _kv_comm(i):
|
|
# Avoid overwriting attn input when it shares mem with buffer
|
|
if not RingAttention.ATTN_DONE.query():
|
|
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
|
|
if i < sp_size - 1:
|
|
kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
|
|
|
|
for i in range(sp_size):
|
|
with torch.cuda.stream(sp_streams[i % 2]):
|
|
# Wait for current kv from prev rank
|
|
# NOTE: waiting outside the current stream will NOT correctly synchronize.
|
|
if i == 0:
|
|
_kv_comm(i)
|
|
else:
|
|
kv_comms[(i + 1) % 2].wait()
|
|
kv_block = kv_buffers[i % 2]
|
|
q_block = q
|
|
block_out[i % 2], block_softmax_lse[i % 2], rng_states[i] = _fa_forward(
|
|
q_block, kv_block[0], kv_block[1], dropout_p, softmax_scale
|
|
)
|
|
RingAttention.ATTN_DONE.record()
|
|
# Pipeline the next KV comm with output correction instead of the next flash attn
|
|
# to minimize idle time when comm takes longer than attn.
|
|
_kv_comm(i + 1)
|
|
block_softmax_lse[i % 2] = (
|
|
block_softmax_lse[i % 2].transpose(1, 2).unsqueeze(-1).contiguous().float()
|
|
) # [B, N, S] -> [B, S, N, 1]
|
|
assert block_out[i % 2].shape[:-1] == block_softmax_lse[i % 2].shape[:-1]
|
|
# Output and log sum exp correction. Ideally overlap this with the next flash attn kernel.
|
|
# In reality this always finishes before next flash attn; no need for extra sync.
|
|
if i == 0:
|
|
out = block_out[0]
|
|
softmax_lse = block_softmax_lse[0]
|
|
else:
|
|
out, softmax_lse = _rescale_out_lse(out, block_out[i % 2], softmax_lse, block_softmax_lse[i % 2])
|
|
torch.cuda.current_stream().wait_stream(sp_stream)
|
|
out = out.to(q.dtype)
|
|
softmax_lse = softmax_lse.squeeze(-1).transpose(1, 2).contiguous()
|
|
|
|
ctx.dropout_p = dropout_p
|
|
ctx.softmax_scale = softmax_scale
|
|
ctx.deterministic = deterministic
|
|
ctx.sp_group = sp_group
|
|
ctx.save_for_backward(q, k, v, out, softmax_lse, *rng_states) # lse [B, N, S]
|
|
return out, softmax_lse
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output, grad_softmax_lse):
|
|
# q, k, v, out: [B, S, N, D], softmax_lse: [B, N, S]
|
|
q, k, v, out, softmax_lse, *rng_states = ctx.saved_tensors
|
|
|
|
sp_group = ctx.sp_group
|
|
sp_size = dist.get_world_size(sp_group)
|
|
kv_comm = RingComm(sp_group)
|
|
dkv_comm = RingComm(sp_group)
|
|
|
|
grad_output = grad_output.contiguous()
|
|
kv_buffers = [torch.stack((k, v))] # (2, B, S, N, D)
|
|
kv_buffers.append(torch.empty_like(kv_buffers[0]))
|
|
dq = None
|
|
dq_block = torch.empty_like(q)
|
|
dk_block = torch.empty_like(k)
|
|
dv_block = torch.empty_like(v)
|
|
dkv_buffers = [torch.empty_like(kv, dtype=torch.float) for kv in kv_buffers]
|
|
del k, v
|
|
|
|
for i in range(sp_size):
|
|
if i > 0:
|
|
kv_comm.wait()
|
|
if i < sp_size - 1:
|
|
kv_comm.send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
|
|
|
|
k_block, v_block = kv_buffers[i % 2]
|
|
_fa_backward(
|
|
grad_output,
|
|
q,
|
|
k_block,
|
|
v_block,
|
|
out,
|
|
softmax_lse,
|
|
dq_block,
|
|
dk_block,
|
|
dv_block,
|
|
rng_states[i],
|
|
dropout_p=ctx.dropout_p,
|
|
softmax_scale=ctx.softmax_scale,
|
|
deterministic=ctx.deterministic,
|
|
)
|
|
|
|
if i == 0:
|
|
dq = dq_block.float()
|
|
dkv_buffers[i % 2][0] = dk_block.float()
|
|
dkv_buffers[i % 2][1] = dv_block.float()
|
|
else:
|
|
dq += dq_block
|
|
dkv_comm.wait()
|
|
dkv_buffers[i % 2][0] += dk_block
|
|
dkv_buffers[i % 2][1] += dv_block
|
|
dkv_comm.send_recv(dkv_buffers[i % 2], dkv_buffers[(i + 1) % 2])
|
|
dkv_comm.wait()
|
|
dkv = dkv_buffers[sp_size % 2]
|
|
|
|
dq, dk, dv = [x.to(q.dtype) for x in (dq, *dkv)]
|
|
|
|
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
|
|
|
@staticmethod
|
|
def attention(
|
|
q,
|
|
k,
|
|
v,
|
|
sp_group,
|
|
dropout_p: float = 0.0,
|
|
softmax_scale: Optional[float] = None,
|
|
deterministic: bool = False,
|
|
return_softmax: bool = False,
|
|
):
|
|
"""Ring attention
|
|
|
|
Args:
|
|
q (torch.Tensor): shape [B, S, N, D]
|
|
k (torch.Tensor): shape [B, S, N, D]
|
|
v (torch.Tensor): shape [B, S, N, D]
|
|
sp_group (dist.ProcessGroup): sequence parallel group
|
|
dropout_p (float, optional): dropout prob. Defaults to 0.0.
|
|
softmax_scale (Optional[float], optional): softmax scale. Defaults to None.
|
|
deterministic (Optional[bool], optional): backward deterministic mode. Defaults to False.
|
|
return_softmax (bool, optional): return softmax or not. Defaults to False.
|
|
|
|
Returns:
|
|
Tuple[torch.Tensor, torch.Tensor]: output and log sum exp. Output's shape should be [B, S, N, D]. LSE's shape should be [B, N, S].
|
|
"""
|
|
if RingAttention.ATTN_DONE is None:
|
|
RingAttention.ATTN_DONE = torch.cuda.Event()
|
|
if RingAttention.SP_STREAM is None:
|
|
RingAttention.SP_STREAM = torch.cuda.Stream()
|
|
out, softmax_lse = RingAttention.apply(
|
|
q, k, v, sp_group, RingAttention.SP_STREAM, dropout_p, softmax_scale, deterministic
|
|
)
|
|
if return_softmax:
|
|
return out, softmax_lse
|
|
return out
|
|
|
|
|
|
def ring_attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, sp_group: dist.ProcessGroup) -> Tensor:
|
|
if isinstance(pe, torch.Tensor):
|
|
q, k = apply_rope(q, k, pe)
|
|
else:
|
|
cos, sin = pe
|
|
q, k = LigerRopeFunction.apply(q, k, cos, sin)
|
|
q, k, v = [x.transpose(1, 2) for x in (q, k, v)] # [B, H, L, D] -> [B, L, H, D]
|
|
x = RingAttention.attention(q, k, v, sp_group)
|
|
x = rearrange(x, "B L H D -> B L (H D)")
|
|
return x
|
|
|
|
|
|
class DistributedDoubleStreamBlockProcessor:
|
|
def __init__(self, shard_config: ShardConfig) -> None:
|
|
self.shard_config = shard_config
|
|
|
|
def __call__(
|
|
self, attn: DoubleStreamBlock, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
|
|
) -> tuple[Tensor, Tensor]:
|
|
img_mod1, img_mod2 = attn.img_mod(vec)
|
|
txt_mod1, txt_mod2 = attn.txt_mod(vec)
|
|
|
|
# prepare image for attention
|
|
img_modulated = attn.img_norm1(img)
|
|
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
|
if attn.img_attn.fused_qkv:
|
|
img_qkv = attn.img_attn.qkv(img_modulated)
|
|
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
|
|
else:
|
|
img_q = rearrange(attn.img_attn.q_proj(img_modulated), "B L (H D) -> B L H D", H=attn.num_heads)
|
|
img_k = rearrange(attn.img_attn.k_proj(img_modulated), "B L (H D) -> B L H D", H=attn.num_heads)
|
|
img_v = rearrange(attn.img_attn.v_proj(img_modulated), "B L (H D) -> B L H D", H=attn.num_heads)
|
|
img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
|
|
if not attn.img_attn.fused_qkv:
|
|
img_q = rearrange(img_q, "B L H D -> B H L D")
|
|
img_k = rearrange(img_k, "B L H D -> B H L D")
|
|
img_v = rearrange(img_v, "B L H D -> B H L D")
|
|
|
|
# prepare txt for attention
|
|
txt_modulated = attn.txt_norm1(txt)
|
|
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
|
if attn.txt_attn.fused_qkv:
|
|
txt_qkv = attn.txt_attn.qkv(txt_modulated)
|
|
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
|
|
else:
|
|
txt_q = rearrange(attn.txt_attn.q_proj(txt_modulated), "B L (H D) -> B L H D", H=attn.num_heads)
|
|
txt_k = rearrange(attn.txt_attn.k_proj(txt_modulated), "B L (H D) -> B L H D", H=attn.num_heads)
|
|
txt_v = rearrange(attn.txt_attn.v_proj(txt_modulated), "B L (H D) -> B L H D", H=attn.num_heads)
|
|
txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
|
|
if not attn.txt_attn.fused_qkv:
|
|
txt_q = rearrange(txt_q, "B L H D -> B H L D")
|
|
txt_k = rearrange(txt_k, "B L H D -> B H L D")
|
|
txt_v = rearrange(txt_v, "B L H D -> B H L D")
|
|
|
|
txt_len = txt_q.size(2)
|
|
# run actual attention
|
|
q = torch.cat((txt_q, img_q), dim=2)
|
|
k = torch.cat((txt_k, img_k), dim=2)
|
|
v = torch.cat((txt_v, img_v), dim=2)
|
|
|
|
if (
|
|
self.shard_config.enable_sequence_parallelism
|
|
and self.shard_config.sequence_parallelism_mode == "all_to_all"
|
|
):
|
|
assert (
|
|
attn.num_heads % self.shard_config.sequence_parallel_size == 0
|
|
), f"Expected num heads({attn.num_heads}) % sp size({self.shard_config.sequence_parallel_size}) == 0"
|
|
# TODO: overlap the communication with computation
|
|
q = all_to_all_comm(q, self.shard_config.sequence_parallel_process_group, scatter_dim=1, gather_dim=2)
|
|
k = all_to_all_comm(k, self.shard_config.sequence_parallel_process_group, scatter_dim=1, gather_dim=2)
|
|
v = all_to_all_comm(v, self.shard_config.sequence_parallel_process_group, scatter_dim=1, gather_dim=2)
|
|
|
|
if self.shard_config.enable_sequence_parallelism and self.shard_config.sequence_parallelism_mode == "ring_attn":
|
|
attn1 = ring_attention(q, k, v, pe, self.shard_config.sequence_parallel_process_group)
|
|
else:
|
|
attn1 = attention(q, k, v, pe=pe)
|
|
if (
|
|
self.shard_config.enable_sequence_parallelism
|
|
and self.shard_config.sequence_parallelism_mode == "all_to_all"
|
|
):
|
|
attn1 = all_to_all_comm(
|
|
attn1, self.shard_config.sequence_parallel_process_group, scatter_dim=1, gather_dim=2
|
|
)
|
|
txt_attn, img_attn = attn1[:, :txt_len], attn1[:, txt_len:]
|
|
|
|
# calculate the img bloks
|
|
img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
|
|
img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
|
|
|
|
# calculate the txt bloks
|
|
txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
|
|
txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
|
|
return img, txt
|
|
|
|
|
|
class DistributedSingleStreamBlockProcessor:
|
|
def __init__(self, shard_config: ShardConfig) -> None:
|
|
self.shard_config = shard_config
|
|
|
|
def __call__(self, attn: SingleStreamBlock, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
|
mod, _ = attn.modulation(vec)
|
|
x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
|
|
|
|
if attn.fused_qkv:
|
|
qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
|
|
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
|
|
else:
|
|
q = rearrange(attn.q_proj(x_mod), "B L (H D) -> B L H D", H=attn.num_heads)
|
|
k = rearrange(attn.k_proj(x_mod), "B L (H D) -> B L H D", H=attn.num_heads)
|
|
v, mlp = torch.split(attn.v_mlp(x_mod), [attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
|
|
v = rearrange(v, "B L (H D) -> B L H D", H=attn.num_heads)
|
|
q, k = attn.norm(q, k, v)
|
|
if not attn.fused_qkv:
|
|
q = rearrange(q, "B L H D -> B H L D")
|
|
k = rearrange(k, "B L H D -> B H L D")
|
|
v = rearrange(v, "B L H D -> B H L D")
|
|
|
|
if (
|
|
self.shard_config.enable_sequence_parallelism
|
|
and self.shard_config.sequence_parallelism_mode == "all_to_all"
|
|
):
|
|
assert (
|
|
attn.num_heads % self.shard_config.sequence_parallel_size == 0
|
|
), f"Expected num heads({attn.num_heads}) % sp size({self.shard_config.sequence_parallel_size}) == 0"
|
|
q = all_to_all_comm(q, self.shard_config.sequence_parallel_process_group, scatter_dim=1, gather_dim=2)
|
|
k = all_to_all_comm(k, self.shard_config.sequence_parallel_process_group, scatter_dim=1, gather_dim=2)
|
|
v = all_to_all_comm(v, self.shard_config.sequence_parallel_process_group, scatter_dim=1, gather_dim=2)
|
|
|
|
# compute attention
|
|
if self.shard_config.enable_sequence_parallelism and self.shard_config.sequence_parallelism_mode == "ring_attn":
|
|
attn_1 = ring_attention(q, k, v, pe, self.shard_config.sequence_parallel_process_group)
|
|
else:
|
|
attn_1 = attention(q, k, v, pe=pe)
|
|
|
|
if (
|
|
self.shard_config.enable_sequence_parallelism
|
|
and self.shard_config.sequence_parallelism_mode == "all_to_all"
|
|
):
|
|
attn_1 = all_to_all_comm(
|
|
attn_1, self.shard_config.sequence_parallel_process_group, scatter_dim=1, gather_dim=2
|
|
)
|
|
|
|
# compute activation in mlp stream, cat again and run second linear layer
|
|
output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
|
|
output = x + mod.gate * output
|
|
return output
|
|
|
|
|
|
class _TempSwitchCP(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, input_, shard_config: ShardConfig, value: bool):
|
|
ctx.old_value = shard_config.enable_sequence_parallelism
|
|
ctx.shard_config = shard_config
|
|
shard_config.enable_sequence_parallelism = value
|
|
return input_
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
print(f"in backward, sp mode: {ctx.shard_config.enable_sequence_parallelism}")
|
|
ctx.shard_config.enable_sequence_parallelism = ctx.old_value
|
|
return grad_output, None, None
|
|
|
|
|
|
def switch_sequence_parallelism(input_, shard_config: ShardConfig, value: bool):
|
|
return _TempSwitchCP.apply(input_, shard_config, value)
|
|
|
|
|
|
def mmdit_model_forward(
|
|
self: MMDiTModel,
|
|
img: Tensor,
|
|
img_ids: Tensor,
|
|
txt: Tensor,
|
|
txt_ids: Tensor,
|
|
timesteps: Tensor,
|
|
y_vec: Tensor,
|
|
cond: Tensor = None,
|
|
guidance: Tensor | None = None,
|
|
shard_config: ShardConfig = None,
|
|
stage_index: Optional[List[int]] = None,
|
|
internal_img: Optional[Tensor] = None,
|
|
internal_txt: Optional[Tensor] = None,
|
|
internal_pe: Optional[Tensor] = None,
|
|
internal_vec: Optional[Tensor] = None,
|
|
**kwargs,
|
|
):
|
|
txt_len = txt.shape[1]
|
|
if shard_config.pipeline_stage_manager is None or shard_config.pipeline_stage_manager.is_first_stage():
|
|
img, txt, vec, pe = self.prepare_block_inputs(img, img_ids, txt, txt_ids, timesteps, y_vec, cond, guidance)
|
|
has_grad = img.grad_fn is not None
|
|
old_sequence_parallelism = shard_config.enable_sequence_parallelism
|
|
if shard_config.enable_sequence_parallelism:
|
|
assert (
|
|
txt.shape[1] + img.shape[1]
|
|
) % shard_config.sequence_parallel_size == 0, (
|
|
f"Expected {txt.shape[1] +img.shape[1]} % {shard_config.sequence_parallel_size} == 0"
|
|
)
|
|
mask = torch.zeros(txt.shape[1] + img.shape[1], dtype=bool)
|
|
mask[txt.shape[1] :] = 1
|
|
mask_chunks = mask.chunk(shard_config.sequence_parallel_size)
|
|
cur_mask = mask_chunks[dist.get_rank(shard_config.sequence_parallel_process_group)]
|
|
txt_splits = [len(c) - c.sum().item() for c in mask_chunks]
|
|
img_splits = [c.sum().item() for c in mask_chunks]
|
|
if 0 in img_splits:
|
|
# temporarily disable sequence parallelism to avoid stucking
|
|
img = switch_sequence_parallelism(img, shard_config, False)
|
|
else:
|
|
img = split_forward_gather_backward_var_len(
|
|
img, 1, shard_config.sequence_parallel_process_group, img_splits
|
|
)
|
|
txt = split_forward_gather_backward_var_len(
|
|
txt, 1, shard_config.sequence_parallel_process_group, txt_splits
|
|
)
|
|
if shard_config.sequence_parallelism_mode == "ring_attn":
|
|
# pe does not require grad
|
|
sp_rank = dist.get_rank(shard_config.sequence_parallel_process_group)
|
|
if isinstance(pe, torch.Tensor):
|
|
pe = pe.chunk(shard_config.sequence_parallel_size, dim=2)[sp_rank].clone()
|
|
else:
|
|
cos, sin = pe
|
|
cos = cos.chunk(shard_config.sequence_parallel_size, dim=1)[sp_rank].clone()
|
|
sin = sin.chunk(shard_config.sequence_parallel_size, dim=1)[sp_rank].clone()
|
|
pe = (cos, sin)
|
|
else:
|
|
img, txt, vec, pe = internal_img, internal_txt, internal_vec, internal_pe
|
|
|
|
double_start, double_end = 0, len(self.double_blocks)
|
|
if shard_config.pipeline_stage_manager is not None:
|
|
double_start = stage_index[0]
|
|
double_end = min(stage_index[1], len(self.double_blocks))
|
|
|
|
for block in self.double_blocks[double_start:double_end]:
|
|
img, txt = auto_grad_checkpoint(block, img, txt, vec, pe)
|
|
|
|
if shard_config.pipeline_stage_manager is not None and stage_index[1] <= len(self.double_blocks):
|
|
return {
|
|
"internal_img": img,
|
|
"internal_txt": txt,
|
|
"internal_pe": pe,
|
|
"internal_vec": vec,
|
|
}
|
|
single_start, single_end = 0, len(self.single_blocks)
|
|
if shard_config.pipeline_stage_manager is not None:
|
|
single_start = max(stage_index[0] - len(self.double_blocks), 0)
|
|
single_end = stage_index[1] - len(self.double_blocks)
|
|
|
|
if single_start == 0:
|
|
img = torch.cat((txt, img), 1)
|
|
|
|
for block in self.single_blocks[single_start:single_end]:
|
|
img = auto_grad_checkpoint(block, img, vec, pe)
|
|
|
|
if shard_config.pipeline_stage_manager is not None and single_end < len(self.single_blocks):
|
|
return {
|
|
"internal_img": img,
|
|
"internal_pe": pe,
|
|
"internal_vec": vec,
|
|
}
|
|
|
|
if shard_config.enable_sequence_parallelism:
|
|
img = img[:, cur_mask]
|
|
else:
|
|
img = img[:, txt_len:]
|
|
|
|
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
|
|
|
if shard_config.enable_sequence_parallelism:
|
|
img = gather_forward_split_backward_var_len(img, 1, shard_config.sequence_parallel_process_group, img_splits)
|
|
|
|
if not has_grad:
|
|
shard_config.enable_sequence_parallelism = old_sequence_parallelism
|
|
return img
|
|
|
|
|
|
class MMDiTPolicy(Policy):
|
|
def config_sanity_check(self):
|
|
if self.shard_config.enable_sequence_parallelism and is_share_sp_tp(
|
|
self.shard_config.sequence_parallelism_mode
|
|
):
|
|
assert self.shard_config.enable_tensor_parallelism, "Tensor parallelism should be enabled"
|
|
|
|
def preprocess(self) -> nn.Module:
|
|
return self.model
|
|
|
|
def postprocess(self) -> nn.Module:
|
|
return self.model
|
|
|
|
def tie_weight_check(self) -> bool:
|
|
return False
|
|
|
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
|
policy = {
|
|
DoubleStreamBlock: ModulePolicyDescription(attribute_replacement={}, sub_module_replacement=[]),
|
|
SingleStreamBlock: ModulePolicyDescription(attribute_replacement={}, sub_module_replacement=[]),
|
|
}
|
|
|
|
if self.shard_config.enable_sequence_parallelism:
|
|
if not is_share_sp_tp(self.shard_config.sequence_parallelism_mode):
|
|
policy[DoubleStreamBlock].attribute_replacement["processor"] = DistributedDoubleStreamBlockProcessor(
|
|
self.shard_config
|
|
)
|
|
policy[SingleStreamBlock].attribute_replacement["processor"] = DistributedSingleStreamBlockProcessor(
|
|
self.shard_config
|
|
)
|
|
if self.shard_config.enable_sequence_parallelism or self.shard_config.pipeline_stage_manager is not None:
|
|
fwd_fn = partial(mmdit_model_forward, shard_config=self.shard_config)
|
|
if self.shard_config.pipeline_stage_manager is not None:
|
|
layers_per_stage = self.shard_config.pipeline_stage_manager.distribute_layers(
|
|
len(self.model.double_blocks) + len(self.model.single_blocks)
|
|
)
|
|
if self.shard_config.pipeline_stage_manager.is_interleave:
|
|
self.shard_config.pipeline_stage_manager.stage_indices = (
|
|
self.shard_config.pipeline_stage_manager.get_stage_index(layers_per_stage)
|
|
)
|
|
else:
|
|
stage_index = self.shard_config.pipeline_stage_manager.get_stage_index(layers_per_stage)
|
|
fwd_fn = partial(mmdit_model_forward, shard_config=self.shard_config, stage_index=stage_index)
|
|
self.append_or_create_method_replacement(
|
|
description={
|
|
"forward": fwd_fn,
|
|
},
|
|
policy=policy,
|
|
target_key=MMDiTModel,
|
|
)
|
|
|
|
if self.shard_config.enable_tensor_parallelism:
|
|
mlp_hidden_size = int(self.model.config.hidden_size * self.model.config.mlp_ratio)
|
|
assert (
|
|
self.model.config.num_heads % self.shard_config.tensor_parallel_size == 0
|
|
and mlp_hidden_size % self.shard_config.tensor_parallel_size == 0
|
|
), "num_heads and hidden_size should be divisible by tensor_parallel_size"
|
|
for n in ["img", "txt"]:
|
|
if self.model.config.fused_qkv:
|
|
policy[DoubleStreamBlock].sub_module_replacement.append(
|
|
SubModuleReplacementDescription(
|
|
suffix=f"{n}_attn.qkv",
|
|
target_module=FusedLinear1D_Col,
|
|
kwargs={
|
|
"split_sizes": [self.model.config.hidden_size] * 3,
|
|
"seq_parallel_mode": self.shard_config.sequence_parallelism_mode,
|
|
},
|
|
),
|
|
)
|
|
else:
|
|
policy[DoubleStreamBlock].sub_module_replacement.extend(
|
|
[
|
|
SubModuleReplacementDescription(
|
|
suffix=f"{n}_attn.q_proj",
|
|
target_module=Linear1D_Col,
|
|
kwargs={"seq_parallel_mode": self.shard_config.sequence_parallelism_mode},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix=f"{n}_attn.k_proj",
|
|
target_module=Linear1D_Col,
|
|
kwargs={"seq_parallel_mode": self.shard_config.sequence_parallelism_mode},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix=f"{n}_attn.v_proj",
|
|
target_module=Linear1D_Col,
|
|
kwargs={"seq_parallel_mode": self.shard_config.sequence_parallelism_mode},
|
|
),
|
|
]
|
|
)
|
|
policy[DoubleStreamBlock].sub_module_replacement.extend(
|
|
[
|
|
SubModuleReplacementDescription(
|
|
suffix=f"{n}_attn.proj",
|
|
target_module=Linear1D_Row,
|
|
kwargs={"seq_parallel_mode": self.shard_config.sequence_parallelism_mode},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix=f"{n}_mlp[0]",
|
|
target_module=Linear1D_Col,
|
|
kwargs={"seq_parallel_mode": self.shard_config.sequence_parallelism_mode},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix=f"{n}_mlp[2]",
|
|
target_module=Linear1D_Row,
|
|
kwargs={"seq_parallel_mode": self.shard_config.sequence_parallelism_mode},
|
|
),
|
|
]
|
|
)
|
|
policy[DoubleStreamBlock].attribute_replacement["num_heads"] = (
|
|
self.model.config.num_heads // self.shard_config.tensor_parallel_size
|
|
)
|
|
policy[SingleStreamBlock].attribute_replacement.update(
|
|
{
|
|
"num_heads": self.model.config.num_heads // self.shard_config.tensor_parallel_size,
|
|
"hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
|
"mlp_hidden_dim": mlp_hidden_size // self.shard_config.tensor_parallel_size,
|
|
}
|
|
)
|
|
if self.model.config.fused_qkv:
|
|
policy[SingleStreamBlock].sub_module_replacement.append(
|
|
SubModuleReplacementDescription(
|
|
suffix="linear1",
|
|
target_module=FusedLinear1D_Col,
|
|
kwargs={
|
|
"split_sizes": [self.model.config.hidden_size] * 3 + [mlp_hidden_size],
|
|
"seq_parallel_mode": self.shard_config.sequence_parallelism_mode,
|
|
},
|
|
),
|
|
)
|
|
else:
|
|
policy[SingleStreamBlock].sub_module_replacement.extend(
|
|
[
|
|
SubModuleReplacementDescription(
|
|
suffix="q_proj",
|
|
target_module=Linear1D_Col,
|
|
kwargs={"seq_parallel_mode": self.shard_config.sequence_parallelism_mode},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="k_proj",
|
|
target_module=Linear1D_Col,
|
|
kwargs={"seq_parallel_mode": self.shard_config.sequence_parallelism_mode},
|
|
),
|
|
SubModuleReplacementDescription(
|
|
suffix="v_mlp",
|
|
target_module=FusedLinear1D_Col,
|
|
kwargs={
|
|
"split_sizes": [self.model.config.hidden_size] + [mlp_hidden_size],
|
|
"seq_parallel_mode": self.shard_config.sequence_parallelism_mode,
|
|
},
|
|
),
|
|
]
|
|
)
|
|
policy[SingleStreamBlock].sub_module_replacement.extend(
|
|
[
|
|
SubModuleReplacementDescription(
|
|
suffix="linear2",
|
|
target_module=FusedLinear1D_Row,
|
|
kwargs={
|
|
"split_sizes": [self.model.config.hidden_size, mlp_hidden_size],
|
|
"seq_parallel_mode": self.shard_config.sequence_parallelism_mode,
|
|
},
|
|
),
|
|
],
|
|
)
|
|
|
|
return policy
|
|
|
|
def get_held_layers(self) -> List[nn.Module]:
|
|
stage_manager = self.shard_config.pipeline_stage_manager
|
|
assert stage_manager is not None, "Pipeline stage manager is not set"
|
|
|
|
held_layers = []
|
|
total_blocks = [*self.model.double_blocks, *self.model.single_blocks]
|
|
if stage_manager.is_first_stage(ignore_chunk=stage_manager.is_interleave):
|
|
held_layers.extend(
|
|
[
|
|
self.model.pe_embedder,
|
|
self.model.img_in,
|
|
self.model.time_in,
|
|
self.model.vector_in,
|
|
self.model.guidance_in,
|
|
self.model.cond_in,
|
|
self.model.txt_in,
|
|
]
|
|
)
|
|
|
|
layers_per_stage = stage_manager.distribute_layers(len(total_blocks))
|
|
if stage_manager.is_interleave:
|
|
assert stage_manager.num_model_chunks is not None
|
|
stage_indices = stage_manager.get_stage_index(layers_per_stage)
|
|
for start_idx, end_idx in stage_indices:
|
|
held_layers.extend(total_blocks[start_idx:end_idx])
|
|
else:
|
|
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
|
|
held_layers.extend(total_blocks[start_idx:end_idx])
|
|
if stage_manager.is_last_stage(ignore_chunk=stage_manager.is_interleave):
|
|
held_layers.append(self.model.final_layer)
|
|
return held_layers
|